[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Desktop (please complete the following information):**\n - OS: [e.g. iOS]\n - Browser [e.g. chrome, safari]\n - Version [e.g. 22]\n\n**Smartphone (please complete the following information):**\n - Device: [e.g. iPhone6]\n - OS: [e.g. iOS8.1]\n - Browser [e.g. stock browser, safari]\n - Version [e.g. 22]\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/openchatkit-feedback-report.yaml",
    "content": "name: OpenChatKit Feedback Report\ndescription: Details of feedback from using OpenChatKit test app\ntitle: OpenChatKit Feedback Report\nlabels: \"feedback report\"\nassignees: []\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Thanks for taking the time to fill out this feedback report!\n  - type: textarea\n    id: my-question\n    attributes:\n      label: \"My question:\"\n    validations:\n      required: true\n  - type: textarea\n    id: bot-response\n    attributes:\n      label: \"Bot response:\"\n    validations:\n      required: true\n  - type: textarea\n    id: ideal-bot-response\n    attributes:\n      label: \"Ideal bot response:\"\n    validations:\n      required: true\n  - type: checkboxes\n    id: response-issues\n    attributes:\n      label: \"Bot response was:\"\n      options:\n        - label: Factually incorrect\n          required: true\n        - label: Not helpful\n          required: true\n        - label: Harmful, inappropriate or unsafe\n          required: true\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# ignore downloaded files\n/data/OIG-moderation/files/\n/data/OIG/files/\n/data/wikipedia-3sentence-level-retrieval-index/files/\n/pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b/\n/pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/\n/pretrained/RedPajama-3B/togethercomputer_RedPajama-INCITE-Chat-3B-v1\n\n# ignore training output\n/model_ckpts/\n/huggingface_models/\n/training/wandb/\n\n# ignore trained low-rank adapters\n/outputs/\ndata/OIG-chip2/*.jsonl\nwandb/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n------------- LICENSE for training code -------------\n\nCopyright (c) 2022 Anonymous Institution \n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\nFooter\n"
  },
  {
    "path": "README.md",
    "content": "# OpenChatKit\n\nOpenChatKit provides a powerful, open-source base to create both specialized and general purpose models for various applications. The kit includes an instruction-tuned language models, a moderation model, and an extensible retrieval system for including up-to-date responses from custom repositories. OpenChatKit models were trained on the OIG-43M training dataset, which was a collaboration between [Together](https://www.together.xyz/), [LAION](https://laion.ai), and [Ontocord.ai](https://ontocord.ai). \n\nIn this repo, you'll find code for:\n- Training GPT-NeoXT-Chat-Base-20B, a 20B parameter chat model (see [docs/GPT-NeoXT-Chat-Base-20B.md](docs/GPT-NeoXT-Chat-Base-20B.md))\n- Fine-tuning Llama-2-7B-32K-beta, a 7B parameter long context model\n- Training Pythia-Chat-Base-7B, a 7B parameter chat model\n- Testing inference using either of the chat models\n- Augmenting the model with additional context from a retrieval index\n\n# Contents\n\n- [Getting Started](#getting-started)\n  * [Requirements](#requirements)\n  * [Chatting with Pythia-Chat-Base-7B](#chatting-with-pythia-chat-base-7b)\n- [Fine-tuning Llama-2-7B-32K-beta](#fine-tuning-llama-2-7b-32k-beta)\n  * [Downloading and converting the base model](#downloading-and-converting-the-base-model)\n  * [Fine-tuning the model](#fine-tuning-the-model)\n  * [Converting trained weights to Hugging Face format](#converting-trained-weights-to-hugging-face-format)\n- [Reproducing Pythia-Chat-Base-7B](#reproducing-pythia-chat-base-7b)\n  * [Downloading training data and the base model](#downloading-training-data-and-the-base-model)\n  * [(Optional) 8bit Adam](#optional-8bit-adam)\n  * [Training the model](#training-the-model)\n  * [Converting weights to Hugging Face format](#converting-weights-to-hugging-face-format)\n  * [Testing the new model](#testing-the-new-model)\n- [Monitoring](#monitoring)\n  * [Loguru](#loguru)\n  * [Weights & Biases](#weights--biases)\n- [Experimental: Retrieval-Augmented Models](#experimental-retrieval-augmented-models)\n- [See Also](#see-also)\n- [License](#license)\n- [Citing OpenChatKit](#citing-openchatkit)\n- [Acknowledgements](#acknowledgements)\n\n# Getting Started\n\nIn this tutorial, you will download Pythia-Chat-Base-7B, an instruction-tuned language model, and run some some inference requests against it using a command-line tool.\n\nPythia-Chat-Base-7B is a 7B-parameter fine-tuned variant of Pythia-6.9B-deduped from Eleuther AI. Pre-trained weights for this model are available on Hugging Face as [togethercomputer/Pythia-Chat-Base-7B](https://huggingface.co/togethercomputer/Pythia-Chat-Base-7B) under an Apache 2.0 license.\n\nMore details can be found on the model card for [Pythia-Chat-Base-7B](https://huggingface.co/togethercomputer/Pythia-Chat-Base-7B) on Hugging Face.\n\n## Requirements\n\nBefore you begin, you need to install PyTorch and other dependencies.\n\n1. Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) from their website.\n\n2. Install [Git LFS](https://git-lfs.com/) from their website.\n\n3. Install the `git lfs` hooks.\n\n```shell\ngit lfs install\n```\n\n4. Install mamba in the `base` environment so it's available in all environments.\n\n```shell\nconda install mamba -n base -c conda-forge\n```\n\n5. Create an environment called OpenChatKit using the `environment.yml` file at the root of this repo.\n\n> **Note**\n> Use `mamba` to create the environment. It's **much** faster than using `conda`.\n\n```shell\nmamba env create -f environment.yml \n```\n\n6. Activate the new conda environment.\n\n```shell\nconda activate OpenChatKit\n```\n\n## Chatting with Pythia-Chat-Base-7B\n\nTo help you try the model, [`inference/bot.py`](inference/bot.py) is a simple command-line test harness that provides a shell inferface enabling you to chat with the model. Simply enter text at the prompt and the model replies. The test harness also maintains conversation history to provide the model with context.\n\n\nStart the bot by calling `bot.py` from the root for the repo.\n\n```shell\npython inference/bot.py --model togethercomputer/Pythia-Chat-Base-7B\n```\n\nLoading the model can take some time, but once it's loaded, you are greeted with a prompt. Say hello.\n\n```shell\n$ python inference/bot.py \nLoading /home/csris/src/github.com/togethercomputer/OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:1...\nWelcome to OpenChatKit shell.   Type /help or /? to list commands.\n\n>>> Hello.\nHello human.\n\n>>> \n```\n\nEnter additional queries at the prompt, and the model replies. Under the covers, the shell is forming a prompt with all previous queries and passes that to the model to generate more text.\n\nThe shell also supports additional commands to inspect hyperparamters, the full prompt, and more. Commands are prefixed with a `/`.\n\n> **Note**\n> The `/quit` command exits the shell.\n\nPlease see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.\n\n# Fine-tuning Llama-2-7B-32K-beta\n\nLlama-2-7B-32K-beta model can be fine-tuned using various datasets. In this tutorial, we will use the multi-document natural questions dataset and BookSum dataset.\n\n## Downloading and converting the base model\n\nTo download model Llama-2-7B-32K-beta and prepare it for fine-tuning, run this command from the root of the repository.\n\n```shell\npython pretrained/Llama-2-7B-32K-beta/prepare.py\n```\n\nThe weights for this model will be in the `pretrained/Llama-2-7B-32K-beta/togethercomputer_Llama-2-7B-32K-beta` directory.\n\n\n## Fine-tuning the model\n\nThe `training/finetune_llama-2-7b-32k-mqa.sh` and `training/finetune_llama-2-7b-32k-booksum.sh` scripts configure and run the training loop.\n\n1. To fine-tune the multi-document natural questions dataset, run:\n   ```shell\n   bash training/finetune_llama-2-7b-32k-mqa.sh\n   ```\n\n2. To fine-tune the BookSum dataset, run:\n   ```shell\n   bash training/finetune_llama-2-7b-32k-booksum.sh\n   ```\n\nAs the training loop runs, checkpoints are saved to the `model_ckpts` directory at the root of the repo.\n\nPlease see [the training README](training/README.md) for more details about customizing the training run.\n\n## Converting trained weights to Hugging Face format\n\nBefore you can use this model to perform inference, it must be converted to the Hugging Face format. Run this command from the root of the repo to do so.\n\nFor example\n```shell\nmkdir huggingface_models \\\n  && python tools/convert_to_hf_llama.py \\\n       --config-name togethercomputer/Llama-2-7B-32K-beta \\\n       --ckpt-path model_ckpts/llama-2-7b-32k-mqa/checkpoint_10 \\\n       --save-path huggingface_models/llama-2-7b-32k-mqa \\\n       --n-stages 4 \\\n       --n-layer-per-stage 8 \\\n       --fp16\n```\nwhere the `--fp16` flag will load and store models in fp16.\n\nMake sure to replace model_ckpts/llama-2-7b-32k-mqa/checkpoint_10` with the latest checkpoint in the `model_ckpts/llama-2-7b-32k-mqa` or `model_ckpts/llama-2-7b-32k-booksum` directory.\n\n\n# Reproducing Pythia-Chat-Base-7B\n\nThis tutorial walks through reproducing the Pythia-Chat-Base-7B model by fine-tuning Eleuther AI's Pythia-6.9B-deduped model using the OIG dataset.\n\n## Downloading training data and the base model\n\nThe chat model was trained on the [OIG](https://huggingface.co/datasets/laion/OIG) dataset built by [LAION](https://laion.ai/), [Together](https://www.together.xyz/), and [Ontocord.ai](https://www.ontocord.ai/). To download the dataset from Hugging Face run the command below from the root of the repo.\n\n```shell\npython data/OIG/prepare.py\n```\n> **Note** \n> You can help make this chat model better by contributing data! See the [OpenDataHub](https://github.com/togethercomputer/OpenDataHub) repo for more details.\n\nOnce the command completes, the data will be in the `data/OIG/files` directory.\n\nPythia-Chat-Base-7B is a fine-tuned variant of Pythia-6.9B-deduped from Eleuther AI. To download the model and prepare it for fine tuning, run this command from the root of the repo.\n\n```shell\npython pretrained/Pythia-6.9B-deduped/prepare.py\n```\n\nThe weights for this model will be in the `pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped` directory.\n\n## (Optional) 8bit Adam\n\nTo use 8bit-adam during training, install the `bitsandbytes` package.\n\n```shell\npip install bitsandbytes # optional, to use 8bit-adam\n```\n\n## Training the model\n\nThe `training/finetune_Pythia-Chat-Base-7B.sh` script configures and runs the training loop. After downloading the dataset and the base model, run:\n\n```shell\nbash training/finetune_Pythia-Chat-Base-7B.sh\n```\n\nAs the training loop runs, checkpoints are saved to the `model_ckpts` directory at the root of the repo.\n\nPlease see [the training README](training/README.md) for more details about customizing the training run.\n\n## Converting weights to Hugging Face format\n\nBefore you can use this model to perform inference, it must be converted to the Hugging Face format. Run this command from the root of the repo to do so.\n\n```shell\nmkdir huggingface_models \\\n  && python tools/convert_to_hf_gptneox.py \\\n       --config-name EleutherAI/pythia-6.9b-deduped \\\n       --ckpt-path model_ckpts/Pythia-Chat-Base-7B/checkpoint_100 \\\n       --save-path huggingface_models/Pythia-Chat-Base-7B \\\n       --n-stages 4 \\\n       --n-layer-per-stage 8 \\\n       --fp16\n```\nwhere the `--fp16` flag will load and store models in fp16.\n\nMake sure to replace `model_ckpts/Pythia-Chat-Base-7B/checkpoint_100` with the latest checkpoint in the `model_ckpts/Pythia-Chat-Base-7B` directory.\n\n## Testing the new model\n\nYou can use the OpenChatKit Shell test harness to chat with the new model. From the root of the repo, run\n\n```shell\npython inference/bot.py\n```\n\nBy default the script will load the model named Pythia-Chat-Base-7B under the `huggingface_models` directory, but you can override that behavior by specifying `--model`.\n\n```shell\npython inference/bot.py --model ./huggingface_models/GPT-NeoXT-Chat-Base-20B\n```\n\nOnce the model has loaded, enter text at the prompt and the model will reply.\n\n```shell\n$ python inference/bot.py \nLoading /home/csris/src/github.com/togethercomputer/OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:1...\nWelcome to OpenChatKit shell.   Type /help or /? to list commands.\n\n>>> Hello.\nHello human.\n\n>>> \n```\n\nThe shell also supports additional commands to inspect hyperparamters, the full prompt, and more. Commands are prefixed with a `/`.\n\n> **Note**\n> The `/quit` command exits the shell.\n\nPlease see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.\n\n# Monitoring\n\nBy default, the training script simply prints the loss as training proceeds, but it can also output metrics to a file using [loguru](https://github.com/Delgan/loguru) or report them to Weights & Biases.\n\n## Loguru\n\nAdd the flag `--train-log-backend loguru` to your training script to log to `./logs/file_{time}.log`\n\n## Weights & Biases\n\nTo use Weights & Biases, first login with your Weights & Biases token.\n\n```shell\nwandb login\n```\n\nAnd set `--train-log-backend wandb` in the training script to enable logging to Weights & Biases.\n\n# Experimental: Retrieval-Augmented Models\n\n> **Warning**\n> Retrieval support is experimental.\n\nThe code in `/retrieval` implements a python package for querying a Faiss index of Wikipedia. The following steps explain how to use this index to augment queries in the test harness with context from the retriever.\n\n1. Download the Wikipedia index.\n\n```shell\npython data/wikipedia-3sentence-level-retrieval-index/prepare.py\n```\n\n2. Run the bot with the `--retrieval` flag.\n\n```shell\npython inference/bot.py --retrieval\n```\n\nAfter starting, the bot will load both the chat model and the retrieval index, which takes a long time. Once the model and the index are loaded, all queries will be augmented with extra context.\n\n\n```shell\n$ python inference/bot.py --retrieval\nLoading /OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:0...\nLoading retrieval index...\nWelcome to OpenChatKit shell.   Type /help or /? to list commands.\n\n>>> Where is Zurich?\nWhere is Zurich?\nZurich is located in Switzerland.\n\n>>>\n```\n\n# See Also\n* [docs/GPT-NeoXT-Chat-Base-20B.md](docs/GPT-NeoXT-Chat-Base-20B.md). OpenChatKit also provides a larger, 20B parameter chat model that was trained on GPT-NeoXT-Chat-Base-20B from Eleuther AI.\n\n# License\n\nAll code in this repository was developed by Together Computer except where otherwise noted.  Copyright (c) 2023, Together Computer.  All rights reserved. The code is licensed under the Apache 2.0 license.\n\n\n```\nCopyright 2023 Together Computer\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n   http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n```\n\nThis repository also contains code written by a number of other authors. Such contributions are marked and the relevant licensing is included where appropriate.\n\nFor full terms, see the LICENSE file. If you have any questions, comments, or concerns about licensing please [contact us](https://www.together.xyz/contact).\n\n# Citing OpenChatKit\n\n```bibtex\n@software{openchatkit,\n  title = {{OpenChatKit: An Open Toolkit and Base Model for Dialogue-style Applications}},\n  author = {Together Computer},\n  url = {https://github.com/togethercomputer/OpenChatKit}\n  month = {3},\n  year = {2023},\n  version = {0.15},\n}\n```\n\n# Acknowledgements\n\nOur models are fine-tuned versions of large language models trained by [Eleuther AI](https://www.eleuther.ai). We evaluated our model on [HELM](https://crfm.stanford.edu/helm/latest/) provided by the [Center for Research on Foundation Models](https://crfm.stanford.edu). And we collaborated with both [CRFM](https://crfm.stanford.edu) and [HazyResearch](http://hazyresearch.stanford.edu) at Stanford to build this model.\n\nWe collaborated with [LAION](https://laion.ai/) and [Ontocord.ai](https://www.ontocord.ai/) to build the training data used to fine tune this model.\n"
  },
  {
    "path": "data/OIG/prepare.py",
    "content": "import sys\nimport os\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_data import prepare_data\n\nif __name__ == \"__main__\":\n    dest_dir = os.path.join(current_dir, \"files\")\n    prepare_data(\"https://huggingface.co/datasets/laion/OIG\", dest_dir)\n"
  },
  {
    "path": "data/OIG-chip2/prepare.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\nwget https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl -O ${DIR}/unified_chip2.jsonl"
  },
  {
    "path": "data/OIG-moderation/prepare.py",
    "content": "import sys\nimport os\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_data import prepare_data\n\nif __name__ == \"__main__\":\n    dest_dir = os.path.join(current_dir, \"files\")\n    prepare_data(\"https://huggingface.co/datasets/ontocord/OIG-moderation\", dest_dir)\n"
  },
  {
    "path": "data/prepare_data.py",
    "content": "import argparse\nfrom shutil import copyfile\nimport boto3\nimport botocore\nimport glob\nimport gzip\nimport os\nimport re\nimport requests\nimport shutil\nimport subprocess\nimport sys\nfrom urllib.parse import urlparse\n\n\n# Check if git-lfs is installed.\ndef is_git_lfs_installed():\n    try:\n        process = subprocess.run(['git', 'lfs', 'version'], \n                                 stdout=subprocess.DEVNULL, \n                                 stderr=subprocess.DEVNULL)\n        return process.returncode == 0\n    except FileNotFoundError:\n        return False\n\n# Check if a url is a Hugging Face git URL.\ndef is_huggingface_git_url(url):\n    # Regular expression pattern for Hugging Face git URLs\n    hf_git_pattern = r'^https://huggingface\\.co/datasets/[A-Za-z0-9_\\.\\-/]+$'\n    \n    # Match the pattern against the URL\n    # Return True if a match is found, False otherwise\n    return re.match(hf_git_pattern, url) is not None\n\n# Check if the path is a GitHub repository URL.\ndef is_github_repo_url(url):\n    # Regular expression patterns for GitHub repository URLs\n    ssh_pattern = r'^git@github\\.com:[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+\\.git$'\n    http_pattern = r'^https?://(www\\.)?github\\.com/[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+\\.git$'\n    \n    # Match the patterns against the path\n    # Return True if a match is found in either SSH or HTTP pattern, False otherwise\n    return re.match(ssh_pattern, url) is not None or re.match(http_pattern, url) is not None\n\n\n# Check if the path is an S3 or R2 repository URL.\ndef is_s3_url(url):\n    # Regular expression pattern for S3 URLs\n    s3_pattern = r'^https?://(s3(-[a-z0-9-]+)?\\.amazonaws|[a-fA-F0-9]+\\.r2\\.cloudflarestorage)\\.com/[a-z0-9][a-z0-9\\.\\-]{1,61}[a-z0-9]/[0-9a-zA-Z!\\-_\\.*\\'()/]+$'\n    \n    # Match the pattern against the URL\n    # Return True if a match is found, False otherwise\n    if re.match(s3_pattern, url) is None:\n        return False\n    \n    # Check for a valid bucket name\n    bucket_name = url.split('/')[3]\n    if bucket_name.startswith(\"xn--\"):\n        return False\n    if bucket_name.endswith(\"-s3alias\"):\n        return False\n    if bucket_name.endswith(\"--ol-s3\"):\n        return False\n    if re.match(r'\\.\\.', bucket_name) is not None:\n        return False\n    if re.match(r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}', bucket_name) is not None:\n        return False\n    if re.match(r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}', bucket_name) is not None:\n        return False\n    \n    return True\n\n\n# Check that the current git repository has git-lfs installed. If the git-lfs\n# is not installed, then run `git lfs install` if git-lfs is installed. If \n# git-lfs is not installed, then print an error message and exit.\ndef clone_git_repo(data_source, destination_dir):\n    process = subprocess.run(\n        'git lfs env | grep -q \\'git config filter.lfs.smudge = \"git-lfs smudge -- %f\"\\'',\n        shell=True\n    )\n\n    # Check if the git repository has already been cloned\n    if os.path.exists(os.path.join(destination_dir, \".git\")):\n        print(f\"Git repository already exists at {destination_dir}. Skipping clone.\")\n        return\n\n    # Check if git-lfs is installed\n    if process.returncode != 0 and is_git_lfs_installed():\n        subprocess.run('git lfs install', shell=True, check=True)\n        process = subprocess.run(\n            'git lfs install',\n            shell=True\n        )\n\n    if process.returncode != 0:\n        print('error: git lfs not installed. please install git-lfs and run `git lfs install`')\n        sys.exit(1)\n\n    # Clone a GitHub repository.\n    try:\n        subprocess.run(f\"git clone {data_source} {destination_dir}\", shell=True,\n                       check=True)\n    except subprocess.CalledProcessError:\n        print(f\"error: failed to clone repository {data_source}\")\n        sys.exit(1)\n\n    \n\n# Download all files from an S3 compatible storage service.\ndef download_from_s3(url, destination_dir, access_key_id = None,\n                     secret_access_key = None, session_token = None, debug = False):\n    # Get the access key ID and secret access key from the environment variables\n    if access_key_id is None:\n        access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')\n    if secret_access_key is None:\n        secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')\n    if session_token is None:\n        session_token = os.environ.get('AWS_SESSION_TOKEN')\n    \n    # Create an S3 client\n    parsed_url = url.split('/')\n    endpoint_url = f\"{parsed_url[0]}//{parsed_url[2]}\"\n    bucket_name = parsed_url[3]\n    key_prefix = \"/\".join(parsed_url[4:-1])\n    base_file = parsed_url[-1] if not url.endswith('/') else \"\"\n    \n    print(f\"endpoint_url={endpoint_url} ...\")\n    if debug:\n        print(f\"access_key_id={access_key_id}\")\n        print(f\"secret_access_key={secret_access_key}\")\n        print(f\"bucket_name={bucket_name}\")\n        print(f\"key_prefix={key_prefix}\")\n        print(f\"base_file={base_file}\")\n\n    s3 = boto3.resource('s3',\n        endpoint_url = endpoint_url,\n        aws_access_key_id = access_key_id,\n        aws_secret_access_key = secret_access_key,\n        aws_session_token=session_token,\n        region_name = \"auto\"\n    )\n    \n    # Create the destination directory if it does not exist\n    os.makedirs(destination_dir, exist_ok=True)\n\n    try:\n        print(f\"Downloading file(s) from S3 {url} to {destination_dir} ...\")\n        bucket = s3.Bucket(bucket_name)\n        \n        # Otherwise, download the file at the prefix\n        if url.endswith('/'):\n            # Download the file from the S3 path\n            for obj in bucket.objects.filter(Prefix=key_prefix):\n                if not obj.key.endswith('/'):\n                    destination_file = os.path.join(destination_dir, os.path.basename(obj.key))\n                    if not os.path.exists(destination_file):\n                        print(f\"Downloading {obj.key} ...\")\n                        bucket.download_file(obj.key, destination_file)\n                    else:\n                        print(f\"File already exists, skipping {obj.key}\")\n        else:\n            destination_file = os.path.join(destination_dir, base_file)\n            if not os.path.exists(destination_file):\n                print(f\"Downloading {base_file} ...\")\n                bucket.download_file(f'/{key_prefix}/{base_file}', destination_file)\n            else:\n                print(f\"File already exists, skipping {base_file}\")\n\n        print(\"Download completed successfully.\")\n        return\n    \n    except botocore.exceptions.NoCredentialsError:\n        print(\"Error: AWS credentials not found.\") \n    except botocore.exceptions.EndpointConnectionError:\n        print(\"Error: Unable to connect to the S3 endpoint.\")\n    except botocore.exceptions.ParamValidationError as e:\n        print(f\"Error: Invalid S3 URL: {e}\")\n    except botocore.exceptions.ClientError as e:\n        print(f\"Error: {e.response['Error']['Message']}\")\n    \n    # Something went wrong, exit with error.\n    sys.exit(1)\n\ndef download_from_url(url, destination_dir):\n    print(f\"Downloading file from {url} to {destination_dir} ...\")\n    try:\n        # Parse the URL to extract the filename\n        parsed_url = urlparse(url)\n        filename = os.path.basename(parsed_url.path)\n        \n        # Construct the destination file path\n        destination_file = os.path.join(destination_dir, filename)\n        \n        # Download the file\n        response = requests.get(url, stream=True)\n        response.raise_for_status()\n        with open(destination_file, 'wb') as f:\n            for chunk in response.iter_content(chunk_size=8192): \n                f.write(chunk)\n        print(\"Download completed successfully.\")\n        return\n    \n    except requests.exceptions.HTTPError as e:\n        print(f\"Error: {e}\")\n    except requests.exceptions.ConnectionError:\n        print(\"Error: Unable to connect to the URL.\")\n    except requests.exceptions.Timeout:\n        print(\"Error: Connection timed out.\")\n    except requests.exceptions.RequestException as e:\n        print(f\"Error: {e}\")\n\n    # Something went wrong, exit with error.\n    sys.exit(1)\n\n# Perepare data will clone the git repository given by data_source into the\n# destination_dir.\ndef prepare_data(data_source, destination_dir, access_key_id=None, secret_access_key=None, debug=False):\n\n    # Check that destination_dir is a directory. If it does not exist, then\n    # create it.\n    if not os.path.exists(destination_dir):\n        os.makedirs(destination_dir)\n    elif not os.path.isdir(destination_dir):\n        print(f\"Error: {destination_dir} is not a directory.\")\n        sys.exit(1)\n\n    if os.path.isfile(data_source):\n        # Handle the case where the data source is a local file\n        print(f\"Copying file {data_source} to {destination_dir} ...\")\n        copyfile(data_source, destination_dir)\n    elif is_github_repo_url(data_source) or is_huggingface_git_url(data_source):\n        # Handle the case where the data source is a GitHub or Hugging Face repository\n        clone_git_repo(data_source, destination_dir)\n    elif is_s3_url(data_source):\n        # Handle the case where the data source is an S3 URL\n        download_from_s3(url=data_source, destination_dir=destination_dir, access_key_id=access_key_id, \n                         secret_access_key=secret_access_key, debug=debug)\n    elif data_source.startswith('http://') or data_source.startswith('https://'):\n        # Handle the case where the data source is a URL\n        download_from_url(data_source, destination_dir)\n    else:\n        print(f\"Error: Invalid data source: {data_source}\")\n        sys.exit(1)\n\n    # Extract gzipped files, if present\n    for file_path in glob.glob(f\"{destination_dir}/*.gz\"):\n        out_path, _ = os.path.splitext(file_path)\n        with gzip.open(file_path, 'rb') as infile, open(out_path, 'wb') as outfile:\n            shutil.copyfileobj(infile, outfile)\n        os.remove(file_path)\n    \ndef main():\n    parser = argparse.ArgumentParser(description=\"Script for cloning a git repository and extracting files.\")\n    parser.add_argument(\"-s\", \"--data-source\", required=True, help=\"URL of the data source (git repository)\")\n    parser.add_argument(\"-d\", \"--dest\", required=True, help=\"Destination directory to clone the repository and extract files\")\n    parser.add_argument(\"-a\", \"--access-key-id\", required=False, help=\"AWS access key ID\")\n    parser.add_argument(\"-k\", \"--secret-access-key\", required=False, help=\"AWS secret access key\")\n    parser.add_argument(\"--debug\", action='store_true', help=\"Enable debug mode\")\n\n    args = parser.parse_args()\n    prepare_data(args.data_source, args.dest, args.access_key_id, args.secret_access_key, args.debug)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/wikipedia-3sentence-level-retrieval-index/prepare.py",
    "content": "import sys\nimport os\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_data import prepare_data\n\nif __name__ == \"__main__\":\n    dest_dir = os.path.join(current_dir, \"files\")\n    prepare_data(\"https://huggingface.co/datasets/ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index\", dest_dir)\n"
  },
  {
    "path": "docs/GPT-NeoXT-Chat-Base-20B.md",
    "content": "# GPT-NeoXT-Chat-Base-20B\n\nOpenChatKit includes an instruction-tuned 20 billion parameter language model called GPT-NeoXT-Chat-Base-20B, a 6 billion parameter moderation model, and an extensible retrieval system for including up-to-date responses from custom repositories. It was trained on the OIG-43M training dataset, which was a collaboration between [Together](https://www.together.xyz/), [LAION](https://laion.ai), and [Ontocord.ai](https://ontocord.ai). Much more than a model release, this is the beginning of an open source project. We are releasing a set of tools and processes for ongoing improvement with community contributions. \n\nIn this doc, you'll find steps for:\n- Training an OpenChatKit model\n- Testing inference using the model\n- Augmenting the model with additional context from a retrieval index\n\n# Contents\n\n- [Requirements](#requirements)\n- [Pre-trained Weights](#pre-trained-weights)\n- [Datasets](#datasets)\n  * [Data Contributions](#data-contributions)\n- [Pretrained Base Model](#pretrained-base-model)\n- [Training and Finetuning](#training-and-finetuning)\n  * [(Optional) 8bit Adam](#optional-8bit-adam)\n  * [Train GPT-NeoX-Chat-Base-20B](#train-gpt-neox-chat-base-20b)\n- [Converting Weights to Huggingface Format](#converting-weights-to-huggingface-format)\n- [Inference](#inference)\n- [Monitoring](#monitoring)\n  * [Loguru](#loguru)\n  * [Weights & Biases](#weights--biases)\n- [Experimental: Retrieval-Augmented Models](#experimental-retrieval-augmented-models)\n- [Acknowledgements](#acknowledgements)\n\n# Requirements\n\nBefore you begin, you need to install PyTorch and other dependencies.\n\n1. Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) from their website.\n\n2. Install [Git LFS](https://git-lfs.com/) from their website.\n\n3. Install the `git lfs` hooks.\n\n```shell\ngit lfs install\n```\n\n4. Install mamba in the `base` environment so it's available in all environments.\n\n```shell\nconda install mamba -n base -c conda-forge\n```\n\n5. Create an environment called OpenChatKit using the `environment.yml` file at the root of this repo.\n\n```shell\nmamba env create -f environment.yml \n```\n\n6. Activate the new conda environment.\n\n```shell\nconda activate OpenChatKit\n```\n\n# Pre-trained Weights\n\nGPT-NeoXT-Chat-Base-20B is a 20B-parameter variant of GPT-NeoX, fine-tuned on conversational datasets. We are releasing pre-trained weights for this model as [togethercomputer/GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) on Huggingface.\n\nMore details can be found on the model card for [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) on Huggingface.\n\n# Datasets\n\nThe chat model was trained on the [OIG](https://huggingface.co/datasets/laion/OIG) dataset built by [LAION](https://laion.ai/), [Together](https://www.together.xyz/), and [Ontocord.ai](https://www.ontocord.ai/). To download the dataset from Huggingface run the command below from the root of the repo.\n\n```shell\npython data/OIG/prepare.py\n```\n\nOnce the command completes, the data will be in the `data/OIG/files` directory.\n\n## Data Contributions\n\nYou can help make this chat model better by contributing data! See the [OpenDataHub](https://github.com/togethercomputer/OpenDataHub) repo for more details.\n\n# Pretrained Base Model\n\nAs mentioned above, the chat model is a fine-tuned variant of GPT-NeoX-20B from Eleuther AI. To download GPT-NeoX-20B and prepare it for fine tuning, run this command from the root of the repo.\n\n```shell\npython pretrained/GPT-NeoX-20B/prepare.py\n```\n\nThe weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b`.\n\nIn case you want to fine-tune other gpt-neox models, e.g. [the Pythia model suite](https://huggingface.co/models?sort=downloads&search=pythia), you can specify the HF model name, for example:\n\n```shell\npython pretrained/GPT-NeoX-20B/prepare.py --model-name EleutherAI/pythia-6.9b-deduped\n```\n\nAnd the weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_pythia-6.9b-deduped`.\n\n\n# Training and Finetuning\n\n## (Optional) 8bit Adam\n\nTo use 8bit-adam during training, install the `bitsandbytes` package.\n\n```shell\npip install bitsandbytes # optional, to use 8bit-adam\n```\n\n## Train GPT-NeoX-Chat-Base-20B\n\nThe `training/finetune_GPT-NeoXT-Chat-Base-20B.sh` script configures and runs the training loop. After downloading the dataset and the base model, run:\n\n```shell\nbash training/finetune_GPT-NeoXT-Chat-Base-20B.sh\n```\n\nThe script launches 8 processes with a pipeline-parallel degree of 8 and a data-parallel degree of 1.\n\nAs the training loop runs, checkpoints are saved to the `model_ckpts` directory at the root of the repo.\n\nPlease see [the training README](training/README.md) for more details about customizing the training run.\n\nThe `training/finetune_Pythia-Chat-Base-7B.sh` script is another example to fine-tune a 7B pythia (gpt-neox) model. The script launches 8 processes with a pipeline-parallel degree of 4 and a data-parallel degree of 2.\n\n# Converting Weights to Huggingface Format\n\nBefore you can use this model to perform inference, it must be converted to the Huggingface format. Run this command from the root of the repo to do so.\n\n```shell\nmkdir huggingface_models \\\n  && python tools/convert_to_hf_gptneox.py \\\n       --ckpt-path model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100  \\\n       --save-path huggingface_models/GPT-NeoXT-Chat-Base-20B  \\\n       --n-stages 8  \\\n       --n-layer-per-stage 6 \\\n       --fp16\n```\nwhere the `--fp16` flag will load and store models in fp16.\n\nMake sure to replace `model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100` with the latest checkpoint in the `model_ckpts/GPT-Neo-XT-Chat-Base-20B` directory.\n\nIf you need to convert ckpts of other gpt-neox variants, make sure to specify the correct config name for your variant.\nFor example, if you want to convert a checkpoint fine-tuned from `EleutherAI/pythia-6.9b-deduped`, you should indicate this as a config name:\n```shell\npython tools/convert_to_hf_gptneox.py \\\n    --config-name EleutherAI/pythia-6.9b-deduped \\\n    --ckpt-path model_ckpts/Pythia-Chat-Base-7B/checkpoint_100 \\\n    --save-path huggingface_models/Pythia-Chat-Base-7B \\\n    --n-stages 4 \\\n    --n-layer-per-stage 8 \\\n    --fp16\n```\n\n\n# Inference\n\nTo help you test the model, we provide a simple test command line test harness to interact with the bot. \n\n```shell\npython inference/bot.py\n```\n\nBy default the script will load the model named GPT-NeoXT-Chat-Base-20B model under the `huggingface_models` directory, but you can override that behavior by specifying `--model`.\n\nFor example, if you want to load the base model from our Huggingface, repo, you can run the following command which downloads the weights from HuggingFace.\n\n```shell\npython inference/bot.py --model togethercomputer/GPT-NeoXT-Chat-Base-20B\n```\n\nOnce the model has loaded, enter text at the prompt and the model will reply.\n\n```shell\n$ python inference/bot.py \nLoading /home/csris/src/github.com/togethercomputer/OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:1...\nWelcome to OpenChatKit shell.   Type /help or /? to list commands.\n\n>>> Hello.\nSetting `pad_token_id` to `eos_token_id`:0 for open-end generation.\nHello human.\n\n>>> \n```\n\nCommands are prefixed with a `/`, and the `/quit` command exits.\n\nPlease see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.\n\n# Monitoring\n\nBy default, the training script simply prints the loss as training proceeds, but it can also output metrics to a file using [loguru](https://github.com/Delgan/loguru) or report them to Weights & Biases.\n\n## Loguru\n\nAdd the flag `--train-log-backend loguru` to your training script to log to `./logs/file_{time}.log`\n\n## Weights & Biases\n\nTo use Weights & Biases, first login with your Weights & Biases token.\n\n```shell\nwandb login\n```\n\nAnd set `--train-log-backend wandb` in the training script to enable logging to Weights & Biases.\n\n# Experimental: Retrieval-Augmented Models\n\n*Note: Retrieval is still experimental.*\n\nThe code in `/retrieval` implements a python package for querying a Faiss index of Wikipedia. The following steps explain how to use this index to augment queries in the test harness with context from the retriever.\n\n1. Download the Wikipedia index.\n\n```shell\npython data/wikipedia-3sentence-level-retrieval-index/prepare.py\n```\n\n2. Run the bot with the `--retrieval` flag.\n\n```shell\npython inference/bot.py --retrieval\n```\n\nAfter starting, the bot will load both the chat model and the retrieval index, which takes a long time. Once the model and the index are loaded, all queries will be augmented with extra context.\n\n\n```shell\n$ python inference/bot.py --retrieval\nLoading /OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:0...\nLoading retrieval index...\nWelcome to OpenChatKit shell.   Type /help or /? to list commands.\n\n>>> Where is Zurich?\nSetting `pad_token_id` to `eos_token_id`:0 for open-end generation.\nWhere is Zurich?\nZurich is located in Switzerland.\n\n>>>\n```\n\n# Acknowledgements\n\nOur model is a fine-tuned version of [gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b), a large language model trained by [Eleuther AI](https://www.eleuther.ai). We evaluated our model on [HELM](https://crfm.stanford.edu/helm/latest/) provided by the [Center for Research on Foundation Models](https://crfm.stanford.edu). And we collaborated with both [CRFM](https://crfm.stanford.edu) and [HazyResearch](http://hazyresearch.stanford.edu) at Stanford to build this model.\n\nWe collaborated with [LAION](https://laion.ai/) and [Ontocord.ai](https://www.ontocord.ai/) to build the training data used to fine tune this model.\n"
  },
  {
    "path": "docs/finetuning-RedPajama-3B.md",
    "content": "# RedPajama-3B\n\nIn this tutorial, you will learn how to fine-tune a base LLM on a sample of data. By the end of \nthe tutorial, you will have fine-tuned the RedPajama-INCITE-Chat-3B model using a sample of \nchat data from the OIG dataset. You can adapt this tutorial to fine-tune with your own data.\n\nIn order to fine-tune the RedPajama 3B models, please follow these steps:\n\nFirst clone the OpenChatKit repo:\n\n```shell\ngit clone git@github.com:togethercomputer/OpenChatKit.git\n```\n\nNext install dependencies as instructed by the OpenChatKit repo.\n\n# Prepare Weights\n\n```shell\npython pretrained/RedPajama-3B/prepare.py\n```\n\nThis script will download the weight from HuggingFace and prepare it for finetuning. The prepared weights will be saved at \n\n```\npretrained/RedPajama-3B/togethercomputer_RedPajama-INCITE-Chat-3B-v1\n```\n\n# Prepare Fine Tuning Data\n\nWe now need to preapre the training data.  We provide an example script that downloads a small slice of data from OIG. \nTo download this sample dataset, please run:\n \n```\nbash data/OIG-chip2/prepare.sh\n````\n \nThe sample dataset will be saved at \n\n```\ndata/OIG-chip2/unified_chip2.jsonl.\n```\n\n# Run Fine Tuning Script\n\nWe provide an example training script.  Please configure the parameters (e.g., learning_rate, batch_size, dataset_path) according to your hardware configuration. \nThen to start training, simply run\n\n```\nbash training/finetune_RedPajama-INCITE-Chat-3B-v1.sh\n```\n\n# Convert to Huggingface Format\n\nThe fine-tuned model will be saved to \n\n```\nmodel_ckpts/rp-incite-chat-3b-finetuned/checkpoint_{steps}\n```\n\nIn order to use it for inference, you will need to convert it to the HuggingFace format. To do so, run the following script \n(as an example, please change the checkpoint path, n-stages and n-layer-per-stage according to the training script):\n\nThe default for n-stages used in the training script is 10 and the n-layer-per-stage is 8.\n\n```\npython tools/convert_to_hf_gptneox.py --config-name togethercomputer/RedPajama-INCITE-Chat-3B-v1 --ckpt-path model_ckpts/redpajama-incite-chat-3b-sample/checkpoint_10/ --save-path model_ckpts/hf --n-stages 4 --n-layer-per-stage 8\n```\n\nThen you are ready to go! You can load the model with HuggingFace and use it for inference, for example:\n\n```python\nimport torch\nimport transformers\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\n\ntokenizer = AutoTokenizer.from_pretrained(\"togethercomputer/RedPajama-INCITE-Chat-3B-v1\")\nmodel = AutoModelForCausalLM.from_pretrained(\"./model_ckpts/hf\", torch_dtype=torch.float16)\nmodel = model.to('cuda:0')\n\nprompt = \"<human>: Who is Alan Turing?\\n<bot>:\"\ninputs = tokenizer(prompt, return_tensors='pt').to(model.device)\ninput_length = inputs.input_ids.shape[1]\noutputs = model.generate(\n    **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True\n)\ntoken = outputs.sequences[0, input_length:]\noutput_str = tokenizer.decode(token)\nprint(output_str)\n\n```\n\nPlease note the above finetuning takes around 60GB VRAM to fit everything in to GPU, and may take even more to fit training data. If you do not have such GPUs, we also provide the low-rank finetuning scripts that works with 14GB VRAM. Here’re the steps to get started.\n\n* Clone the OpenChatKit repo, install dependencies and prepare the dataset. These steps are the same as full fine-tuning.\n\n* The sample low-rank finetuning script is at /training/lora/redpajama-incite-chat-3b.py, please modify this script to accommodate your own training data and preferred configuration.\n\n* Then you can start low-rank finetuning by running this script.\n\nOnce the finetuning is finished, the resulting low-rank adapter will be saved to /outputs, and you can do inference with the following script.\n\n```\npython training/lora/redpajama-incite-chat-3b_inference.py\n```"
  },
  {
    "path": "environment.yml",
    "content": "name: OpenChatKit\nchannels:\n  - pytorch\n  - nvidia\n  - conda-forge\n  - defaults\ndependencies:\n  - cudatoolkit=11.8.0\n  - cupy=12.1.0\n  - faiss-gpu=1.7.2\n  - fastparquet=0.5.0\n  - nccl=2.18.3.1\n  - pip=23.2\n  - pyarrow=12.0.1\n  - python=3.10.9\n  - python-snappy=0.6.1\n  - pytorch=2.0.1\n  - pytorch-cuda=11.8\n  - snappy=1.1.9\n  - torchaudio=2.0.2\n  - torchvision=0.15.2\n  - pip:\n      - accelerate==0.21.0\n      - boto3\n      - datasets==2.13.1\n      - loguru==0.6.0\n      - netifaces==0.11.0\n      - pandas==2.0.3\n      - transformers==4.31.0\n      - wandb==0.15.5\n      - zstandard==0.21.0\n      - sentencepiece\n"
  },
  {
    "path": "inference/README.md",
    "content": "# OpenChatKit Inference\nThis directory contains code for OpenChatKit's inference.\n\n## Arguments\n- `--gpu-id`: Primary GPU device to load inputs onto for inference. Default: `0`\n- `--model`: name/path of the model. Default = `../huggingface_models/GPT-NeoXT-Chat-Base-20B`\n- `--max-tokens`: the maximum number of tokens to generate. Default: `128`\n- `--sample`: indicates whether to sample. Default: `True`\n- `--temperature`: temperature for the LM. Default: `0.6`\n- `--top-k`: top-k for the LM. Default: `40`\n- `--retrieval`: augment queries with context from the retrieval index. Default `False`\n- `-g` `--gpu-vram`: GPU ID and VRAM to allocate to loading the model, separated by a `:` in the format `ID:RAM` where ID is the CUDA ID and RAM is in GiB. `gpu-id` must be present in this list to avoid errors. Accepts multiple values, for example, `-g ID_0:RAM_0 ID_1:RAM_1 ID_N:RAM_N`\n- `-r` `--cpu-ram`: CPU RAM overflow allocation for loading the model. Optional, and only used if the model does not fit onto the GPUs given.\n\n## Hardware requirements for inference\nThe GPT-NeoXT-Chat-Base-20B model requires at least 41GB of free VRAM. Used VRAM also goes up by ~100-200 MB per prompt. \n\n- A **minimum of 80 GB is recommended** \n\n- A **minimum of 48 GB in VRAM is recommended** for fast responses.\n\nIf you'd like to run inference on a GPU with <48 GB VRAM, refer to this section on [running on consumer hardware](#running-on-consumer-hardware).\n\nBy default, inference uses only CUDA Device 0.\n\n**NOTE: Inference currently requires at least 1x GPU.**\n\n## Running on multiple GPUs\nAdd the argument \n\n```-g ID0:MAX_VRAM ID1:MAX_VRAM ID2:MAX_VRAM ...``` \n\nwhere IDx is the CUDA ID of the device and MAX_VRAM is the amount of VRAM you'd like to allocate to the device.\n\nFor example, if you are running this on 4x 48 GB GPUs and want to distribute the model across all devices, add ```-g 0:10 1:12 2:12 3:12 4:12```. In this example, the first device gets loaded to a max of 10 GiB while the others are loaded with a max of 12 GiB.\n\nHow it works: The model fills up the max available VRAM on the first device passed and then overflows into the next until the whole model is loaded.\n\n**IMPORTANT: This MAX_VRAM is only for loading the model. It does not account for the additional inputs that are added to the device. It is recommended to set the MAX_VRAM to be at least 1 or 2 GiB less than the max available VRAM on each device, and at least 3GiB less than the max available VRAM on the primary device (set by `gpu-id` default=0).**\n\n**Decrease MAX_VRAM if you run into CUDA OOM. This happens because each input takes up additional space on the device.**\n\n**NOTE: Total MAX_VRAM across all devices must be > size of the model in GB. If not, `bot.py` automatically offloads the rest of the model to RAM and disk. It will use up all available RAM. To allocate a specified amount of RAM: [refer to this section on running on consumer hardware](#running-on-consumer-hardware).**\n\n## Running on specific GPUs\nIf you have multiple GPUs but would only like to use a specific device(s), [use the same steps as in this section on running on multiple devices](#running-on-multiple-gpus) and only specify the devices you'd like to use. \n\nAlso, if needed, add the argument `--gpu-id ID` where ID is the CUDA ID of the device you'd like to make the primary device. NOTE: The device specified in `--gpu-id` must be present as one of the ID in the argument `-g` to avoid errors.\n\n- **Example #1**: to run inference on devices 2 and 5 with a max of 25 GiB on each, and make device 5 the primary device, add: `--gpu-id 5 -g 2:25 5:25`. In this example, not adding `--gpu-id 5` will give you an error.\n- **Example #2**: to run inference on devices 0 and 3 with a max of 10GiB on 0 and 40GiB on 3, with device 0 as the primary device, add: `-g 0:10 3:40`. In this example, `--gpu-id` is not required because device 0 is specified in `-g`.\n- **Example #3**: to run inference only on device 1 with a max of 75 GiB, add: `--gpu-id 1 -g 1:75`\n\n\n## Running on consumer hardware\nIf you have multiple GPUs, each <48 GB VRAM, [the steps mentioned in this section on running on multiple GPUs](#running-on-multiple-gpus) still apply, unless, any of these apply:\n- Running on just 1x GPU with <48 GB VRAM,\n- <48 GB VRAM combined across multiple GPUs\n- Running into Out-Of-Memory (OOM) issues\n\nIn which case, add the flag `-r CPU_RAM` where CPU_RAM is the maximum amount of RAM you'd like to allocate to loading model. Note: This significantly reduces inference speeds. \n\nThe model will load without specifying `-r`, however, it is not recommended because it will allocate all available RAM to the model. To limit how much RAM the model can use, add `-r`.\n\nIf the total VRAM + CPU_RAM < the size of the model in GiB, the rest of the model will be offloaded to a folder \"offload\" at the root of the directory. Note: This significantly reduces inference speeds.\n\n- Example: `-g 0:12 -r 20` will first load up to 12 GiB of the model into the CUDA device 0, then load up to 20 GiB into RAM, and load the rest into the \"offload\" directory.\n\nHow it works: \n- https://github.com/huggingface/blog/blob/main/accelerate-large-models.md\n- https://www.youtube.com/embed/MWCSGj9jEAo\n"
  },
  {
    "path": "inference/bot.py",
    "content": "import os\nimport sys\n\nINFERENCE_DIR = os.path.dirname(os.path.abspath(__file__))\n\n# TODO: PYTHONPATH hacks are never a good idea. clean this up later\nsys.path.append(os.path.join(INFERENCE_DIR, '..'))\n\nimport cmd\nimport torch\nimport argparse\nimport conversation as convo\nimport retrieval.wikipedia as wp\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList\nfrom accelerate import infer_auto_device_map, init_empty_weights\n\n\nclass StopWordsCriteria(StoppingCriteria):\n    def __init__(self, tokenizer, stop_words, stream_callback):\n        self._tokenizer = tokenizer\n        self._stop_words = stop_words\n        self._partial_result = ''\n        self._stream_buffer = ''\n        self._stream_callback = stream_callback\n\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        first = not self._partial_result\n        text = self._tokenizer.decode(input_ids[0, -1])\n        self._partial_result += text\n        for stop_word in self._stop_words:\n            if stop_word in self._partial_result:\n                return True\n        if self._stream_callback:\n            if first:\n                text = text.lstrip()\n            # buffer tokens if the partial result ends with a prefix of a stop word, e.g. \"<hu\"\n            for stop_word in self._stop_words:\n                for i in range(1, len(stop_word)):\n                    if self._partial_result.endswith(stop_word[0:i]):\n                        self._stream_buffer += text\n                        return False\n            self._stream_callback(self._stream_buffer + text)\n            self._stream_buffer = ''\n        return False\n\n\nclass ChatModel:\n    human_id = \"<human>\"\n    bot_id = \"<bot>\"\n\n    def __init__(self, model_name, gpu_id, max_memory):\n        device = torch.device('cuda', gpu_id)   # TODO: allow sending to cpu\n\n        # recommended default for devices with > 40 GB VRAM\n        # load model onto one device\n        if max_memory is None:\n            self._model = AutoModelForCausalLM.from_pretrained(\n                model_name, torch_dtype=torch.float16, device_map=\"auto\")\n            self._model.to(device)\n        # load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)\n        else:\n            config = AutoConfig.from_pretrained(model_name)\n            # load empty weights\n            with init_empty_weights():\n                model_from_conf = AutoModelForCausalLM.from_config(config)\n\n            model_from_conf.tie_weights()\n\n            # create a device_map from max_memory\n            device_map = infer_auto_device_map(\n                model_from_conf,\n                max_memory=max_memory,\n                no_split_module_classes=[\"GPTNeoXLayer\"],\n                dtype=\"float16\"\n            )\n            # load the model with the above device_map\n            self._model = AutoModelForCausalLM.from_pretrained(\n                model_name,\n                device_map=device_map,\n                offload_folder=\"offload\",  # optional offload-to-disk overflow directory (auto-created)\n                offload_state_dict=True,\n                torch_dtype=torch.float16\n            )\n        self._tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):\n        stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback)\n        inputs = (\n            self._tokenizer(prompt, return_tensors='pt')\n            .to(self._model.device)\n        )\n        outputs = self._model.generate(\n            **inputs,\n            max_new_tokens=max_new_tokens,\n            do_sample=do_sample,\n            temperature=temperature,\n            top_k=top_k,\n            pad_token_id=self._tokenizer.eos_token_id,\n            stopping_criteria=StoppingCriteriaList([stop_criteria]),\n        )\n        output = self._tokenizer.batch_decode(outputs)[0]\n\n        # remove the context from the output\n        output = output[len(prompt):]\n\n        return output\n\n\nclass OpenChatKitShell(cmd.Cmd):\n    intro = \"Welcome to OpenChatKit shell.   Type /help or /? to list commands.\\n\"\n    prompt = \">>> \"\n\n    def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):\n        super().__init__()\n        self._gpu_id = gpu_id\n        self._model_name_or_path = model_name_or_path\n        self._max_tokens = max_tokens\n        self._sample = sample\n        self._temperature = temperature\n        self._top_k = top_k\n        self._retrieval = retrieval\n        self._max_memory = max_memory\n        self._do_stream = do_stream\n\n    def preloop(self):\n        print(f\"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...\")\n        self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory)\n\n        if self._retrieval:\n            print(f\"Loading retrieval index...\")\n            self._index = wp.WikipediaIndex()\n\n        self._convo = convo.Conversation(\n            self._model.human_id, self._model.bot_id)\n\n    def precmd(self, line):\n        if line.startswith('/'):\n            return line[1:]\n        else:\n            return 'say ' + line\n\n    def do_say(self, arg):\n        if self._retrieval:\n            results = self._index.search(arg)\n            if len(results) > 0:\n                self._convo.push_context_turn(results[0])\n\n        self._convo.push_human_turn(arg)\n\n        output = self._model.do_inference(\n            self._convo.get_raw_prompt(),\n            self._max_tokens,\n            self._sample,\n            self._temperature,\n            self._top_k,\n            lambda x : print(x, end='', flush=True) if self._do_stream else None,\n        )\n\n        self._convo.push_model_response(output)\n\n        print(\"\" if self._do_stream else self._convo.get_last_turn())\n\n    def do_raw_say(self, arg):\n        output = self._model.do_inference(\n            arg,\n            self._max_tokens,\n            self._sample,\n            self._temperature,\n            self._top_k\n        )\n\n        print(output)\n\n    def do_raw_prompt(self, arg):\n        print(self._convo.get_raw_prompt())\n\n    def do_reset(self, arg):\n        self._convo = convo.Conversation(\n            self._model.human_id, self._model.bot_id)\n\n    def do_hyperparameters(self, arg):\n        print(\n            f\"Hyperparameters:\\n\"\n            f\"  max_tokens: {self._max_tokens}\\n\"\n            f\"  sample: {self._sample}\\n\"\n            f\"  temperature: {self._temperature}\\n\"\n            f\"  top_k: {self._top_k}\"\n        )\n\n    def do_quit(self, arg):\n        return True\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description='test harness for OpenChatKit')\n\n    parser.add_argument(\n        '--gpu-id',\n        default=0,\n        type=int,\n        help='the ID of the GPU to run on'\n    )\n    parser.add_argument(\n        '--model',\n        default=f\"{INFERENCE_DIR}/../huggingface_models/Pythia-Chat-Base-7B\",\n        help='name/path of the model'\n    )\n    parser.add_argument(\n        '--max-tokens',\n        default=128,\n        type=int,\n        help='the maximum number of tokens to generate'\n    )\n    parser.add_argument(\n        '--sample',\n        default=True,\n        action='store_true',\n        help='indicates whether to sample'\n    )\n    parser.add_argument(\n        '--no-stream',\n        action='store_true',\n        help='indicates whether to stream tokens'\n    )\n    parser.add_argument(\n        '--temperature',\n        default=0.6,\n        type=float,\n        help='temperature for the LM'\n    )\n    parser.add_argument(\n        '--top-k',\n        default=40,\n        type=int,\n        help='top-k for the LM'\n    )\n    parser.add_argument(\n        '--retrieval',\n        default=False,\n        action='store_true',\n        help='augment queries with context from the retrieval index'\n    )\n    parser.add_argument(\n        '-g',\n        '--gpu-vram',\n        action='store',\n        help='max VRAM to allocate per GPU',\n        nargs='+',\n        required=False,\n    )\n    parser.add_argument(\n        '-r',\n        '--cpu-ram',\n        default=None,\n        type=int,\n        help='max CPU RAM to allocate',\n        required=False\n    )\n    args = parser.parse_args()\n\n    # set max_memory dictionary if given\n    if args.gpu_vram is None:\n        max_memory = None\n    else:\n        max_memory = {}\n        for i in range(len(args.gpu_vram)):\n            # assign CUDA ID as label and XGiB as value\n            max_memory[int(args.gpu_vram[i].split(':')[0])] = f\"{args.gpu_vram[i].split(':')[1]}GiB\"\n\n        if args.cpu_ram is not None:\n            # add cpu to max-memory if given\n            max_memory['cpu'] = f\"{int(args.cpu_ram)}GiB\"\n\n    OpenChatKitShell(\n        args.gpu_id,\n        args.model,\n        args.max_tokens,\n        args.sample,\n        args.temperature,\n        args.top_k,\n        args.retrieval,\n        max_memory,\n        not args.no_stream,\n    ).cmdloop()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "inference/conversation.py",
    "content": "import re\nimport time\n\nMEANINGLESS_WORDS = ['<pad>', '</s>', '<|endoftext|>']\nPRE_PROMPT = \"\"\"\\\nCurrent Date: {}\nCurrent Time: {}\n\n\"\"\"\n\ndef clean_response(response):\n    for word in MEANINGLESS_WORDS:\n        response = response.replace(word, \"\")\n    response = response.strip(\"\\n\")\n    return response\n\nclass Conversation:\n    def __init__(self, human_id, bot_id):\n        cur_date = time.strftime('%Y-%m-%d')\n        cur_time = time.strftime('%H:%M:%S %p %Z')\n\n        self._human_id = human_id\n        self._bot_id = bot_id\n        self._prompt = PRE_PROMPT.format(cur_date, cur_time)\n\n    def push_context_turn(self, context):\n        # for now, context is represented as a human turn\n        self._prompt += f\"{self._human_id}: {context}\\n\"\n\n    def push_human_turn(self, query):\n        self._prompt += f\"{self._human_id}: {query}\\n\"\n        self._prompt += f\"{self._bot_id}:\"\n\n    def push_model_response(self, response):\n        has_finished = self._human_id in response\n        bot_turn = response.split(f\"{self._human_id}:\")[0]\n        bot_turn = clean_response(bot_turn)\n        # if it is truncated, then append \"...\" to the end of the response\n        if not has_finished:\n            bot_turn += \"...\"\n\n        self._prompt += f\"{bot_turn}\\n\"\n\n    def get_last_turn(self):\n        human_tag = f\"{self._human_id}:\"\n        bot_tag = f\"{self._bot_id}:\"\n        turns = re.split(f\"({human_tag}|{bot_tag})\\W?\", self._prompt)\n        return turns[-1]\n\n    def get_raw_prompt(self):\n        return self._prompt\n\n    @classmethod\n    def from_raw_prompt(cls, value):\n        self._prompt = value\n"
  },
  {
    "path": "pretrained/GPT-NeoX-20B/prepare.py",
    "content": "import sys\nimport os\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_pretrained import prepare_pretrained\n\nif __name__ == \"__main__\":\n    model_name = \"EleutherAI/gpt-neox-20b\"\n    save_path = os.path.join(current_dir, model_name.replace('/', '_'))\n    prepare_pretrained(save_path, model_name)\n"
  },
  {
    "path": "pretrained/Llama-2-7B-32K-beta/prepare.py",
    "content": "import os\nimport argparse\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n\nDIR = os.path.dirname(os.path.abspath(__file__))\nUSE_AUTH_TOKEN = False\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Convert HF checkpoints')\n    parser.add_argument('--model-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta', \n                        help='model-name')\n    parser.add_argument('--save-dir', type=str, default=DIR, \n                        help='model-name')\n    parser.add_argument('--offload-dir', type=str, default=None,\n                        help='directory to offload from memory')\n    args = parser.parse_args()\n    \n    if not os.path.exists(args.save_dir):\n        os.mkdir(args.save_dir)\n    save_path = os.path.join(args.save_dir, args.model_name.replace('/', '_'))\n    if not os.path.exists(save_path):\n        os.mkdir(save_path)\n    \n    print('loading model from HF...')\n    config = AutoConfig.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)\n    config.save_pretrained(save_path)\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)\n    tokenizer.save_pretrained(save_path)\n\n    # offload model from memory to disk if offload-dir is specified\n    if args.offload_dir is not None:\n        if not os.path.exists(args.offload_dir):\n            os.mkdir(args.offload_dir)\n        model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map=\"auto\", offload_folder=args.offload_dir, use_auth_token=USE_AUTH_TOKEN)\n    else:\n        model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_auth_token=USE_AUTH_TOKEN)\n    print('loaded model from HF...')\n    \n    print('converting the embedding layer...')\n    item = {}\n    item['embed_tokens.weight'] = model.model.embed_tokens.weight\n    torch.save(item, os.path.join(save_path, 'pytorch_embs.pt'))\n    print('converted the embedding layer.')\n\n    for i in range(len(model.model.layers)):\n        print(f'converting the {i}-th transformer layer...')\n        torch.save(model.model.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt'))\n        print(f'converted the {i}-th transformer layer.')\n\n    print('converting the lm_head layer...')\n    item = {}\n    item['lm_head.weight'] = model.lm_head.weight\n    item['norm.weight'] = model.model.norm.weight\n    torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt'))\n    print('converted the lm_head layer.')\n"
  },
  {
    "path": "pretrained/Pythia-6.9B-deduped/prepare.py",
    "content": "import sys\nimport os\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_pretrained import prepare_pretrained\n\nif __name__ == \"__main__\":\n    model_name = \"EleutherAI/pythia-6.9b-deduped\"\n    save_path = os.path.join(current_dir, model_name.replace('/', '_'))\n    prepare_pretrained(save_path, model_name)\n"
  },
  {
    "path": "pretrained/RedPajama-3B/prepare.py",
    "content": "import os\nimport sys\n\n# Import the prepare_data function\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(os.path.join(current_dir, '..'))\nfrom prepare_pretrained import prepare_pretrained\n\nif __name__ == \"__main__\":\n    model_name = \"togethercomputer/RedPajama-INCITE-Chat-3B-v1\"\n    save_path = os.path.join(current_dir, model_name.replace('/', '_'))\n    prepare_pretrained(save_path, model_name)\n"
  },
  {
    "path": "pretrained/RedPajama-7B/prepare.py",
    "content": "import os\nimport argparse\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n\nDIR = os.path.dirname(os.path.abspath(__file__))\nUSE_AUTH_TOKEN = False\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Convert HF checkpoints')\n    parser.add_argument('--model-name', type=str, default='togethercomputer/RedPajama-INCITE-7B-Chat', \n                        help='model-name')\n    parser.add_argument('--save-dir', type=str, default=DIR, \n                        help='model-name')\n    parser.add_argument('--offload-dir', type=str, default=None,\n                        help='directory to offload from memory')\n    args = parser.parse_args()\n    \n    if not os.path.exists(args.save_dir):\n        os.mkdir(args.save_dir)\n    save_path = os.path.join(args.save_dir, args.model_name.replace('/', '_'))\n    if not os.path.exists(save_path):\n        os.mkdir(save_path)\n    \n    print('loading model from HF...')\n    config = AutoConfig.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)\n    config.save_pretrained(save_path)\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)\n    tokenizer.save_pretrained(save_path)\n\n    # offload model from memory to disk if offload-dir is specified\n    if args.offload_dir is not None:\n        if not os.path.exists(args.offload_dir):\n            os.mkdir(args.offload_dir)\n        model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map=\"auto\", offload_folder=args.offload_dir, use_auth_token=USE_AUTH_TOKEN)\n    else:\n        model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_auth_token=USE_AUTH_TOKEN)\n    print('loaded model from HF...')\n    \n    print('converting the embedding layer...')\n    \n    item = {}\n    item['embed_in.weight'] = model.gpt_neox.embed_in.weight\n    torch.save(item, os.path.join(save_path, 'pytorch_embs.pt'))\n    print('converted the embedding layer.')\n\n    for i in range(len(model.gpt_neox.layers)):\n        print(f'converting the {i}-th transformer layer...')\n        torch.save(model.gpt_neox.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt'))\n        print(f'converted the {i}-th transformer layer.')\n\n    print('converting the lm_head layer...')\n    item = {}\n    item['embed_out.weight'] = model.embed_out.weight\n    item['final_layer_norm.weight'] = model.gpt_neox.final_layer_norm.weight\n    item['final_layer_norm.bias'] = model.gpt_neox.final_layer_norm.bias\n    torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt'))\n    print('converted the lm_head layer.')\n"
  },
  {
    "path": "pretrained/prepare_pretrained.py",
    "content": "import os\nimport argparse\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n\nDIR = os.path.dirname(os.path.abspath(__file__))\nUSE_AUTH_TOKEN = False\n\n# Load pretrained model from HuggingFace and save it to disk\ndef prepare_pretrained(save_path, model_name, offload_dir=None):\n    os.makedirs(save_path, exist_ok=True)\n    \n    print('loading model from HF...')\n    config = AutoConfig.from_pretrained(model_name, use_auth_token=USE_AUTH_TOKEN)\n    config.save_pretrained(save_path)\n    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=USE_AUTH_TOKEN)\n    tokenizer.save_pretrained(save_path)\n\n    # offload model from memory to disk if offload-dir is specified\n    if offload_dir is not None:\n        os.makedirs(offload_dir, exist_ok=True)\n        model = AutoModelForCausalLM.from_pretrained(model_name, \n                                                     torch_dtype=torch.float16,\n                                                     device_map=\"auto\",\n                                                     offload_folder=offload_dir,\n                                                     use_auth_token=USE_AUTH_TOKEN)\n    else:\n        model = AutoModelForCausalLM.from_pretrained(model_name,\n                                                     torch_dtype=torch.float16,\n                                                     use_auth_token=USE_AUTH_TOKEN)\n    print('loaded model from HF...')\n    \n    print('converting the embedding layer...')\n    item = {}\n    item['embed_in.weight'] = model.gpt_neox.embed_in.weight\n    torch.save(item, os.path.join(save_path, 'pytorch_embs.pt'))\n    print('converted the embedding layer.')\n\n    for i in range(len(model.gpt_neox.layers)):\n        print(f'converting the {i}-th transformer layer...')\n        torch.save(model.gpt_neox.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt'))\n        print(f'converted the {i}-th transformer layer.')\n\n    print('converting the lm_head layer...')\n    item = {}\n    item['embed_out.weight'] = model.embed_out.weight\n    item['final_layer_norm.weight'] = model.gpt_neox.final_layer_norm.weight\n    item['final_layer_norm.bias'] = model.gpt_neox.final_layer_norm.bias\n    torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt'))\n    print('converted the lm_head layer.')\n\n# python pretrained/prepare_pretrained.py --model-name EleutherAI/gpt-neox-125M --save-dir pretrained/files --offload-dir pretrained/files/offload\ndef main():\n    parser = argparse.ArgumentParser(description='Convert HF checkpoints')\n    parser.add_argument('--model-name', type=str, required=True, \n                        help='model-name')\n    parser.add_argument('--save-dir', type=str, required=True, \n                        help='model-name')\n    parser.add_argument('--offload-dir', type=str, default=None,\n                        help='directory to offload from memory')\n    args = parser.parse_args()\n    \n    prepare_pretrained(args.save_dir, args.model_name, args.offload_dir)\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "retrieval/README.md",
    "content": "# Retrieval-Enhanced Chatbot\n\nThis is a demonstration of how to enhance a chatbot using Wikipedia. We'll be using [ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index](https://huggingface.co/datasets/ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index). for this demo. Thank Christoph for providing this resource!\n\nIn this demo, we'll be extending the approach of comparing and adding the adjacent `w` sentences to the matched sentence if their cosine similarity is larger than `w_th`. By doing so, we can provide the chatbot with a longer context, which may improve its performance.\n\nThis demo combines both the above index and the chat model into one system\n\n## Start the combined  server\n\nTo get started, we need to install some dependencies and download the Wikipedia index:\n\n0. Install dependencies\n\nInstall the necessary dependencies, including `torch`, `transformers`, `flask`, `faiss`, and `fastparquet`.\n\n1. Open up wiki-server.py and set model_name_or_path to point to the path that contains the chat\nmodel\n\n\n2. Start the retrieval server\n\n```shell\npython wiki-server.py\n```\n\nThe server will listen on port 7003.  It will download the data sets from ChristophSchuhman.  This\nmay take a few minutes.\n\n3. Test the full retrieval enhanced chatbot\n\nWe now demonstrate both the wiki index and the GPT-NeoX-fine-tuned model.\n\n```curl -X POST -H 'Content-Type: application/json' http://127.0.0.1:7003/inference -d '{ \"prompt\" : \"where is zurich located?\" }'```\n\nInternally we first query the wiki index and generate a response using the provided model.  To do\nthis, We concatenate the retrieved information and the users' query into a prompt, \nencode it with a tokenizer, and generate a response using the chatbot model.\n\nThe response should indicate the location of Zurich city.\n\n\n4. To test just the retrieval functionality of the system you can can do the following.  Curl works\nas well.\n\n```python\nimport requests\n\nendpoint = 'http://127.0.0.1:7003/search'\nres = requests.post(endpoint, json={\n    'query': 'Where is Zurich?',\n    'k': 1,\n    'w': 5,\n    'w_th': 0.7,\n})\nprint(res.json())\n```\n\nThis should print the most relevant sentences about Zurich from Wikipedia. By increasing w and \ndecreasing w_th, we can retrieve a longer context.\n\n\n"
  },
  {
    "path": "retrieval/__init__.py",
    "content": ""
  },
  {
    "path": "retrieval/wikipedia.py",
    "content": "# This file was adapted from ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index:\n#   https://huggingface.co/datasets/ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index/blob/main/wikiindexquery.py\n#\n# The original file was licensed under the Apache 2.0 license.\n\nimport os\n\nfrom transformers import AutoTokenizer, AutoModel\nimport faiss\nimport numpy as np\nimport pandas as pd\n\nDIR = os.path.dirname(os.path.abspath(__file__))\n\n\ndef mean_pooling(token_embeddings, mask):\n    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)\n    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]\n    return sentence_embeddings\n\ndef cos_sim_2d(x, y):\n    norm_x = x / np.linalg.norm(x, axis=1, keepdims=True)\n    norm_y = y / np.linalg.norm(y, axis=1, keepdims=True)\n    return np.matmul(norm_x, norm_y.T)\n\n\nclass WikipediaIndex:\n    def __init__(self):\n        path = os.path.join(DIR, '..', 'data', 'wikipedia-3sentence-level-retrieval-index', 'files')\n        indexpath = os.path.join(path, 'knn.index')\n        wiki_sentence_path = os.path.join(path, 'wikipedia-en-sentences.parquet')\n\n        self._device = 'cuda'\n        self._tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')\n        self._contriever = AutoModel.from_pretrained('facebook/contriever-msmarco').to(self._device)\n\n        self._df_sentences = pd.read_parquet(wiki_sentence_path, engine='fastparquet')\n\n        self._wiki_index = faiss.read_index(indexpath, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)\n\n\n    def search(self, query, k=1, w=5, w_th=0.5):\n        inputs = self._tokenizer(query, padding=True, truncation=True, return_tensors='pt').to(self._device)\n        outputs = self._contriever(**inputs)\n        embeddings = mean_pooling(outputs[0], inputs['attention_mask'])\n        \n        query_vector = embeddings.cpu().detach().numpy().reshape(1, -1)\n        \n        distances, indices = self._wiki_index.search(query_vector, k)\n        \n        texts = []\n        for i, (dist, indice) in enumerate(zip(distances[0], indices[0])):\n            text = self._df_sentences.iloc[indice]['text_snippet']\n\n            try:\n                input_texts = [self._df_sentences.iloc[indice]['text_snippet']]\n                for j in range(1, w+1):\n                    input_texts = [self._df_sentences.iloc[indice-j]['text_snippet']] + input_texts\n                for j in range(1, w+1):\n                    input_texts = input_texts + [self._df_sentences.iloc[indice+j]['text_snippet']]\n                \n                inputs = self._tokenizer(input_texts, padding=True, truncation=True, return_tensors='pt').to(self._device)\n\n                outputs = self._contriever(**inputs)\n                embeddings = mean_pooling(outputs[0], inputs['attention_mask']).detach().cpu().numpy()\n\n                for j in range(1, w+1):\n                    if cos_sim_2d(embeddings[w-j].reshape(1, -1), embeddings[w].reshape(1, -1)) > w_th:\n                        text = self._df_sentences.iloc[indice-j]['text_snippet'] + text\n                    else:\n                        break\n\n                for j in range(1, w+1):\n                    if cos_sim_2d(embeddings[w+j].reshape(1, -1), embeddings[w].reshape(1, -1)) > w_th:\n                        text += self._df_sentences.iloc[indice+j]['text_snippet']\n                    else:\n                        break\n\n            except Exception as e:\n                print(e)\n\n            texts.append(text)\n        \n        return texts\n"
  },
  {
    "path": "tools/README.md",
    "content": "# OpenChatKit Tools\n\n## convert_to_hf_gptneox.py\n\n## ml_load_benchmark.py\n\nThe commands to run the model load benchmark tool is:\n```shell\n$ python3 model_load_benchmark.py -i benchmark_input.json -o benchmark_results.json -d cuda:0\n```\n\n```\nusage: model_load_benchmark.py [-h] -i INPUT -o OUTPUT [-d DEVICE] [-r REPEAT_INFER]\n\nBenchmark downloading, loading, and running an inferernce for a set of ML models.\n\noptional arguments:\n  -h, --help            show this help message and exit\n  -i INPUT, --input INPUT\n                        Input JSON file containing models to be benchmark\n  -o OUTPUT, --output OUTPUT\n                        Output JSON file with model benchmark results\n  -d DEVICE, --device DEVICE\n                        Cuda device name, e.g. \"cuda:0\"\n  -r REPEAT_INFER, --repeat-infer REPEAT_INFER\n                        Repeat inferrence for warm timings\n```\n\nThe input file is a JSON file with the names and paths of the models to be tested. For example:\n```JSON\n{\n    \"GPT-NeoXT-Chat-Base-20B\": \"togethercomputer/GPT-NeoXT-Chat-Base-20B\",\n    \"Pythia-Chat-Base-7B\": \"togethercomputer/Pythia-Chat-Base-7B\",\n    \"GPT-JT-Moderation-6B\": \"togethercomputer/GPT-JT-Moderation-6B\",\n    \"GPT-JT-6B-v1\": \"togethercomputer/GPT-JT-6B-v1\",\n    \"GPT-JT-6B-v0\": \"togethercomputer/GPT-JT-6B-v0\"\n}\n```\n\nThe output is a json file with the timings for:\n1. tokenizer download time in seconds -- `tokenizer_download_sec`\n2. tokenizer load time in seconds -- `tokenizer_load_sec`\n3. model download time -- `model_download_sec`\n5. model load to RAM time -- `model_load_to_ram_sec`\n6. model transfer to GPU time -- `model_transfer_to_gpu_sec`\n7. inference time (input is \"hello, world!\") -- `inference_sec`\n8. total time (sum of all the above) -- `total_sec`\n9. inference time from a warm start (the average of running inference `REPEAT_INFER` times) -- `inference_warm_sec`\n10. model main memory footprint in MB -- `model_main_memory_MB`\n11. model GPU memory footprint in MB -- `model_gpu_memory_MB`\n\nAn example of the output is:\n```JSON\n{\n    \"GPT-JT-6B-v1\": {\n        \"tokenizer_download_sec\": 1.52,\n        \"tokenizer_load_sec\": 0.10,\n        \"model_download_sec\": 124.70,\n        \"model_load_to_ram_sec\": 127.81,\n        \"model_main_memory_MB\": 12297.10,\n        \"model_transfer_to_gpu_sec\": 3.29,\n        \"model_gpu_memory_MB\": 12219.74,\n        \"inference_sec\": 0.93,\n        \"inference_warm_sec\": 0.047,\n        \"total_sec\": 258.38\n    }\n}\n```"
  },
  {
    "path": "tools/benchmark_input.json",
    "content": "{\n    \"GPT-NeoXT-Chat-Base-20B\": \"togethercomputer/GPT-NeoXT-Chat-Base-20B\",\n    \"Pythia-Chat-Base-7B\": \"togethercomputer/Pythia-Chat-Base-7B\",\n    \"GPT-JT-Moderation-6B\": \"togethercomputer/GPT-JT-Moderation-6B\",\n    \"GPT-JT-6B-v1\": \"togethercomputer/GPT-JT-6B-v1\",\n    \"GPT-JT-6B-v0\": \"togethercomputer/GPT-JT-6B-v0\"\n}"
  },
  {
    "path": "tools/convert_to_hf_gptneox.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport argparse\n\nfrom transformers import GPTNeoXForCausalLM\n\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom transformers.modeling_utils import no_init_weights\nimport os\n\n\ndef create_empty_gptneox(config):\n\n    import torch\n    import torch.nn as nn\n\n    _reset_parameters_linear = nn.Linear.reset_parameters\n    def dummy(*args, **kargs):\n        pass\n    nn.Linear.reset_parameters = dummy\n\n    # 1. disable init for faster initialization\n    # 2. avoid tie token embeddings with lm_head, as we train them separately.\n    with no_init_weights(_enable=True):\n        model = GPTNeoXForCausalLM(config).eval()\n\n    nn.Linear.reset_parameters = _reset_parameters_linear\n\n    return model\n\ndef load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_per_stage=14):\n    input_path = checkpoint_path\n\n    assert n_stages * n_layer_per_stage >= len(model.gpt_neox.layers)\n    # assert model.lm_head.weight.data is not model.transformer.wte.weight.data\n\n    for i in range(n_stages):\n\n        print(f'loading stage {i}')\n\n        checkpoint = torch.load(os.path.join(input_path, f'prank_{i}_checkpoint.pt'), map_location=torch.device(\"cpu\"))\n\n        if i == 0:\n            _tmp = {k[len(f\"{0}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"0.\")}\n            # torch.save(_tmp, os.path.join(output_path, f'pytorch_embs.pt'))\n            model.gpt_neox.embed_in.weight.data[:] = _tmp['embed_in.weight']\n\n            for j in range(n_layer_per_stage):\n                _tmp = {k[len(f\"{j+1}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j+1}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{j}.pt'))\n                model.gpt_neox.layers[j].load_state_dict(_tmp)\n\n        elif i == n_stages - 1:\n            for j in range(n_layer_per_stage):\n                _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))\n                model.gpt_neox.layers[i*n_layer_per_stage + j].load_state_dict(_tmp)\n                if i*n_layer_per_stage + j == len(model.gpt_neox.layers) - 1:\n                    j += 1\n                    break\n\n            _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n            if len(_tmp) == 0:\n                break\n            # torch.save(_tmp, os.path.join(output_path, f'pytorch_lm_head.pt'))\n            model.gpt_neox.final_layer_norm.weight.data[:] = _tmp['final_layer_norm.weight']\n            model.gpt_neox.final_layer_norm.bias.data[:] = _tmp['final_layer_norm.bias']\n            model.embed_out.weight.data[:] = _tmp['embed_out.weight']\n            if 'embed_out.bias' in _tmp:\n                model.embed_out.bias.data[:] = _tmp['embed_out.bias']\n\n        else:\n            for j in range(n_layer_per_stage):\n                _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))\n                model.gpt_neox.layers[i*n_layer_per_stage + j].load_state_dict(_tmp)\n\n    return model\n\n\nif __name__ == '__main__':\n    \n    parser = argparse.ArgumentParser(description='Convert HF checkpoints')\n    parser.add_argument('--config-name', type=str, default='EleutherAI/gpt-neox-20b',\n                        help='config-name')\n    parser.add_argument('--ckpt-path', type=str, default=None, \n                        help='ckpt-path')\n    parser.add_argument('--save-path', type=str, default=None, \n                        help='save-path')\n    parser.add_argument('--n-stages', type=int, default=8, \n                        help='pipeline group size')\n    parser.add_argument('--n-layer-per-stage', type=int, default=6, \n                        help='n layers per GPU device')\n    parser.add_argument('--fp16', default=False, action='store_true')\n    args = parser.parse_args()\n    \n    assert args.ckpt_path is not None\n    assert args.save_path is not None\n    \n    os.makedirs(args.save_path, exist_ok=True)\n\n    print('loading config...')\n    config = AutoConfig.from_pretrained(args.config_name)\n    print('loaded config.')\n    print('loading tokenizer...')\n    tokenizer = AutoTokenizer.from_pretrained(args.config_name)\n    print('loaded tokenizer.')\n    print('creating empty model...')\n    model = create_empty_gptneox(config)\n    if args.fp16:\n        model = model.half()\n    print('created empty model.')\n    print('loading model ckpt...')\n    load_decentralized_checkpoint(\n        model, args.ckpt_path, n_stages=args.n_stages, n_layer_per_stage=args.n_layer_per_stage,\n    )\n    print('loaded model ckpt.')\n    \n    print('saving HF model...')\n    model.save_pretrained(args.save_path)\n    print(f'saved HF model to `{args.save_path}`')\n    config.save_pretrained(args.save_path)\n    tokenizer.save_pretrained(args.save_path)\n    \n"
  },
  {
    "path": "tools/convert_to_hf_llama.py",
    "content": "import os\nimport argparse\nimport torch\n\nimport torch\nimport torch.nn as nn\n\nfrom transformers import LlamaForCausalLM\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom transformers.modeling_utils import no_init_weights\nimport os\n\n\ndef create_emtpy_llama(config):\n\n    import torch\n    import torch.nn as nn\n\n    _reset_parameters_linear = nn.Linear.reset_parameters\n    def dummy(*args, **kargs):\n        pass\n    nn.Linear.reset_parameters = dummy\n\n    # 1. disable init for faster initialization\n    # 2. avoid tie token embeddings with lm_head, as we train them separately.\n    with no_init_weights(_enable=True):\n        model = LlamaForCausalLM(config).eval()\n\n    nn.Linear.reset_parameters = _reset_parameters_linear\n\n    return model\n\ndef load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_per_stage=16, ):\n    input_path = checkpoint_path\n\n    n_layers = len(model.model.layers)\n    assert n_stages * n_layer_per_stage >= len(model.model.layers)\n    # assert model.lm_head.weight.data is not model.transformer.wte.weight.data\n\n    for i in range(n_stages):\n\n        print(f'loading stage {i}')\n\n        checkpoint = torch.load(os.path.join(input_path, f'prank_{i}_checkpoint.pt'), map_location=torch.device(\"cpu\"))\n\n        if i == 0:\n            _tmp = {k[len(f\"{0}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"0.\")}\n            # torch.save(_tmp, os.path.join(output_path, f'pytorch_embs.pt'))\n            model.model.embed_tokens.weight.data[:] = _tmp['embed_tokens.weight']\n\n            for j in range(n_layer_per_stage):\n                _tmp = {k[len(f\"{j+1}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j+1}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{j}.pt'))\n                ret = model.model.layers[j].load_state_dict(_tmp, strict=False)\n                if len(ret.missing_keys):\n                    print('The following weight keys are missing:')\n                    print(ret.missing_keys)\n                if len(ret.unexpected_keys):\n                    print('The following weight keys are unexpected:')\n                    print(ret.unexpected_keys)\n\n        elif i == n_stages - 1:\n            for j in range(n_layer_per_stage):\n                if i*n_layer_per_stage + j == n_layers:\n                    break\n                _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))\n                ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False)\n                if len(ret.missing_keys):\n                    print('The following weight keys are missing:')\n                    print(ret.missing_keys)\n                if len(ret.unexpected_keys):\n                    print('The following weight keys are unexpected:')\n                    print(ret.unexpected_keys)\n            else:\n                j += 1\n\n            _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n            if len(_tmp) == 0:\n                break\n            # torch.save(_tmp, os.path.join(output_path, f'pytorch_lm_head.pt'))\n            model.model.norm.weight.data[:] = _tmp['norm.weight']\n            if 'norm.bias' in _tmp:\n                model.model.norm.bias.data[:] = _tmp['norm.bias']\n            model.lm_head.weight.data[:] = _tmp['lm_head.weight']\n            if 'lm_head.bias' in _tmp:\n                model.lm_head.bias.data[:] = _tmp['lm_head.bias']\n\n        else:\n            for j in range(n_layer_per_stage):\n                _tmp = {k[len(f\"{j}.\"):]:v for k,v in checkpoint.items() if k.startswith(f\"{j}.\")}\n                if len(_tmp) == 0:\n                    break\n                # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))\n                ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False)\n                if len(ret.missing_keys):\n                    print('The following weight keys are missing:')\n                    print(ret.missing_keys)\n                if len(ret.unexpected_keys):\n                    print('The following weight keys are unexpected:')\n                    print(ret.unexpected_keys)\n\n    return model\n\n\nif __name__ == '__main__':\n    \n    parser = argparse.ArgumentParser(description='Convert HF checkpoints')\n    parser.add_argument('--config-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta',\n                        help='config-name')\n    parser.add_argument('--ckpt-path', type=str, default=None, \n                        help='ckpt-path')\n    parser.add_argument('--save-path', type=str, default=None, \n                        help='save-path')\n    parser.add_argument('--n-stages', type=int, default=8, \n                        help='pipeline group size')\n    parser.add_argument('--n-layer-per-stage', type=int, default=4, \n                        help='n layers per GPU device')\n    parser.add_argument('--fp16', default=False, action='store_true')\n    args = parser.parse_args()\n    \n    assert args.ckpt_path is not None\n    assert args.save_path is not None\n    \n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n\n    # LlamaForCausalLM LlamaConfig LlamaTokenizer\n    print('loading config...')\n    config = AutoConfig.from_pretrained(args.config_name)\n    print('loaded config.')\n    print('loading tokenizer...')\n    tokenizer = AutoTokenizer.from_pretrained(args.config_name)\n    print('loaded tokenizer.')\n    print('creating empty model...')\n    model = create_emtpy_llama(config)\n    if args.fp16:\n        model = model.half()\n    print('created empty model.')\n    print('loading model ckpt...')\n    load_decentralized_checkpoint(\n        model, args.ckpt_path, n_stages=args.n_stages, n_layer_per_stage=args.n_layer_per_stage,\n    )\n    print('loaded model ckpt.')\n    \n    print('saving HF model...')\n    model.save_pretrained(args.save_path)\n    print(f'saved HF model to `{args.save_path}`')\n    config.save_pretrained(args.save_path)\n    tokenizer.save_pretrained(args.save_path)\n"
  },
  {
    "path": "tools/model_load_benchmark.py",
    "content": "import argparse\nimport json\nimport time\nimport torch\nimport torchvision\nimport os\nimport re\nimport psutil\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\n\n# Benchmark download, tokenize, load, inference time.\ndef benchmark(model_dict: dict, device_name: str, repeat_infer: int):\n\n    # Initialize the benchmark results dictionary\n    results_dict = {}\n\n    # Check that we have CUDA GPUs available before running the benchmark\n    if not torch.cuda.is_available():\n        print(\"ERROR: CUDA GPUs are not available, benchmark not run\")\n        return results_dict\n\n    device = torch.device(device_name)\n\n    process = psutil.Process()\n\n    print(f'Using device {device}')\n\n    # Loop through the models to test\n    for model_name, model_path in model_dict.items():\n        # purge unused cached memory\n        torch.cuda.empty_cache()\n\n        print(f\"Testing model: {model_name}\")\n\n        # Measure the time it takes to download the tokenizer data and load the tokenizer\n        tokenizer_download_start_time = time.time()\n        tokenizer = AutoTokenizer.from_pretrained(model_path, force_download=True)\n        tokenizer_download_end_time = time.time()\n\n        tokenizer = None\n\n        # Measure the time it takes to  load the tokenizer\n        tokenizer_load_start_time = time.time()\n        tokenizer = AutoTokenizer.from_pretrained(model_path)\n        tokenizer_load_end_time = time.time()\n\n        tokenizer_load_sec = tokenizer_load_end_time - tokenizer_load_start_time\n        tokenizer_download_sec = tokenizer_download_end_time - tokenizer_download_start_time - tokenizer_load_sec\n\n        print(f\"Testing model: {model_name} --- tokenizer download time = {tokenizer_download_sec:.3} sec\")\n        print(f\"Testing model: {model_name} --- tokenize load time = {tokenizer_load_sec:.3} sec\")\n\n        # Measure the time it takes to download and load the model into main memory\n        model_download_start_time = time.time()\n        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, torchscript=True, force_download=True)\n        model_download_end_time = time.time()\n        \n        model = None\n\n        # Measure the time it takes to load the model into main memory\n        memory_used_main_start = process.memory_info().rss\n        model_load_to_ram_start_time = time.time()\n        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, torchscript=True)\n        model_load_to_ram_end_time = time.time()\n        memory_used_main_end = process.memory_info().rss\n\n        model_load_to_ram_sec = model_load_to_ram_end_time - model_load_to_ram_start_time\n        model_download_sec = model_download_end_time - model_download_start_time - model_load_to_ram_sec\n        model_main_memory_bytes = memory_used_main_end - memory_used_main_start\n\n        print(f\"Testing model: {model_name} --- model download time = {model_download_sec:.3} sec\")\n        print(f\"Testing model: {model_name} --- model load to RAM time = {model_load_to_ram_sec:.3} sec\")\n        print(f\"Testing model: {model_name} --- model main memory size = {model_main_memory_bytes} bytes\")\n\n        # Measure the time it takes to load the model from main memory to the GPU\n        gpu_memory_start = torch.cuda.memory_allocated(device)\n        model_xfer_to_gpu_start_time = time.time()\n        model = model.to(device)\n        model_xfer_to_gpu_end_time = time.time()\n        gpu_memory_end = torch.cuda.memory_allocated(device)\n\n        model_xfer_to_gpu_sec = model_xfer_to_gpu_end_time - model_xfer_to_gpu_start_time\n        model_gpu_memory_bytes = gpu_memory_end - gpu_memory_start\n\n        print(f\"Testing model: {model_name} --- model transfer to GPU time = {model_xfer_to_gpu_sec:.3} sec\")\n        print(f\"Testing model: {model_name} --- model GPU memory size = {model_gpu_memory_bytes} bytes\")\n\n        # Measure the time it takes to run inference from a cold start\n        inference_start_time = time.time()\n        inputs = tokenizer(\"Hello, world!\", return_tensors=\"pt\").to(device)\n        outputs = model(**inputs)\n        inference_end_time = time.time()\n        inference_sec = inference_end_time - inference_start_time\n\n        print(f\"Testing model: {model_name} --- inference time = {inference_sec:.3} sec\")\n\n        # Measure the time it takes to run inference from a cold start\n        inference_warm_start_time = time.time()\n        for i in range(0, repeat_infer):\n            inputs = tokenizer(\"Hello, world!\", return_tensors=\"pt\").to(device)\n            outputs = model(**inputs)\n        inference_warm_end_time = time.time()\n        inference_warm_sec = (inference_warm_end_time - inference_warm_start_time) / float(repeat_infer)\n\n        print(f\"Testing model: {model_name} --- inference warm time = {inference_warm_sec:.3} sec\")\n\n        total_sec = tokenizer_download_sec + tokenizer_load_sec + model_download_sec + model_load_to_ram_sec + model_xfer_to_gpu_sec + inference_sec\n\n        print(f\"Testing model: {model_name} --- total time = {total_sec:.3} sec\")\n\n        # Add the results to the dictionary\n        results_dict[model_name] = {\n            \"tokenizer_download_sec\": tokenizer_download_sec,\n            \"tokenizer_load_sec\": tokenizer_load_sec,\n            \"model_download_sec\": model_download_sec,\n            \"model_load_to_ram_sec\": model_load_to_ram_sec,\n            \"model_main_memory_MB\": float(model_main_memory_bytes) / 1000000.0,\n            \"model_transfer_to_gpu_sec\": model_xfer_to_gpu_sec,\n            \"model_gpu_memory_MB\": float(model_gpu_memory_bytes) / 1000000.0,\n            \"inference_sec\": inference_sec,\n            \"inference_warm_sec\": inference_warm_sec,\n            \"total_sec\": total_sec\n        }\n\n        # Unload the model\n        model = None\n        torch.cuda.empty_cache()\n\n    return results_dict\n\n# Define the main function\ndef main(input_file: str, output_file: str, device_name: str, repeat_infer: int):\n\n    # Load the models to test from the input JSON file\n    with open(input_file, \"r\") as f:\n        model_dict = json.load(f)\n\n    # Run the benchmark\n    results_dict = benchmark(model_dict, device_name, repeat_infer)\n\n    # Write the results to the JSON output file\n    # use a regular expression to apply formatting to floatin point\n    json_data = re.sub('\"(.*?)\":\\s*(0\\.0*\\d{2}|\\d+\\.\\d{2})\\d*(,?\\n)', '\"\\\\1\": \\\\2\\\\3',  json.dumps(results_dict, indent=4))\n    with open(output_file, 'w') as f:\n        f.write(json_data)\n\nif __name__ == \"__main__\":\n    # Create an argument parser\n    parser = argparse.ArgumentParser(description='Benchmark downloading, loading, and running an inferernce for a set of ML models.')\n    parser.add_argument('-i', '--input', required=True, help='Input JSON file containing models to be benchmark')\n    parser.add_argument('-o', '--output', required=True, help='Output JSON file with model benchmark results')\n    parser.add_argument('-d', '--device', required=False, default='cuda:0', help='Cuda device name, e.g. \"cuda:0\"')\n    parser.add_argument('-r', '--repeat-infer', required=False, default=30, help='Repeat inferrence for warm timings')\n\n    # Parse the command line arguments\n    args = parser.parse_args()\n\n    # Process the data\n    main(args.input, args.output, args.device, max(args.repeat_infer, 1))"
  },
  {
    "path": "training/README.md",
    "content": "# OpenChatKit Training\n\nThis directory contains code for training a chat model using OpenChatKit. The main training script is `finetune_GPT-NeoXT-Chat-Base-20B.sh`.\n\nTo customize training, make a copy of the script and modify the arguments.\n\n## Arguments\n\nEnvironment vars that should be set:\n```bash\nexport GLOO_SOCKET_IFNAME=lo # this interface should be consistent to `--net-interface`\nexport NCCL_SOCKET_IFNAME=lo # this interface should be consistent to `--net-interface`\nexport WANDB_NAME=gptj-test # wandb run name\n```\n\nThe following arguments should be carefully set:\n- `--model-name`: The path of model ckpt sharded by layers.\n- `--tokenizer-name`: Usually the same to `--model-name`. You can also use HF's model name.\n- `--model-type`: Indicate the model type. {gptj}. More model types will be added soon.\n- `--num-layers`: Number of Transformer layers **for each GPU**. E.g. GPT-J has 28 layers, if we use two GPUs to form a pipeline, `--num-layers` should be 14.\n- `--embedding-dim`: The hidden size of the model. GPT-J-6B is 4096. This is used to create buffers.\n- `--dist-url`: URL of rank 0 worker (master). It is the same to all workers. And this URL should be accessible by all workers. For local training (single machine multiple GPUs), this can be like `--dist-url tcp://127.0.0.1:7033`\n- `--world-size`: The total number of workers. `world-size == pipeline-group-size * data-group-size`\n- `--pipeline-group-size`: Number of GPU workers for each pipeline\n- `--data-group-size`: Number of data parallel workers. Also the number of pipelines.\n- `--net-interface`: Network interface. Should be consistent with `GLOO_SOCKET_IFNAME` and `NCCL_SOCKET_IFNAME`.\n\nThe following arguments can be tuned / changed:\n- `--train-log-backend `: How to log the training info. {print, loguru, wandb}.\n- `--optimizer`: Optimizer type. {adam, 8bit-adam} (8bit-adam requires `pip install bitsandbytes`)\n- `--load-pretrained-model`: Whether to load model weights. Usually `true`.\n- `--task-name`: The task name or the path of a `jsonl` file. For multi-task training separate task names by `,`.\n   There is an optional sampling weight after each task name, separated by `:` (default is 1.0). Sampling weights will be normalized.\n   E.g. it should be like `--task-name cot:0.1,/path_task0.jsonl:1.0,/path_task0.jsonl:1.0,/path_task0.jsonl:1.0`.\n   The number after the colon indicates the sampling weight for the task during training. For example, `cot:0.1` means the `cot` task will be sampled with a weight of 0.1.\n- `--checkpoint-path`: Path to save fine-tuned checkpoints.\n- `--checkpoint-steps`: Save ckpt every `checkpoint-steps`.\n- `--total-steps`: Total number of steps for training. (This counts all `gradient-accumulate-step`s.)\n- `--warmup-steps`: LR warmup steps.\n- `--lr`: learning rate\n- `--seq-length`: sequence length\n- `--batch-size`: batch size for each GPU device (of each gradient accumulation step).\n- `--micro-batch-size`: micro batch size for pipeline parallelism. 1 works fine.\n- `--gradient-accumulate-step`: Accumulate gradients for several steps before updating parameters. This is another way to achieve large batch sizes when GPU memory is not enough.\n\nThe following arguments usually do not change:\n- `--dp-backend`: {nccl, gloo}, default nccl.\n- `--dp-mode`: {allreduce}.\n- `--fp16`: Flag to enable FP16 mixed precision training. Should always adding it for the current impl.\n- `--pp-mode`: always `gpipe`\n- `--profiling`: {no-profiling, tidy_profiling}. `tidy_profiling` will generate profile jsons.\n\n## Adding Your Own Data to the DATASETS\n\nTo add your own data to the training process, you should create a `jsonl` file where each line is a JSON object representing a single training example. Once you have your `jsonl` file, you can include it in the `--task-name` argument with an appropriate sampling weight. For instance, if your file is located at `/path_to_your_data/your_data.jsonl` and you wish to give it a sampling weight of 0.5, you would add `/path_to_your_data/your_data.jsonl:0.5` to the `--task-name` argument.\n\nIf you have any questions or need further assistance, please refer to the [OpenDataHub](https://github.com/togethercomputer/OpenDataHub) repository or contact us through our [website](https://www.together.ai/contact).\n"
  },
  {
    "path": "training/comm/__init__.py",
    "content": ""
  },
  {
    "path": "training/comm/comm_utils.py",
    "content": "from .torch_backend import *\nfrom .nccl_backend import *\n\n_DATA_PARALLEL_COMM = None\n_DATA_PARALLEL_RANK = None\n_DATA_PARALLEL_WORLD_SIZE = None\n\n_PIPELINE_PARALLEL_COMM = None\n_PIPELINE_PARALLEL_RANK = None\n_PIPELINE_PARALLEL_WORLD_SIZE = None\n\n_TENSOR_PARALLEL_COMM = None\n_TENSOR_PARALLEL_RANK = None\n_TENSOR_PARALLEL_WORLD_SIZE = None\n\nimport threading \n\n_LOCK = threading.RLock()\n\ndef get_lock():\n    return _LOCK\n\ndef get_data_parallel_comm() -> NCCLCommunicator:\n    assert _DATA_PARALLEL_COMM is not None\n    return _DATA_PARALLEL_COMM\n\n\ndef get_data_parallel_rank() -> int:\n    assert _DATA_PARALLEL_RANK is not None\n    return _DATA_PARALLEL_RANK\n\n\ndef get_data_parallel_world_size() -> int:\n    assert _DATA_PARALLEL_WORLD_SIZE is not None\n    return _DATA_PARALLEL_WORLD_SIZE\n\n\ndef get_pipeline_parallel_comm() -> NCCLCommunicator:\n    assert _PIPELINE_PARALLEL_COMM is not None\n    return _PIPELINE_PARALLEL_COMM\n\n\ndef get_pipeline_parallel_rank() -> int:\n    assert _PIPELINE_PARALLEL_RANK is not None\n    return _PIPELINE_PARALLEL_RANK\n\n\ndef get_pipeline_parallel_world_size() -> int:\n    assert _PIPELINE_PARALLEL_WORLD_SIZE is not None\n    return _PIPELINE_PARALLEL_WORLD_SIZE\n\n\ndef get_megatron_tensor_parallel_comm() -> NCCLCommunicator:\n    assert _TENSOR_PARALLEL_COMM is not None\n    return _TENSOR_PARALLEL_COMM\n\n\ndef get_megatron_tensor_parallel_rank() -> int:\n    assert _TENSOR_PARALLEL_RANK is not None\n    return _TENSOR_PARALLEL_RANK\n\n\ndef get_megatron_tensor_parallel_world_size() -> int:\n    assert _TENSOR_PARALLEL_WORLD_SIZE is not None\n    return _TENSOR_PARALLEL_WORLD_SIZE\n\n\ndef default_init(args):\n    import datetime\n    import time\n    try:\n        dist.destroy_process_group()\n        # the first time will raise exception, so the following code is skipped.\n        print('destroy comm, increase port for 1. (this could cause problem)')\n        url = ':'.join(args.dist_url.split(':')[:-1])\n        port = int(args.dist_url.split(':')[-1]) + 1\n        args.dist_url = f\"{url}:{port}\"\n        print(f\"new master url: {args.dist_url}\")\n    except:\n        pass\n    dist.init_process_group(backend='gloo', timeout=datetime.timedelta(seconds=5*60), init_method=args.dist_url, world_size=args.world_size, rank=args.rank)\n    \n\ndef init_communicators(args):\n    default_init(args)\n    assert args.world_size == args.data_group_size * args.pipeline_group_size\n    if args.world_size == args.data_group_size * args.pipeline_group_size:\n        #    We do the following hard code alignment of communication groups:\n        #    Suppose there are 8 instances (world_size), and 4 data parallel groups (data_group_size is 2),\n        #    Then there would be 2 pipeline parallel groups (pipeline_group_size is 4), then the groups will look like:\n        #    pipeline parallel: <group 0: [0,1,2,3]>, <group 1: [4,5,6,7]>\n        #    data parallel: <group 0: [0,4]>, <group 1: [1,5]>, <group 2: [2,6]>, <group 3: [3,7]>\n        # assert args.world_size == args.data_group_size * args.pipeline_group_size\n        global _DATA_PARALLEL_COMM\n        global _PIPELINE_PARALLEL_COMM\n        global _DATA_PARALLEL_RANK\n        global _PIPELINE_PARALLEL_RANK\n        global _DATA_PARALLEL_WORLD_SIZE\n        global _PIPELINE_PARALLEL_WORLD_SIZE\n        # We use pipeline parallel by default.\n        _PIPELINE_PARALLEL_WORLD_SIZE = args.pipeline_group_size\n        _PIPELINE_PARALLEL_RANK = args.rank % args.pipeline_group_size\n        _PIPELINE_PARALLEL_COMM = NCCLCommunicator(_PIPELINE_PARALLEL_RANK, args.cuda_id, args.pipeline_group_size,\n                                                   \"pipeline_group_\"+str(args.rank // args.pipeline_group_size))\n        if args.data_group_size != 1:\n            _DATA_PARALLEL_WORLD_SIZE = args.data_group_size\n            _DATA_PARALLEL_RANK = args.rank // args.pipeline_group_size\n            \n            dp_backend = getattr(args, 'dp_backend', 'gloo')\n            if dp_backend == 'nccl':\n            \n                _DATA_PARALLEL_COMM = NCCLCommunicator(_DATA_PARALLEL_RANK, args.cuda_id, args.data_group_size,\n                                                       \"data_group_\"+str(args.rank % args.pipeline_group_size))\n            \n            elif dp_backend == 'gloo':\n                \n                for i in range(args.pipeline_group_size):\n                    ranks = [rank for rank in range(i, args.world_size, args.pipeline_group_size)]\n                    print(args.rank, ranks)\n                    data_group = torch.distributed.new_group(ranks, backend='gloo')\n                    if args.rank in ranks:\n                        def to_global_rank(dp_rank):\n                            rank = _PIPELINE_PARALLEL_RANK + dp_rank * args.pipeline_group_size\n                            # print(f\"{dp_rank} --> {rank}\")\n                            return rank\n                        _DATA_PARALLEL_COMM = TorchCommunicator(\n                            data_group, to_global_rank=to_global_rank, \n                            dp_rank=_DATA_PARALLEL_RANK,\n                            comm_group_size=args.data_group_size,)\n            \n            else:\n                assert False\n            \n        print('comm init done!!')\n            \n    # elif args.world_size == args.data_group_size * args.tensor_group_size:\n    #    global _DATA_PARALLEL_COMM\n    #    global _TENSOR_PARALLEL_COMM\n    #    global _DATA_PARALLEL_RANK\n    #    global _TENSOR_PARALLEL_RANK\n    #    global _DATA_PARALLEL_WORLD_SIZE\n    #    global _TENSOR_PARALLEL_WORLD_SIZE\n        # We use megatron tensor parallel by default.\n    #    _TENSOR_PARALLEL_WORLD_SIZE = args.tensor_group_size\n    #    _TENSOR_PARALLEL_RANK = args.rank % args.tensor_group_size\n    #    _TENSOR_PARALLEL_COMM = NCCLCommunicator(_TENSOR_PARALLEL_RANK, args.cuda_id, args.tensor_group_size,\n    #                                             \"tensor_group_\" + str(args.rank // args.tensor_group_size))\n    #    if args.data_group_size != 1:\n    #        _DATA_PARALLEL_WORLD_SIZE = args.data_group_size\n    #        _DATA_PARALLEL_RANK = args.rank // args.tensor_group_size\n    #        _DATA_PARALLEL_COMM = NCCLCommunicator(_DATA_PARALLEL_RANK, args.cuda_id, args.data_group_size,\n    #                                              \"data_group_\" + str(args.rank % args.tensor_group_size))\n    else:\n        print(\"Not supported yet\")\n        assert False\n\n        \n        \ndef reinit_dp_communicator(args):\n    \n    print('###### reinit start #######')\n    \n    default_init(args)\n    assert args.world_size == args.data_group_size * args.pipeline_group_size\n    if args.world_size == args.data_group_size * args.pipeline_group_size:\n        #    We do the following hard code alignment of communication groups:\n        #    Suppose there are 8 instances (world_size), and 4 data parallel groups (data_group_size is 2),\n        #    Then there would be 2 pipeline parallel groups (pipeline_group_size is 4), then the groups will look like:\n        #    pipeline parallel: <group 0: [0,1,2,3]>, <group 1: [4,5,6,7]>\n        #    data parallel: <group 0: [0,4]>, <group 1: [1,5]>, <group 2: [2,6]>, <group 3: [3,7]>\n        # assert args.world_size == args.data_group_size * args.pipeline_group_size\n        global _DATA_PARALLEL_COMM\n        global _PIPELINE_PARALLEL_COMM\n        global _DATA_PARALLEL_RANK\n        global _PIPELINE_PARALLEL_RANK\n        global _DATA_PARALLEL_WORLD_SIZE\n        global _PIPELINE_PARALLEL_WORLD_SIZE\n        \n        if args.data_group_size != 1:\n            \n            dp_backend = getattr(args, 'dp_backend', 'gloo')\n            if dp_backend == 'nccl':\n            \n                raise Exception('NCCL cannot reinit.')\n            \n            elif dp_backend == 'gloo':\n                \n                for i in range(args.pipeline_group_size):\n                    ranks = [rank for rank in range(i, args.world_size, args.pipeline_group_size)]\n                    print(args.rank, ranks)\n                    data_group = torch.distributed.new_group(ranks, backend='gloo')\n                    if args.rank in ranks:\n                        def to_global_rank(dp_rank):\n                            rank = _PIPELINE_PARALLEL_RANK + dp_rank * args.pipeline_group_size\n                            # print(f\"{dp_rank} --> {rank}\")\n                            return rank\n                        _DATA_PARALLEL_COMM = TorchCommunicator(\n                            data_group, to_global_rank=to_global_rank, \n                            dp_rank=_DATA_PARALLEL_RANK,\n                            comm_group_size=args.data_group_size,)\n            \n            else:\n                assert False\n            \n        print('######## dp comm reinit done!! ########')"
  },
  {
    "path": "training/comm/nccl_backend.py",
    "content": "import torch\nimport numpy as np\nimport cupy\nimport cupy.cuda.nccl\nimport torch.distributed as dist\nfrom typing import List\n\n\ndef _type_torch_to_cupy(torch_type: torch.dtype):\n    # print(torch_type)\n    mappings = {\n        torch.uint8: cupy.cuda.nccl.NCCL_UINT8,\n        torch.int32: cupy.cuda.nccl.NCCL_INT32,\n        torch.int64: cupy.cuda.nccl.NCCL_INT64,\n        torch.int: cupy.cuda.nccl.NCCL_INT,\n        torch.float16: cupy.cuda.nccl.NCCL_FLOAT16,\n        torch.float32: cupy.cuda.nccl.NCCL_FLOAT32,\n        torch.float64: cupy.cuda.nccl.NCCL_FLOAT64,\n        torch.float: cupy.cuda.nccl.NCCL_FLOAT\n    }\n    return mappings[torch_type]\n\n\nclass NCCLCommunicator:\n    def __init__(self,\n                 comm_rank: int,\n                 cuda_id: int,\n                 comm_group_size: int,\n                 comm_name: str):\n        self.comm_rank = comm_rank\n        cupy.cuda.Device(cuda_id).use()\n        self.comm_group_size = comm_group_size\n        print(\"Initialize NCCLCommunicator: <\", comm_name, \">; rank:\", comm_rank)\n        self.dist_store = dist.distributed_c10d._get_default_store()\n\n        if self.comm_rank == 0:\n            cuda_id = cupy.cuda.nccl.get_unique_id()\n            # print(cuda_id)\n            cuda_id_str = np.array(cuda_id).tobytes()\n            self.dist_store.set('group-'+comm_name+'-unique-id', cuda_id_str)\n            # print(\"Master put <group-\"+comm_name+\"-unique-id: \", cuda_id_str, \">.\")\n        else:\n            cuda_id_str = self.dist_store.get('group-'+comm_name+'-unique-id')\n\n        comm_id = tuple(np.frombuffer(cuda_id_str, dtype=int))\n        # comm_id = cupy.cuda.nccl.get_unique_id()\n        # print(comm_id)\n        self.comm = cupy.cuda.nccl.NcclCommunicator(comm_group_size, comm_id, comm_rank)\n\n    @staticmethod\n    def barrier():\n        dist.barrier()\n\n    def store_set(self, key, value):\n        self.dist_store.set(key, value)\n\n    def store_get(self, key):\n        return self.dist_store.get(key)\n\n    def send(self,\n             tensor: torch.Tensor,\n             dst: int,\n             stream=cupy.cuda.Stream.null):\n        # print(\"Send tensor of size:\", torch.numel(tensor))\n        self.comm.send(\n            tensor.data_ptr(),\n            torch.numel(tensor),\n            _type_torch_to_cupy(tensor.dtype),\n            dst,\n            stream.ptr\n        )\n\n    def recv(self,\n             tensor: torch.Tensor,\n             src: int,\n             stream=cupy.cuda.Stream.null):\n        # print(\"Recv tensor of size:\", torch.numel(tensor))\n        # print(\"mean:\", torch.mean(tensor).item(), \" std:\", torch.std(tensor).item())\n        self.comm.recv(\n            tensor.data_ptr(),\n            torch.numel(tensor),\n            _type_torch_to_cupy(tensor.dtype),\n            src,\n            stream.ptr\n        )\n\n    def broadcast(self,\n                  tensor: torch.Tensor,\n                  src: int,\n                  stream=cupy.cuda.Stream.null):\n        self.comm.bcast(\n            tensor.data_ptr(),\n            torch.numel(tensor),\n            _type_torch_to_cupy(tensor.dtype),\n            src,\n            stream.ptr\n        )\n\n    def reduce(self,\n               tensor: torch.Tensor,\n               dst: int,\n               stream=cupy.cuda.Stream.null,\n               op=cupy.cuda.nccl.NCCL_SUM):\n        self.comm.reduce(\n            tensor.data_ptr(),  # force it to be in-place.\n            tensor.data_ptr(),\n            torch.numel(tensor),\n            _type_torch_to_cupy(tensor.dtype),\n            op,\n            dst,\n            stream.ptr\n        )\n\n    def all_reduce(self,\n                  tensor: torch.Tensor,\n                  stream=cupy.cuda.Stream.null,\n                  op=cupy.cuda.nccl.NCCL_SUM):\n        self.comm.allReduce(\n            tensor.data_ptr(),\n            tensor.data_ptr(),\n            torch.numel(tensor),\n            _type_torch_to_cupy(tensor.dtype),\n            op,\n            stream.ptr\n        )\n\n    def scatter(self,\n                tensor: torch.Tensor,\n                scatter_list: List[torch.Tensor],\n                src: int,\n                stream=cupy.cuda.Stream.null):\n        cupy.cuda.nccl.groupStart()\n        if self.comm_rank == src:\n            for i in range(self.comm_group_size):\n                self.send(\n                    scatter_list[i],\n                    i,\n                    stream\n                )\n        self.recv(\n            tensor,\n            src,\n            stream\n        )\n        cupy.cuda.nccl.groupEnd()\n\n    def gather(self,\n               tensor: torch.Tensor,\n               gather_list: List[torch.Tensor],\n               dst: int,\n               stream=cupy.cuda.Stream.null):\n        cupy.cuda.nccl.groupStart()\n        if self.comm_rank == dst:\n            for i in range(self.comm_group_size):\n                self.recv(\n                    gather_list[i],\n                    i,\n                    stream\n                )\n        self.send(\n            tensor,\n            dst,\n            stream\n        )\n        cupy.cuda.nccl.groupEnd()\n\n    def all_to_all(self,\n                   output_tensor_list: List[torch.Tensor],\n                   input_tensor_list: List[torch.Tensor],\n                   stream=cupy.cuda.Stream.null):\n        assert len(output_tensor_list) == self.comm_group_size and len(input_tensor_list) == self.comm_group_size\n        cupy.cuda.nccl.groupStart()\n        for i in range(self.comm_group_size):\n            self.send(input_tensor_list[i], i, stream)\n            self.recv(output_tensor_list[i], i, stream)\n        cupy.cuda.nccl.groupEnd()\n\n    def all_gather(self,\n                   tensor: torch.Tensor,\n                   output_tensor_list: List[torch.Tensor],\n                   stream=cupy.cuda.Stream.null\n                   ):\n        assert len(output_tensor_list) == self.comm_group_size\n        cupy.cuda.nccl.groupStart()\n        for i in range(self.comm_group_size):\n            self.send(tensor, i, stream)\n            self.recv(output_tensor_list[i], i, stream)\n        cupy.cuda.nccl.groupEnd()\n\n    def all_reduce_opt(self,\n                       tensor: torch.Tensor,\n                       buffer: List[torch.Tensor],\n                       stream=cupy.cuda.Stream.null,\n                       caller=None):\n        # First do all-to-all\n        assert torch.numel(tensor.data) % self.comm_group_size == 0\n        chunk_size = torch.numel(tensor.data) // self.comm_group_size\n        t_type = _type_torch_to_cupy(tensor.dtype)\n        element_size = tensor.data.element_size()\n        \n        cupy.cuda.nccl.groupStart()\n        for i in range(self.comm_group_size):\n            self.comm.send(tensor.data_ptr()+i*chunk_size*element_size, chunk_size, t_type, i, stream.ptr)\n            self.comm.recv(buffer[i].data_ptr(), chunk_size, t_type, i, stream.ptr)\n        cupy.cuda.nccl.groupEnd()\n        \n        for i in range(1, self.comm_group_size):\n            buffer[0] += buffer[i]\n\n        cupy.cuda.nccl.groupStart()\n        for i in range(self.comm_group_size):\n            self.comm.send(buffer[0].data_ptr(), chunk_size, t_type, i, stream.ptr)\n            self.comm.recv(tensor.data_ptr()+i*chunk_size*element_size, chunk_size, t_type, i, stream.ptr)\n        cupy.cuda.nccl.groupEnd()\n\n"
  },
  {
    "path": "training/comm/torch_backend.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom typing import List\n\nclass TorchCommunicator:\n        \n    def __init__(self,\n                 process_group,\n                 to_global_rank=lambda rank: rank,\n                 dp_rank=None,\n                 comm_group_size=None,):\n        self.process_group = process_group\n        self.to_global_rank = to_global_rank\n        self.dp_rank = dp_rank\n        self.comm_group_size = comm_group_size\n\n    # @staticmethod\n    def barrier(self):\n        dist.barrier(group=self.process_group)\n\n    def send(self,\n             tensor: torch.Tensor,\n             dst: int,\n             stream=None):\n        # print(\"Send tensor of size:\", torch.numel(tensor))\n        if tensor.device == torch.device('cpu'):\n            dist.send(tensor, self.to_global_rank(dst), group=self.process_group)\n        else:\n            dist.send(tensor.cpu(), self.to_global_rank(dst), group=self.process_group)\n            \n    def recv(self,\n             tensor: torch.Tensor,\n             src: int,\n             stream=None):\n        \n        if tensor.device == torch.device('cpu'):\n            dist.recv(tensor, self.to_global_rank(src), group=self.process_group)\n        else:\n            buffer = tensor.cpu()\n            dist.recv(buffer, self.to_global_rank(src), group=self.process_group)\n            tensor[:] = buffer.to(tensor.device)\n    \n    def isend(self,\n             tensor: torch.Tensor,\n             dst: int,\n             stream=None):\n        # print(\"Send tensor of size:\", torch.numel(tensor))\n        if tensor.device == torch.device('cpu'):\n            handler = dist.isend(tensor, self.to_global_rank(dst), group=self.process_group)\n        else:\n            handler = dist.isend(tensor.cpu(), self.to_global_rank(dst), group=self.process_group)\n        return handler\n\n    def irecv(self,\n             tensor: torch.Tensor,\n             src: int,\n             stream=None):\n        if tensor.device == torch.device('cpu'):\n            handler = dist.irecv(tensor, self.to_global_rank(src), group=self.process_group)\n        else:\n            assert False\n            buffer = tensor.cpu()\n            handler = dist.irecv(buffer, self.to_global_rank(src), group=self.process_group)\n            tensor[:] = buffer.to(tensor.device)\n        return handler\n\n    def broadcast(self,\n                  tensor: torch.Tensor,\n                  src: int,\n                  stream=None):\n        if tensor.device == torch.device('cpu'):\n            dist.broadcast(tensor, self.to_global_rank(src), group=self.process_group)\n        else:\n            buffer = tensor.cpu()\n            dist.broadcast(buffer, self.to_global_rank(src), group=self.process_group)\n            tensor[:] = buffer.to(tensor.device)\n\n    def reduce(self,\n               tensor: torch.Tensor,\n               dst: int,\n               stream=None,\n               op=dist.ReduceOp.SUM):\n        dist.reduce(tensor, self.to_global_rank(dst), group=self.process_group, op=op)\n\n    def all_reduce(self,\n                   tensor: torch.Tensor,\n                   stream = None,\n                   op=dist.ReduceOp.SUM):\n        buffer = tensor.cpu()\n        dist.all_reduce(buffer, group=self.process_group, op=op)\n        tensor[:] = buffer.to(tensor.device)\n\n    def gather(self,\n               tensor: torch.Tensor,\n               gather_list: List[torch.Tensor],\n               dst: int,\n               stream=None):\n        dist.gather(tensor, gather_list, self.to_global_rank(dst), group=self.process_group)\n\n    def all_to_all(self,\n                   output_tensor_list: List[torch.Tensor],\n                   input_tensor_list: List[torch.Tensor],\n                   stream=None):\n        dist.all_to_all(output_tensor_list, input_tensor_list, group=self.process_group)\n\n    def all_gather(self,\n                   tensor: torch.Tensor,\n                   output_tensor_list: List[torch.Tensor],\n                   stream=None):\n        dist.all_gather(output_tensor_list, tensor, group=self.process_group)\n        \n"
  },
  {
    "path": "training/data_parallel/__init__.py",
    "content": ""
  },
  {
    "path": "training/data_parallel/dist_dp_allreduce.py",
    "content": "import torch.cuda\nfrom comm.comm_utils import *\nfrom .flatten_utils import flatten_params\n\n\nclass AllReduceDP:\n    def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim.Optimizer = None, flatten=True):\n        self.flatten = flatten\n        self.global_rank = args.rank\n        self.dp_group_size = args.data_group_size\n        self.enable_tidy_profiling = (args.profiling == 'tidy_profiling')\n        self.dp_comm = get_data_parallel_comm()\n        self.dp_rank = get_data_parallel_rank()\n        self.dp_comm_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.torch_optim_comp_stream = torch.cuda.default_stream(device=device)\n        self.backward_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.allreduce_grad_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.optimizer_step_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n\n        self.module = module\n        num_paras, element_size = self._compute_total_para_num()\n        print(\"Total number of parameters: {}, element size: {}, total size {} MB.\"\n              .format(num_paras, element_size, num_paras * element_size // 1024 // 1024))\n\n        if self.flatten:\n            self.flatten_para = flatten_params(self.module.parameters())\n            print(\"Flattened parameter number: {}, element size: {}.\"\n                  .format(self.flatten_para.data.numel(), self.flatten_para.data.element_size()))\n            print(\"Flattened parameter grad number: {}, element size: {}.\"\n                  .format(self.flatten_para.grad.numel(), self.flatten_para.grad.element_size()))\n\n        assert optimizer is not None\n        self.optimizer = optimizer\n\n        if self.enable_tidy_profiling:\n            self.global_rank = args.rank\n            self.init_event = None\n            self.init_time_stamp = None\n            if self.flatten:\n                self.allreduce_gradients_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n            else:\n                self.allreduce_gradients_start_events = dict()\n                self.allreduce_gradients_end_events = dict()\n                for name, _ in self.module.named_parameters():\n                    self.allreduce_gradients_start_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n                    self.allreduce_gradients_end_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n\n            self.optimizer_step_start_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling,\n                                                               blocking=False)\n\n    def _compute_total_para_num(self):\n        total_count = 0\n        element_size = 0\n        for para in self.module.parameters():\n            # print(\"Parameter: \", para.data.shape)\n            total_count += torch.numel(para.data)\n            element_size = para.element_size()\n        return total_count, element_size\n\n    def profile_mark_allreduce_start(self, name=None):\n        if self.enable_tidy_profiling:\n            if name is None:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_start_event)\n            else:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_start_events[name])\n\n    def profile_mark_allreduce_end(self, name=None):\n        if self.enable_tidy_profiling:\n            if name:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_end_events[name])\n\n    def profile_mark_optimizer_step_start(self):\n        if self.enable_tidy_profiling:\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_start_event)\n\n    def _allreduce_gradients(self):\n        with torch.cuda.stream(self.dp_comm_stream):\n            cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)\n            self.dp_comm_stream.wait_event(self.backward_ready_event)\n            if self.flatten:\n                self.profile_mark_allreduce_start()\n                self.dp_comm.all_reduce(self.flatten_para.grad, stream=cupy_dp_stream)\n                self.profile_mark_allreduce_end()\n            else:\n                for name, para in self.module.named_parameters():\n                    if para.grad is None:\n                        continue\n                    self.profile_mark_allreduce_start(name)\n                    self.dp_comm.all_reduce(para.grad, stream=cupy_dp_stream)\n                    self.profile_mark_allreduce_end(name)\n            self.dp_comm_stream.record_event(self.allreduce_grad_ready_event)\n\n    def optimizer_step(self):\n        self._allreduce_gradients()\n        with torch.cuda.stream(self.torch_optim_comp_stream):\n            self.torch_optim_comp_stream.wait_event(self.allreduce_grad_ready_event)\n            self.profile_mark_optimizer_step_start()\n            self.optimizer.step()\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event)\n\n    def set_time_stamp(self, init_time_stamp, init_event):\n        self.init_event = init_event\n        self.init_time_stamp = init_time_stamp\n\n    def get_ts(self, event):\n        return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3\n\n    def profiling_data_parallel(self, init_time_stamp, init_event):\n        self.set_time_stamp(init_time_stamp, init_event)\n        profiling_log = []\n\n        if self.flatten:\n            allreduce_slot = self.allreduce_gradients_start_event.elapsed_time(self.allreduce_grad_ready_event)*1e+3\n            allreduce_log = {\"name\": \"opt_allreduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                             \"ts\": self.get_ts(self.allreduce_gradients_start_event),\n                             \"dur\": allreduce_slot, \"cname\": \"cq_build_passed\",\n                             \"args\": {'para': 'flattened_grad', 'size': self.flatten_para.grad.numel()}}\n            # print(allreduce_log)\n            profiling_log.append(allreduce_log)\n        else:\n            for name, para in self.module.named_parameters():\n                allreduce_slot = self.allreduce_gradients_start_events[name].elapsed_time(\n                    self.allreduce_gradients_end_events[name]) * 1e+3\n                allreduce_log = {\"name\": \"opt_allreduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                                 \"ts\": self.get_ts(self.allreduce_gradients_start_events[name]), \"dur\": allreduce_slot,\n                                 \"cname\": \"cq_build_passed\", \"args\": {'para': name, 'size': torch.numel(para.data)}}\n                # print(allreduce_log)\n                profiling_log.append(allreduce_log)\n\n        optimizer_slot = self.optimizer_step_start_event.elapsed_time(self.optimizer_step_ready_event) * 1e+3\n        optimizer_log = {\"name\": \"opt_comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"8. optimizer-comp\",\n                         \"ts\": self.get_ts(self.optimizer_step_start_event), \"dur\": optimizer_slot, \"cname\": \"bad\"}\n        # print(optimizer_log)\n        profiling_log.append(optimizer_log)\n        return profiling_log\n"
  },
  {
    "path": "training/data_parallel/dist_dp_central_ps.py",
    "content": "import torch.cuda\nfrom comm.comm_utils import *\nfrom .flatten_utils import flatten_params\n\n\nclass CentralPSDP:\n    def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim.Optimizer = None, flatten=True):\n        self.flatten = flatten\n        self.global_rank = args.rank\n        self.dp_group_size = args.data_group_size\n        self.enable_tidy_profiling = (args.profiling == 'tidy_profiling')\n        self.dp_comm = get_data_parallel_comm()\n        self.dp_rank = get_data_parallel_rank()\n        self.dp_comm_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.torch_optim_comp_stream = torch.cuda.default_stream(device=device)\n        self.backward_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.broadcast_reduced_gradients_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling,\n                                                                        blocking=False)\n        self.optimizer_step_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n\n        self.module = module\n        num_paras, element_size = self._compute_total_para_num()\n        print(\"Total number of parameters: {}, element size: {}, total size {} MB.\"\n              .format(num_paras, element_size, num_paras * element_size // 1024 // 1024))\n\n        if self.flatten:\n            self.flatten_para = flatten_params(self.module.parameters())\n            print(\"Flattened parameter number: {}, element size: {}.\"\n                  .format(self.flatten_para.data.numel(), self.flatten_para.data.element_size()))\n            print(\"Flattened parameter grad number: {}, element size: {}.\"\n                  .format(self.flatten_para.grad.numel(), self.flatten_para.grad.element_size()))\n\n        assert optimizer is not None\n        self.optimizer = optimizer\n\n        if self.enable_tidy_profiling:\n            self.global_rank = args.rank\n            self.init_event = None\n            self.init_time_stamp = None\n            if self.flatten:\n                self.reduce_gradients_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n                self.reduce_gradients_end_event = torch.cuda.Event(enable_timing=True, blocking=False)\n                self.broadcast_reduced_grad_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n            else:\n                self.reduce_gradients_start_events = dict()\n                self.reduce_gradients_end_events = dict()\n                self.broadcast_reduced_grad_start_events = dict()\n                self.broadcast_reduced_grad_end_events = dict()\n\n                for name, _ in self.module.named_parameters():\n                    self.reduce_gradients_start_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n                    self.reduce_gradients_end_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n                    self.broadcast_reduced_grad_start_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n                    self.broadcast_reduced_grad_end_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n\n            self.optimizer_step_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n\n    def _compute_total_para_num(self):\n        total_count = 0\n        element_size = 0\n        for para in self.module.parameters():\n            # print(\"Parameter: \", para.data.shape)\n            total_count += torch.numel(para.data)\n            element_size = para.element_size()\n        return total_count, element_size\n    \n    def profile_mark_reduce_start(self, name=None):\n        if self.enable_tidy_profiling:\n            if name is None:\n                self.dp_comm_stream.record_event(self.reduce_gradients_start_event)\n            else:\n                self.dp_comm_stream.record_event(self.reduce_gradients_start_events[name])\n\n    def profile_mark_reduce_end(self, name=None):\n        if self.enable_tidy_profiling:\n            if name is None:\n                self.dp_comm_stream.record_event(self.reduce_gradients_end_event)\n            else:\n                self.dp_comm_stream.record_event(self.reduce_gradients_end_events[name])\n\n    def profile_mark_optimizer_step_start(self):\n        if self.enable_tidy_profiling:\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_start_event)\n            \n    def profile_mark_broadcast_start(self, name=None):\n        if self.enable_tidy_profiling:\n            if name is None:\n                self.dp_comm_stream.record_event(self.broadcast_reduced_grad_start_event)\n            else:\n                self.dp_comm_stream.record_event(self.broadcast_reduced_grad_start_events[name])\n            \n    def profile_mark_broadcast_end(self, name=None):\n        if self.enable_tidy_profiling:\n            if name:\n                self.dp_comm_stream.record_event(self.broadcast_reduced_grad_end_events[name])\n\n    def _reduce_gradients(self):\n        with torch.cuda.stream(self.dp_comm_stream):\n            cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)\n            self.dp_comm_stream.wait_event(self.backward_ready_event)\n            if self.flatten:\n                self.profile_mark_reduce_start()\n                self.dp_comm.reduce(self.flatten_para.grad, dst=0, stream=cupy_dp_stream)\n                self.profile_mark_reduce_end()\n            else:\n                for name, para in self.module.named_parameters():\n                    self.profile_mark_reduce_start(name)\n                    self.dp_comm.reduce(para.grad, dst=0, stream=cupy_dp_stream)\n                    self.profile_mark_reduce_end(name)\n\n    def _broadcast_reduced_gradients(self):\n        with torch.cuda.stream(self.dp_comm_stream):\n            cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)\n            if self.flatten:\n                self.profile_mark_broadcast_start()\n                self.dp_comm.broadcast(self.flatten_para.grad, src=0, stream=cupy_dp_stream)\n                self.profile_mark_broadcast_end()\n            else:\n                for name, para in self.module.named_parameters():\n                    self.profile_mark_broadcast_start(name)\n                    self.dp_comm.broadcast(para.grad, src=0, stream=cupy_dp_stream)\n                    self.profile_mark_broadcast_end(name)\n            self.dp_comm_stream.record_event(self.broadcast_reduced_gradients_ready_event)\n\n    def optimizer_step(self):\n        self._reduce_gradients()\n        self._broadcast_reduced_gradients()\n        with torch.cuda.stream(self.torch_optim_comp_stream):\n            self.torch_optim_comp_stream.wait_event(self.broadcast_reduced_gradients_ready_event)\n            self.profile_mark_optimizer_step_start()\n            self.optimizer.step()\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event)\n\n    def set_time_stamp(self, init_time_stamp, init_event):\n        self.init_event = init_event\n        self.init_time_stamp = init_time_stamp\n\n    def get_ts(self, event):\n        return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3\n\n    def profiling_data_parallel(self, init_time_stamp, init_event):\n        self.set_time_stamp(init_time_stamp, init_event)\n        profiling_log = []\n        if self.flatten:\n            reduce_slot = self.reduce_gradients_start_event.elapsed_time(self.reduce_gradients_end_event) * 1e+3\n            reduce_log = {\"name\": \"opt_reduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                          \"ts\": self.get_ts(self.reduce_gradients_start_event),\n                          \"dur\": reduce_slot, \"cname\": \"cq_build_passed\",\n                          \"args\": {'para': 'flattened_grad', 'size': self.flatten_para.grad.numel()}}\n            # print(reduce_log)\n            profiling_log.append(reduce_log)\n        else:\n            for name, para in self.module.named_parameters():\n                reduce_slot = self.reduce_gradients_start_events[name].elapsed_time(\n                    self.reduce_gradients_end_events[name]) * 1e+3\n                reduce_log = {\"name\": \"opt_reduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                              \"ts\": self.get_ts(self.reduce_gradients_start_events[name]), \"dur\": reduce_slot,\n                              \"cname\": \"cq_build_passed\", \"args\": {'para': name, 'size': torch.numel(para.data)}}\n                # print(reduce_log)\n                profiling_log.append(reduce_log)\n\n        optimizer_slot = self.optimizer_step_start_event.elapsed_time(self.optimizer_step_ready_event) * 1e+3\n        optimizer_log = {\"name\": \"opt_comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"8. optimizer-comp\",\n                         \"ts\": self.get_ts(self.optimizer_step_start_event), \"dur\": optimizer_slot, \"cname\": \"bad\"}\n        # print(optimizer_log)\n        profiling_log.append(optimizer_log)\n\n        if self.flatten:\n            broadcast_slot = self.broadcast_reduced_grad_start_event.elapsed_time(\n                self.broadcast_reduced_gradients_ready_event) * 1e+3\n            broadcast_log = {\"name\": \"opt_broadcast\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                             \"ts\": self.get_ts(self.broadcast_reduced_grad_start_event),\n                             \"dur\": broadcast_slot, \"cname\": \"cq_build_passed\",\n                             \"args\": {'para': 'flattened_grad', 'size': self.flatten_para.grad.numel()}}\n            profiling_log.append(broadcast_log)\n        else:\n            for name, para in self.module.named_parameters():\n                broadcast_slot = self.broadcast_reduced_grad_start_events[name].elapsed_time(\n                    self.broadcast_reduced_grad_end_events[name]) * 1e+3\n                broadcast_log = {\"name\": \"opt_broadcast\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                                 \"ts\": self.get_ts(self.broadcast_reduced_grad_start_events[name]), \"dur\": broadcast_slot,\n                                 \"cname\": \"cq_build_passed\", \"args\": {'para': name, 'size': torch.numel(para.data)}}\n                # print(broadcast_log)\n                profiling_log.append(broadcast_log)\n        return profiling_log\n"
  },
  {
    "path": "training/data_parallel/dist_dp_local.py",
    "content": "import torch.cuda\nimport cupy\nfrom comm.comm_utils import *\nfrom .flatten_utils import flatten_params\n\n\nclass LocalDP:\n    def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim.Optimizer = None, flatten=True):\n        flatten = True\n        self.flatten = flatten\n        self.global_rank = args.rank\n        self.dp_group_size = args.data_group_size\n        self.enable_tidy_profiling = (args.profiling == 'tidy_profiling')\n        self.dp_comm = get_data_parallel_comm()\n        self.dp_rank = get_data_parallel_rank()\n        self.dp_comm_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.torch_optim_comp_stream = torch.cuda.default_stream(device=device)\n        self.backward_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.allreduce_gradients_start_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.allreduce_grad_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.optimizer_step_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n\n        self.module = module\n        num_paras, element_size = self._compute_total_para_num()\n        print(\"Total number of parameters: {}, element size: {}, total size {} MB.\"\n              .format(num_paras, element_size, num_paras * element_size // 1024 // 1024))\n\n        if self.flatten:\n            self.flatten_para = flatten_params(self.module.parameters())\n            print(\"Flattened parameter number: {}, element size: {}.\"\n                  .format(self.flatten_para.data.numel(), self.flatten_para.data.element_size()))\n            print(\"Flattened parameter grad number: {}, element size: {}.\"\n                  .format(self.flatten_para.grad.numel(), self.flatten_para.grad.element_size()))\n\n        assert optimizer is not None\n        self.optimizer = optimizer\n\n        if self.enable_tidy_profiling:\n            self.global_rank = args.rank\n            self.init_event = None\n            self.init_time_stamp = None\n            if self.flatten:\n                self.allreduce_gradients_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n            else:\n                self.allreduce_gradients_start_events = dict()\n                self.allreduce_gradients_end_events = dict()\n                for name, _ in self.module.named_parameters():\n                    self.allreduce_gradients_start_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n                    self.allreduce_gradients_end_events[name] = torch.cuda.Event(enable_timing=True, blocking=False)\n\n            self.optimizer_step_start_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling,\n                                                               blocking=False)\n\n    def _compute_total_para_num(self):\n        total_count = 0\n        element_size = 0\n        for para in self.module.parameters():\n            # print(\"Parameter: \", para.data.shape)\n            total_count += torch.numel(para.data)\n            element_size = para.element_size()\n        return total_count, element_size\n\n    def profile_mark_allreduce_start(self, name=None):\n        if self.enable_tidy_profiling:\n            if name is None:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_start_event)\n            else:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_start_events[name])\n\n    def profile_mark_allreduce_end(self, name=None):\n        if self.enable_tidy_profiling:\n            if name:\n                self.dp_comm_stream.record_event(self.allreduce_gradients_end_events[name])\n\n    def profile_mark_optimizer_step_start(self):\n        if self.enable_tidy_profiling:\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_start_event)\n            \n    def allreduce_parameters(self):\n        self._local_parameters_backup = [\n            p.data.clone() for p in self.module.parameters()\n        ]\n        torch.cuda.synchronize()\n        self.dp_comm.barrier()\n        with torch.cuda.stream(self.dp_comm_stream):\n            cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)\n            self.dp_comm_stream.wait_event(self.backward_ready_event)\n            if self.flatten:\n                self.profile_mark_allreduce_start()\n                self.dp_comm.all_reduce(self.flatten_para.data, stream=cupy_dp_stream)\n                self.flatten_para.data /= self.dp_group_size\n                self.profile_mark_allreduce_end()\n            else:\n                for name, para in self.module.named_parameters():\n                    self.profile_mark_allreduce_start(name)\n                    self.dp_comm.all_reduce(para.data, stream=cupy_dp_stream)\n                    para.data /= self.dp_group_size\n                    self.profile_mark_allreduce_end(name)\n            self.dp_comm_stream.record_event(self.allreduce_grad_ready_event)\n        torch.cuda.synchronize()\n        self.dp_comm.barrier()\n\n    def rollback_parameters(self):\n        if not hasattr(self, '_local_parameters_backup'):\n            return\n        \n        for p, p_local in zip(self.module.parameters(), self._local_parameters_backup):\n            p.data[:] = p_local.data\n            \n        del self._local_parameters_backup\n            \n\n    def optimizer_step(self):\n        # torch.cuda.synchronize()\n        with torch.cuda.stream(self.torch_optim_comp_stream):\n            self.torch_optim_comp_stream.record_event(self.allreduce_gradients_start_event)\n            self.torch_optim_comp_stream.record_event(self.allreduce_grad_ready_event)\n            self.torch_optim_comp_stream.wait_event(self.backward_ready_event)\n            self.profile_mark_optimizer_step_start()\n            self.optimizer.step()\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event)\n\n    def set_time_stamp(self, init_time_stamp, init_event):\n        self.init_event = init_event\n        self.init_time_stamp = init_time_stamp\n\n    def get_ts(self, event):\n        return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3\n\n    def profiling_data_parallel(self, init_time_stamp, init_event):\n        self.set_time_stamp(init_time_stamp, init_event)\n        profiling_log = []\n\n        if self.flatten:\n            allreduce_slot = self.allreduce_gradients_start_event.elapsed_time(self.allreduce_grad_ready_event)*1e+3\n            allreduce_log = {\"name\": \"opt_allreduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                             \"ts\": self.get_ts(self.allreduce_gradients_start_event),\n                             \"dur\": allreduce_slot, \"cname\": \"cq_build_passed\",\n                             \"args\": {'para': 'flattened_grad', 'size': self.flatten_para.grad.numel()}}\n            # print(allreduce_log)\n            profiling_log.append(allreduce_log)\n        else:\n            for name, para in self.module.named_parameters():\n                allreduce_slot = self.allreduce_gradients_start_events[name].elapsed_time(\n                    self.allreduce_gradients_end_events[name]) * 1e+3\n                allreduce_log = {\"name\": \"opt_allreduce\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                                 \"ts\": self.get_ts(self.allreduce_gradients_start_events[name]), \"dur\": allreduce_slot,\n                                 \"cname\": \"cq_build_passed\", \"args\": {'para': name, 'size': torch.numel(para.data)}}\n                # print(allreduce_log)\n                profiling_log.append(allreduce_log)\n\n        optimizer_slot = self.optimizer_step_start_event.elapsed_time(self.optimizer_step_ready_event) * 1e+3\n        optimizer_log = {\"name\": \"opt_comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"8. optimizer-comp\",\n                         \"ts\": self.get_ts(self.optimizer_step_start_event), \"dur\": optimizer_slot, \"cname\": \"bad\"}\n        # print(optimizer_log)\n        profiling_log.append(optimizer_log)\n        return profiling_log\n"
  },
  {
    "path": "training/data_parallel/dist_dp_sharded_ps.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.cuda\nfrom comm.comm_utils import *\nfrom .flatten_utils import flatten_params\n\n\nclass ShardedPSDP:\n    def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim.Optimizer = None, flatten=True):\n        self.flatten = flatten\n        self.global_rank = args.rank\n        self.dp_group_size = args.data_group_size\n        self.enable_tidy_profiling = (args.profiling == 'tidy_profiling')\n        self.dp_comm = get_data_parallel_comm()\n        self.dp_rank = get_data_parallel_rank()\n        self.dp_comm_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.torch_optim_comp_stream = torch.cuda.default_stream(device=device)\n        self.backward_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.sync_gradients_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n        self.optimizer_step_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n\n        self.module = module\n        assert optimizer is not None\n        self.optimizer = optimizer\n        num_paras, element_size = self._compute_total_para_num()\n        print(\"Total number of parameters: {}, element size: {}, total size {} MB.\"\n              .format(num_paras, element_size, num_paras * element_size // 1024 // 1024))\n\n        assert self.flatten\n#         self.para = list(self.module.parameters())\n        self.flatten_para = flatten_params(self.module.parameters(), self.dp_group_size)\n        print(\"Flattened parameter number: {}, element size: {}.\"\n              .format(self.flatten_para.data.numel(), self.flatten_para.data.element_size()))\n        print(\"Flattened parameter grad number: {}, element size: {}.\"\n              .format(self.flatten_para.grad.numel(), self.flatten_para.grad.element_size()))\n\n        self.grad_buffer = self._declare_grad_buffer()\n\n        if self.enable_tidy_profiling:\n            self.global_rank = args.rank\n            self.init_event = None\n            self.init_time_stamp = None\n\n            assert self.flatten\n            self.sync_gradients_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n\n            self.optimizer_step_start_event = torch.cuda.Event(enable_timing=True, blocking=False)\n\n    def _compute_total_para_num(self):\n        total_count = 0\n        element_size = 0\n        for para in self.module.parameters():\n            # print(\"Parameter: \", para.data.shape)\n            total_count += torch.numel(para.data)\n            element_size = para.element_size()\n        return total_count, element_size\n\n    def _declare_grad_buffer(self):\n        assert self.flatten_para.data.numel() % self.dp_group_size == 0\n        chunk_size = self.flatten_para.data.numel() // self.dp_group_size\n        grad_buffer = [torch.zeros(chunk_size, device=self.flatten_para.device, dtype=self.flatten_para.dtype)\n                       for _ in range(self.dp_group_size)]\n        return grad_buffer\n\n    def profile_mark_sync_grad_start(self):\n        if self.enable_tidy_profiling:\n            self.dp_comm_stream.record_event(self.sync_gradients_start_event)\n\n    def profile_mark_allreduce_end(self):\n        pass\n\n    def profile_mark_optimizer_step_start(self):\n        if self.enable_tidy_profiling:\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_start_event)\n\n    def _sync_gradients(self):\n        with torch.cuda.stream(self.dp_comm_stream):\n            cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)\n            self.dp_comm_stream.wait_event(self.backward_ready_event)\n            assert self.flatten\n            self.profile_mark_sync_grad_start()\n            self.dp_comm.all_reduce_opt(self.flatten_para.grad, self.grad_buffer, stream=cupy_dp_stream)\n            self.profile_mark_allreduce_end()\n            self.dp_comm_stream.record_event(self.sync_gradients_ready_event)\n\n    def optimizer_step(self):\n        self._sync_gradients()\n        with torch.cuda.stream(self.torch_optim_comp_stream):\n            self.torch_optim_comp_stream.wait_event(self.sync_gradients_ready_event)\n            self.profile_mark_optimizer_step_start()\n            self.optimizer.step()\n            self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event)\n\n    def set_time_stamp(self, init_time_stamp, init_event):\n        self.init_event = init_event\n        self.init_time_stamp = init_time_stamp\n\n    def get_ts(self, event):\n        return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3\n\n    def profiling_data_parallel(self, init_time_stamp, init_event):\n        self.set_time_stamp(init_time_stamp, init_event)\n        profiling_log = []\n\n        assert self.flatten\n        allreduce_slot = self.sync_gradients_start_event.elapsed_time(self.sync_gradients_ready_event)*1e+3\n        allreduce_log = {\"name\": \"opt_shardedPS_sync\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-comm\",\n                         \"ts\": self.get_ts(self.sync_gradients_start_event),\n                         \"dur\": allreduce_slot, \"cname\": \"cq_build_passed\",\n                         \"args\": {'para': 'flattened_grad', 'size': self.flatten_para.grad.numel()}}\n        # print(allreduce_log)\n        profiling_log.append(allreduce_log)\n\n        optimizer_slot = self.optimizer_step_start_event.elapsed_time(self.optimizer_step_ready_event) * 1e+3\n        optimizer_log = {\"name\": \"opt_comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"8. optimizer-comp\",\n                         \"ts\": self.get_ts(self.optimizer_step_start_event), \"dur\": optimizer_slot, \"cname\": \"bad\"}\n        # print(optimizer_log)\n        profiling_log.append(optimizer_log)\n        return profiling_log\n"
  },
  {
    "path": "training/data_parallel/dist_dp_utils.py",
    "content": "from .dist_dp_allreduce import AllReduceDP\nfrom .dist_dp_sharded_ps import ShardedPSDP\nfrom .dist_dp_local import LocalDP\n\n\ndef get_dp_module(args, device, module, optimizer):\n    print(\"Data parallel implementation: \", args.dp_mode)\n    if args.dp_mode == 'allreduce':\n        return AllReduceDP(args, device, module, optimizer, flatten=False) \n        # flatten gradient is not compatible with fp16 now\n    elif args.dp_mode == 'local':\n        return LocalDP(args, device, module, optimizer, flatten=False)\n    elif args.dp_mode == 'sharded_ps':\n        return ShardedPSDP(args, device, module, optimizer, flatten=False)\n    else:\n        print(\"Not recognize this data parallel mode.\")\n        assert False\n"
  },
  {
    "path": "training/data_parallel/flatten_utils.py",
    "content": "import torch\n\n\ndef _assert_contiguous(tensors):\n    data_ptr = None\n    for t in tensors:\n        if data_ptr is not None:\n            assert t.data_ptr() == data_ptr\n        data_ptr = t.data_ptr() + t.numel() * t.element_size()\n\n\ndef flatten_params(param_set, chunk=None):\n    params = [p for p in param_set]\n    weights = [p.data for p in params]\n    grads = [p.grad.data if p.grad is not None else torch.zeros_like(p.data) for p in params]\n    sizes = [p.numel() for p in params]\n    total_size = sum(sizes)\n    if chunk:\n        total_size = ((total_size+chunk-1)//chunk)*chunk\n\n    flatten_weights_tensor = torch.zeros(total_size, dtype=weights[0].dtype).to(weights[0].device)\n    flatten_grads_tensor = torch.zeros(total_size, dtype=weights[0].dtype).to(weights[0].device)\n    flatten_weights_storage = flatten_weights_tensor.storage()\n    flatten_grads_storage = flatten_grads_tensor.storage()\n\n    def set_storage(param, weight_storage, grad_storage, storage_offset):\n        with torch.no_grad():\n            z = torch.zeros_like(param.data)\n            z.set_(weight_storage, storage_offset, param.shape)\n            param.data = z\n\n            t = torch.zeros_like(param.data)\n            t.set_(grad_storage, storage_offset, param.shape)\n            param.grad = t\n\n    offset = 0\n    for i in range(len(params)):\n        flatten_weights_tensor[offset: offset + weights[i].numel()] = weights[i].reshape(-1)\n        flatten_grads_tensor[offset: offset + grads[i].numel()] = grads[i].reshape(-1)\n        set_storage(params[i], flatten_weights_storage, flatten_grads_storage, offset)\n        offset += sizes[i]\n\n    weight_tensors = [p.data for p in params]\n    grad_tensors = [p.grad.data for p in params]\n\n    _assert_contiguous(weight_tensors)\n    _assert_contiguous(grad_tensors)\n\n    with torch.no_grad():\n        flatten_para = torch.nn.Parameter(flatten_weights_tensor, requires_grad=False)\n        flatten_para.grad = flatten_grads_tensor\n        return flatten_para\n    \n\ndef flatten_tensors(tensor_set, chunk=None):\n    tensors = [p for p in tensor_set]\n    weights = [p.data for p in tensors]\n    sizes = [p.numel() for p in tensors]\n    total_size = sum(sizes)\n    if chunk:\n        total_size = ((total_size+chunk-1)//chunk)*chunk\n\n    flatten_weights_tensor = torch.zeros(total_size, dtype=weights[0].dtype).to(weights[0].device)\n    flatten_weights_storage = flatten_weights_tensor.storage()\n\n    def set_storage(param, weight_storage, storage_offset):\n        with torch.no_grad():\n            z = torch.zeros_like(param.data)\n            z.set_(weight_storage, storage_offset, param.shape)\n            param.data = z\n\n    offset = 0\n    for i in range(len(tensors)):\n        flatten_weights_tensor[offset: offset + weights[i].numel()] = weights[i].reshape(-1)\n        set_storage(tensors[i], flatten_weights_storage, offset)\n        offset += sizes[i]\n\n    return flatten_weights_tensor\n"
  },
  {
    "path": "training/dist_clm_train.py",
    "content": "import argparse\nimport time\nimport random\nimport numpy as np\nimport torch\nimport torch.autograd.profiler as profiler\nfrom tasks.data_loaders.data_utils import get_train_data_loader, get_eval_data_loader\nfrom modules.utils import gpt_loss_func\nfrom modules.tokenizer import build_tokenizer\nfrom pipeline_parallel.dist_pp_utils import get_pp_module\n\nfrom transformers import AutoConfig\nimport datasets\n\nfrom utils.dist_args_utils import *\nfrom utils.dist_checkpoint_utils import *\nfrom utils.logging_utils import *\nfrom utils.event_report import *\nfrom comm.comm_utils import *\n\nfrom utils.upload_manager import *\n\n\ndef test_loop(args, pipe, device, test_data_loader):\n    \n    if test_data_loader is None:\n        return\n    \n    print('testing starts.....')\n    \n    pipe.model.eval()\n    \n    if get_pipeline_parallel_rank()  == args.pipeline_group_size - 1:\n        \n        def _lm_pred_func(x, y):\n            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n            logits = x[:, :-1, :].contiguous().float()\n            labels = y[:, 1:].contiguous()\n            loss = loss_fct(logits.transpose(-1, -2), labels).mean(1).detach().cpu()\n            return loss\n        \n        loss_list = []\n        for i, data in enumerate(test_data_loader):\n            \n            if args.evaluation_num_batch is not None and i >= args.evaluation_num_batch:\n                break\n                \n            input_ids = data['input_ids'].to(device)\n            labels = input_ids.clone()\n            pipe.infer_iter(input_ids, labels, output_=loss_list, pred_func=_lm_pred_func)\n            \n        loss = torch.tensor(loss_list).mean()\n        ppls = torch.exp(loss)\n        metric = {\"valid.perplexity\": ppls.item(), \"valid.loss\": loss.item()}\n        \n        print(metric)\n        train_log(\n            metric, \n            step=pipe.global_step,\n        )\n        \n    else:\n        for i, data in enumerate(test_data_loader):\n            \n            if args.evaluation_num_batch is not None and i >= args.evaluation_num_batch:\n                break\n            \n            input_ids = data['input_ids'].to(device)\n            labels = input_ids.clone()\n            current_iter_time = pipe.infer_iter(input_ids, labels)\n    \n    pipe.model.train()\n    \n\n\ndef train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch):\n    \n    print('training starts......')\n\n    event_reporter = EventReporter(host=args.event_host, auth_token=args.event_auth_token, job_id=args.job_id)\n\n    pipe.model.train() # Flag .training to True to enable Dropout\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        # dp_comm = get_data_parallel_comm()\n        dp_rank = get_data_parallel_rank()\n        dp_size = get_data_parallel_world_size()\n    else:\n        dp_rank = 0\n        dp_size = 1\n    pp_comm = get_pipeline_parallel_comm()\n    \n    stop_flag = torch.zeros(1, dtype=torch.int64).to(device)\n    \n    input_ids = torch.zeros(\n        [args.batch_size, args.seq_length], \n        dtype=torch.int64\n    ).to(device)\n    \n    do_sync_before_save = (args.dp_mode in ['local'] and use_dp)\n\n    # Get the number of model parameters for the model\n    param_count = torch.zeros(1, dtype=torch.int64).to(device)\n    local_param_count = sum(p.numel() for p in pipe.model.parameters())\n    param_count.data[:] = local_param_count\n    pp_comm.reduce(param_count, 0)\n\n    if get_pipeline_parallel_rank() == 0 and dp_rank == 0:\n\n        print(f\"Training steps:  total_steps={args.total_steps},  steps_per_epoch={steps_per_epoch},  steps_per_checkpoint={args.checkpoint_steps}\")\n\n        upload_checkpoints_enabled = args.checkpoint_upload_prefix is not None \n        upload_manager = UploadManager(aws_endpoint_url = args.aws_endpoint_url,\n                                       aws_access_key_id = args.aws_access_key_id,\n                                       aws_secret_access_key = args.aws_secret_access_key,\n                                       aws_session_token = args.aws_session_token,\n                                       aws_region = args.aws_region,\n                                       event_reporter = event_reporter,\n                                       n_stages = args.pipeline_group_size)\n\n        if event_reporter is not None:\n\n            # Get the number of tokens in the dataset\n            token_count = train_data_loader.dataset.get_dataset_token_count()\n\n            # Report training start\n            event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,\n                                  message=f\"Training started for model {args.model_name}\",\n                                  event_type=EventReporter.EVENT_TYPE_TRAINING_START,\n                                  param_count=param_count.item(),\n                                  token_count=token_count,\n                                  requires_is_enabled=False)\n        \n        for i, data in enumerate(train_data_loader):\n            # if i < pipe.global_step:\n            #     continue\n                \n            if use_dp:\n                get_data_parallel_comm().broadcast(stop_flag, 0)\n            pp_comm.broadcast(stop_flag, 0)\n            \n            if stop_flag.item() == 1:\n                break\n            \n            input_ids_global = data['input_ids'].to(torch.int64).to(device)\n            \n            input_ids_list = input_ids_global.chunk(dp_size)\n            \n            if use_dp:\n                for j in range(1, dp_size):\n                    get_data_parallel_comm().send(\n                        input_ids_list[j], j,\n                    )\n                \n            input_ids = input_ids_list[0]\n            \n            pp_comm.broadcast(input_ids, 0)\n            \n            labels = input_ids.clone()\n            current_iter_time = pipe.sgd_iter(input_ids, labels, loss_func=gpt_loss_func)\n\n            if event_reporter is not None and (pipe.global_step >= args.total_steps or pipe.global_step % steps_per_epoch == 0):\n                event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,\n                                      message=f\"Epoch completed, at step {pipe.global_step}\",\n                                      event_type=EventReporter.EVENT_TYPE_EPOCH_COMPLETE,\n                                      requires_is_enabled=False)\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n            \n            if pipe.global_step >= args.total_steps or pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    checkpoint_step_path = save_checkpoint(pipe, args)\n                    if upload_checkpoints_enabled:\n                        upload_manager.add_task(directory=checkpoint_step_path,\n                                                checkpoint_upload_prefix=args.checkpoint_upload_prefix,\n                                                step=pipe.global_step)\n\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n            \n            if pipe.global_step >= args.total_steps:\n                stop_flag.data[:] = 1\n        \n        if upload_checkpoints_enabled:\n            upload_manager.wait()\n            \n    elif get_pipeline_parallel_rank() == 0:\n        \n        while True:\n            \n            get_data_parallel_comm().broadcast(stop_flag, 0)\n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n                \n            get_data_parallel_comm().recv(\n                input_ids, 0,\n            )\n            pp_comm.broadcast(input_ids, 0)\n            \n            labels = input_ids.clone()\n            current_iter_time = pipe.sgd_iter(input_ids, labels, loss_func=gpt_loss_func)\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step >= args.total_steps or pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n            \n            \n    elif get_pipeline_parallel_rank()  == args.pipeline_group_size - 1:\n        \n        while True:\n            \n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n                \n            pp_comm.broadcast(input_ids, 0)\n            labels = input_ids.clone()\n            current_iter_time = pipe.sgd_iter(input_ids, labels, loss_func=gpt_loss_func) # lm loss func\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step >= args.total_steps or pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                    pipe.save_on_disk(args.checkpoint_path)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n    else:\n        while True:\n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n            pp_comm.broadcast(input_ids, 0)\n            current_iter_time = pipe.sgd_iter(None, None)\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step >= args.total_steps or pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n\n# Compute the total number of training steps, steps per epoch, and steps per\n# checkpoint\ndef calculate_training_steps(args, train_data_loader) -> int:\n    total_steps = 0\n    steps_per_epoch = 0\n    steps_per_checkpoint = 0\n\n    token_count = train_data_loader.dataset.get_dataset_token_count()\n\n    # Check the inputs to calculate the total steps\n    if args.batch_size is None or args.world_size is None or args.pipeline_group_size is None or token_count is None or args.seq_length is None:\n        print(\"Missing required arguments for calculating total steps based on epochs.\")\n        sys.exit(1)\n\n    global_batch_size = (args.batch_size * args.world_size + args.pipeline_group_size - 1) // args.pipeline_group_size\n    tokens_per_batch = global_batch_size * args.seq_length\n    steps_per_epoch = (token_count + tokens_per_batch - 1) // tokens_per_batch\n\n    if args.total_steps is not None:\n        if args.nepochs is not None:\n            print(\"WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs}).\")\n        total_steps = args.total_steps\n    elif args.nepochs is not None:\n        total_steps = steps_per_epoch * args.nepochs\n    else:\n        total_steps = len(train_data_loader)\n\n    # Set the minimum number of total steps\n    if total_steps < 10:\n        total_steps = 10\n\n    # Ensure that the steps per epoch are consistent with total steps\n    # Note: This does not strictly follow the definition of an epoch. It just\n    # approximately distributes the reporting of epochs over the total number of\n    # steps.\n    if args.nepochs is not None:\n        steps_per_epoch = (total_steps + args.nepochs - 1) // args.nepochs\n\n    # clamp steps_per_epoch to [1, total_steps]\n    if steps_per_epoch > total_steps:\n        steps_per_epoch = total_steps\n    if steps_per_epoch < 1:\n        steps_per_epoch = 1\n\n    # Set the number of steps per epoch based on user input.\n    if args.checkpoint_steps is not None and args.checkpoint_steps > 0:\n        steps_per_checkpoint = args.checkpoint_steps\n    elif args.num_checkpoints is not None and args.num_checkpoints > 0:\n        steps_per_checkpoint = (total_steps + args.num_checkpoints - 1) // args.num_checkpoints\n    else:\n        steps_per_checkpoint = total_steps\n    \n    # Clamp steps_per_checkpoint to [1, total_steps]\n    if steps_per_checkpoint > total_steps:\n        steps_per_checkpoint = total_steps\n    if steps_per_checkpoint < 1:\n        steps_per_checkpoint = 1\n\n    # Set the args base on what we computed above\n    args.total_steps = total_steps\n    args.checkpoint_steps = steps_per_checkpoint\n    return steps_per_epoch\n\ndef main():\n    parser = argparse.ArgumentParser(description='Gpipe-GPT')\n    add_device_arguments(parser)\n    add_torch_distributed_arguments(parser)\n    add_model_arguments(parser)\n    add_task_arguments(parser)\n    add_training_hyper_parameter_arguments(parser)\n    add_mixed_precision_arguments(parser)\n    add_parallel_schema_arguments(parser)\n    add_entry_reporter_arguments(parser)\n    parser.add_argument('--model-name', type=str, default='gpt2', metavar='S',\n                        help='model name or path')\n    parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S',\n                        help='tokenizer name or path')\n    parser.add_argument('--model-type', type=str, default='gpt2', metavar='S',\n                        help='model name or path')\n    parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2')\n    parser.add_argument('--task-name', type=str, default='cot', metavar='S',\n                        help='task name')\n    parser.add_argument('--warmup-steps', type=int, default=0, help='-')\n    parser.add_argument('--train-warmup-steps', type=int, default=0, help='-')\n    parser.add_argument('--nepochs', type=int, default=None, help='-')\n    parser.add_argument('--total-steps', type=int, default=None, help='-')\n    parser.add_argument('--load-pretrained-model', \n                        type=lambda x: x.lower()=='true', default=True, metavar='S',\n                        help='load pretrained model or not.')\n    parser.add_argument('--load-checkpoint', \n                        type=lambda x: x.lower()=='true', default=True, metavar='S',\n                        help='load pretrained model or not.')\n    parser.add_argument('--seed', type=int, default=1, metavar='S',\n                        help='random seed (default: 1)')\n    parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S',\n                        help='enable which profiling? default: tidy mode')\n    parser.add_argument('--trace-postfix', type=str, default='default', metavar='S',\n                        help='postfix of the tracing file name.')\n    parser.add_argument('--evaluation-steps', \n                        type=int, default=0, metavar='S',\n                        help='every x steps, do evaluation. (0 means do not do evaluation)')\n    parser.add_argument('--evaluation-data',\n                        type=str, default=None, help=\"path of eval data in jsonl\")\n    parser.add_argument('--evaluation-num-batch',\n                        type=int, default=None, help=\"for debug purpose, only eval the first several batch.\")\n    parser.add_argument('--checkpoint-steps', \n                        type=int, default=0, metavar='S',\n                        help='every x steps, save checkpoint. (0 means do not save checkpoint)')\n    parser.add_argument('--num-checkpoints', \n                        type=int, default=0, metavar='S',\n                        help='number of checkpoints to save')\n    parser.add_argument('--net-interface', \n                        type=str, default='lo', metavar='S',\n                        help='net_interface')\n    parser.add_argument('--job-id', \n                        type=str, default=\"0\", metavar='S',\n                        help='an uuid')\n    \n    # Add AWS arguments for uploading checkpoints to S3\n    parser.add_argument('--checkpoint-upload-prefix', default=None, help='S3 bucket name')\n    add_aws_arguments(parser)\n\n    args = parser.parse_args()\n    aws_process_args(args)\n    \n    torch.manual_seed(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    \n    if args.use_cuda:\n        assert (torch.cuda.is_available())\n        device = torch.device('cuda', args.cuda_id)\n    else:\n        device = torch.device('cpu')\n        \n    init_communicators(args)\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        dp_comm = get_data_parallel_comm()\n        dp_rank = get_data_parallel_rank()\n        dp_size = get_data_parallel_world_size()\n    else:\n        dp_rank = 0\n        dp_size = 1\n    \n    config = AutoConfig.from_pretrained(args.model_name)\n    \n    # num layer globally\n    if hasattr(config, 'num_hidden_layers'):\n        args.max_layers = config.num_hidden_layers\n    elif hasattr(config, 'num_layers'):\n        args.max_layers = config.num_layers \n    else:\n        args.max_layers = config.n_layer\n    \n    tokenizer = build_tokenizer(args)\n    tokenizer.model_max_length = args.seq_length\n    config.max_position_embeddings = args.seq_length\n    # config.vocab_size = tokenizer.vocab_size\n    config.bos_token_id = tokenizer.bos_token_id\n    config.eos_token_id = tokenizer.eos_token_id\n    config.pad_token_id = tokenizer.pad_token_id\n    print(\"token vocab size:\", config.vocab_size)\n    \n    train_data_loader = get_train_data_loader(args, tokenizer)\n        \n    if args.evaluation_data is not None and dp_rank == 0:\n        test_data_loader = get_eval_data_loader(args, tokenizer)\n    else:\n        test_data_loader = None\n    \n    # calculate total steps\n    steps_per_epoch = calculate_training_steps(args, train_data_loader)\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        print(\"Running \", args.pp_mode, \" with data parallel.\")\n    else:\n        print(\"Running \", args.pp_mode, \" without data parallel.\")\n    \n    pipe = get_pp_module(args, config, device, use_dp)\n    \n    if args.load_checkpoint:\n        load_checkpoint(pipe, args)\n\n    if args.fp16:\n        pipe.optimizer.reload_model_params()\n\n    if args.profiling == 'no-profiling':\n        train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)\n    else:\n        prefix = './trace_json/gpt3_' + args.pp_mode\n        if use_dp:\n            prefix = prefix + '_' + args.dp_mode\n        trace_file = prefix + get_learning_arguments_str(args) + get_model_arguments_str(args) + \\\n                     get_dist_arguments_str(args) + get_mixed_precision_arguments_str(args) + '_' + \\\n                     args.profiling + '_' + args.trace_postfix + '.json'\n        if args.profiling == 'tidy_profiling':\n            try:\n                train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)\n            except Exception as e:\n                raise e\n                print(get_pipeline_parallel_rank(), e)\n            pipe.export_profiling_result(filename=trace_file)\n        elif args.profiling == 'pytorch_profiling':\n            with profiler.profile(profile_memory=True, use_cuda=args.use_cuda) as prof:\n                train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)\n            print(prof.key_averages().table())\n            prof.export_chrome_trace(trace_file)\n        else:\n            print(\"No recognized profiler?\")\n            assert False\n    print(get_pipeline_parallel_rank(), 'finished.')\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "training/dist_prefixlm_train.py",
    "content": "import argparse\nimport time\nimport random\nimport numpy as np\nimport torch\nimport torch.autograd.profiler as profiler\nfrom tasks.data_loaders.data_utils import get_ul2r_train_data_loader\nfrom modules.utils import gpt_loss_func\nfrom modules.tokenizer import build_tokenizer\nfrom pipeline_parallel.dist_pp_utils import get_pp_module\n\nfrom transformers import AutoConfig\nimport datasets\n\nfrom utils.dist_args_utils import *\nfrom utils.dist_checkpoint_utils import *\nfrom utils.logging_utils import *\nfrom comm.comm_utils import *\n\n\ndef test_loop(args, pipe, device, test_data_loader):\n    print(\"no impl for testing, skip.\")\n\n\ndef train_loop(args, pipe, device, train_data_loader, test_data_loader):\n    \n    print('training starts......')\n\n    pipe.model.train() # Flag .training to True to enable Dropout\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        # dp_comm = get_data_parallel_comm()\n        dp_rank = get_data_parallel_rank()\n        dp_size = get_data_parallel_world_size()\n    else:\n        dp_rank = 0\n        dp_size = 1\n    pp_comm = get_pipeline_parallel_comm()\n    \n    stop_flag = torch.zeros(1, dtype=torch.int64).to(device)\n    \n    input_ids = torch.zeros(\n        [args.batch_size, args.seq_length], \n        dtype=torch.int64\n    ).to(device)\n    \n    prefix_masks = torch.zeros(\n        [args.batch_size, args.seq_length], \n        dtype=torch.uint8\n    ).to(device)\n    \n    do_sync_before_save = (args.dp_mode in ['local'] and use_dp)\n    \n    if get_pipeline_parallel_rank() == 0 and dp_rank == 0:\n        \n        for i, data in enumerate(train_data_loader):\n            if i < pipe.global_step:\n                continue\n                \n            if use_dp:\n                get_data_parallel_comm().broadcast(stop_flag, 0)\n            pp_comm.broadcast(stop_flag, 0)\n            \n            if stop_flag.item() == 1:\n                break\n            \n            input_ids_global = data['input_ids'].to(torch.int64).to(device)\n            prefix_masks_global = data['prefix_masks'].to(torch.uint8).to(device)\n            \n            input_ids_list = input_ids_global.chunk(dp_size)\n            prefix_masks_list = prefix_masks_global.chunk(dp_size)\n            \n            if use_dp:\n                for j in range(1, dp_size):\n                    get_data_parallel_comm().send(\n                        input_ids_list[j], j,\n                    )\n                    get_data_parallel_comm().send(\n                        prefix_masks_list[j], j,\n                    )\n                \n            input_ids = input_ids_list[0]\n            prefix_masks = prefix_masks_list[0]\n            \n            pp_comm.broadcast(input_ids, 0)\n            pp_comm.broadcast(prefix_masks, 0)\n            \n            labels = input_ids.clone()\n            current_iter_time = pipe.sgd_iter(\n                input_ids, labels, aux_input_data={'prefix_masks': prefix_masks}, loss_func=gpt_loss_func\n            )\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n            \n            if pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n            \n            if pipe.global_step >= args.total_steps:\n                stop_flag.data[:] = 1\n            \n    elif get_pipeline_parallel_rank() == 0:\n        \n        while True:\n            \n            get_data_parallel_comm().broadcast(stop_flag, 0)\n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n                \n            get_data_parallel_comm().recv(\n                input_ids, 0,\n            )\n            get_data_parallel_comm().recv(\n                prefix_masks, 0,\n            )\n            pp_comm.broadcast(input_ids, 0)\n            pp_comm.broadcast(prefix_masks, 0)\n            \n            labels = input_ids.clone()\n            current_iter_time = pipe.sgd_iter(\n                input_ids, labels, aux_input_data={'prefix_masks': prefix_masks}, loss_func=gpt_loss_func)\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n            \n            \n    elif get_pipeline_parallel_rank()  == args.pipeline_group_size - 1:\n        \n        while True:\n            \n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n                \n            pp_comm.broadcast(input_ids, 0)\n            pp_comm.broadcast(prefix_masks, 0)\n            labels = input_ids.clone()\n            labels[prefix_masks.bool()] = -100 # mask prefix part\n            current_iter_time = pipe.sgd_iter(\n                input_ids, labels, loss_func=gpt_loss_func, aux_input_data={'prefix_masks': prefix_masks}\n            )\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                    pipe.save_on_disk(args.checkpoint_path)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n    else:\n        while True:\n            pp_comm.broadcast(stop_flag, 0)\n            if stop_flag.item() == 1:\n                break\n            pp_comm.broadcast(input_ids, 0)\n            pp_comm.broadcast(prefix_masks, 0)\n            current_iter_time = pipe.sgd_iter(None, None, aux_input_data={'prefix_masks': prefix_masks})\n            \n            if args.evaluation_steps > 0 and pipe.global_step % args.evaluation_steps == 0:\n                test_loop(args, pipe, device, test_data_loader)\n                \n            if pipe.global_step % args.checkpoint_steps == 0:\n                if do_sync_before_save:\n                    pipe.dp_optim.allreduce_parameters()\n                if dp_rank == 0:\n                    save_checkpoint(pipe, args)\n                if do_sync_before_save:\n                    pipe.dp_optim.rollback_parameters()\n        \n\ndef main():\n    parser = argparse.ArgumentParser(description='Gpipe-GPT')\n    add_device_arguments(parser)\n    add_torch_distributed_arguments(parser)\n    add_model_arguments(parser)\n    add_task_arguments(parser)\n    add_training_hyper_parameter_arguments(parser)\n    add_mixed_precision_arguments(parser)\n    add_parallel_schema_arguments(parser)\n    parser.add_argument('--model-name', type=str, default='gpt2', metavar='S',\n                        help='model name or path')\n    parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S',\n                        help='tokenizer name or path')\n    parser.add_argument('--model-type', type=str, default='gpt2', metavar='S',\n                        help='model name or path')\n    parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2')\n    parser.add_argument('--task-name', type=str, default='cot', metavar='S',\n                        help='task name')\n    parser.add_argument('--warmup-steps', type=int, default=0, help='-')\n    parser.add_argument('--train-warmup-steps', type=int, default=0, help='-')\n    parser.add_argument('--total-steps', type=int, default=None, help='-')\n    parser.add_argument('--load-pretrained-model', \n                        type=lambda x: x.lower()=='true', default=True, metavar='S',\n                        help='load pretrained model or not.')\n    parser.add_argument('--load-checkpoint', \n                        type=lambda x: x.lower()=='true', default=True, metavar='S',\n                        help='load pretrained model or not.')\n    parser.add_argument('--seed', type=int, default=1, metavar='S',\n                        help='random seed (default: 1)')\n    parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S',\n                        help='enable which profiling? default: tidy mode')\n    parser.add_argument('--trace-postfix', type=str, default='default', metavar='S',\n                        help='postfix of the tracing file name.')\n    parser.add_argument('--evaluation-steps', \n                        type=int, default=0, metavar='S',\n                        help='every x steps, do evaluation. (0 means do not do evaluation)')\n    parser.add_argument('--evaluation-data',\n                        type=str, default=None, help=\"path of eval data in jsonl\")\n    parser.add_argument('--evaluation-num-batch',\n                        type=int, default=None, help=\"for debug purpose, only eval the first several batch.\")\n    parser.add_argument('--checkpoint-steps', \n                        type=int, default=0, metavar='S',\n                        help='every x steps, save checkpoint. (0 means do not save checkpoint)')\n    parser.add_argument('--net-interface', \n                        type=str, default='lo', metavar='S',\n                        help='net_interface')\n    parser.add_argument('--job-id', \n                        type=str, default=\"0\", metavar='S',\n                        help='an uuid')\n    args = parser.parse_args()\n    \n    torch.manual_seed(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    \n    if args.use_cuda:\n        assert (torch.cuda.is_available())\n        device = torch.device('cuda', args.cuda_id)\n    else:\n        device = torch.device('cpu')\n        \n    init_communicators(args)\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        dp_comm = get_data_parallel_comm()\n        dp_rank = get_data_parallel_rank()\n        dp_size = get_data_parallel_world_size()\n    else:\n        dp_rank = 0\n        dp_size = 1\n    \n    config = AutoConfig.from_pretrained(args.model_name)\n    \n    # num layer globally\n    if hasattr(config, 'num_hidden_layers'):\n        args.max_layers = config.num_hidden_layers\n    elif hasattr(config, 'num_layers'):\n        args.max_layers = config.num_layers \n    else:\n        args.max_layers = config.n_layer\n    \n    tokenizer = build_tokenizer(args)\n    tokenizer.model_max_length = args.seq_length\n    # config.vocab_size = tokenizer.vocab_size\n    config.bos_token_id = tokenizer.bos_token_id\n    config.eos_token_id = tokenizer.eos_token_id\n    config.pad_token_id = tokenizer.pad_token_id\n    print(\"token vocab size:\", config.vocab_size)\n    \n    if get_pipeline_parallel_rank() == 0 and dp_rank == 0:\n        train_data_loader = get_ul2r_train_data_loader(args, tokenizer)\n    else:\n        train_data_loader = None\n        \n    test_data_loader = None\n        \n    if args.total_steps is None:\n        args.total_steps = len(train_data_loader)\n    \n    use_dp = (args.world_size != args.pipeline_group_size)\n    if use_dp:\n        print(\"Running \", args.pp_mode, \" with data parallel.\")\n    else:\n        print(\"Running \", args.pp_mode, \" without data parallel.\")\n    \n    pipe = get_pp_module(args, config, device, use_dp)\n    \n    if args.load_checkpoint:\n        load_checkpoint(pipe, args)\n\n    if args.fp16:\n        pipe.optimizer.reload_model_params()\n        \n    if args.model_type == 'gptj':\n        # make sure, causal mask is here.\n        max_positions = config.n_positions\n        for module in pipe.model.model:\n            if hasattr(module, 'attn'):\n                print('put back causal mask')\n                module.attn.bias[:] = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(\n                    1, 1, max_positions, max_positions\n                )\n\n    if args.profiling == 'no-profiling':\n        train_loop(args, pipe, device, train_data_loader, test_data_loader)\n    else:\n        prefix = './trace_json/gpt3_' + args.pp_mode\n        if use_dp:\n            prefix = prefix + '_' + args.dp_mode\n        trace_file = prefix + get_learning_arguments_str(args) + get_model_arguments_str(args) + \\\n                     get_dist_arguments_str(args) + get_mixed_precision_arguments_str(args) + '_' + \\\n                     args.profiling + '_' + args.trace_postfix + '.json'\n        if args.profiling == 'tidy_profiling':\n            try:\n                train_loop(args, pipe, device, train_data_loader, test_data_loader)\n            except Exception as e:\n                raise e\n                print(get_pipeline_parallel_rank(), e)\n            pipe.export_profiling_result(filename=trace_file)\n        elif args.profiling == 'pytorch_profiling':\n            with profiler.profile(profile_memory=True, use_cuda=args.use_cuda) as prof:\n                train_loop(args, pipe, device, train_data_loader, test_data_loader)\n            print(prof.key_averages().table())\n            prof.export_chrome_trace(trace_file)\n        else:\n            print(\"No recognized profiler?\")\n            assert False\n    print(get_pipeline_parallel_rank(), 'finished.')\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "training/finetune_GPT-NeoXT-Chat-Base-20B.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=GPT-Neo-XT-Chat-Base-20B\n\nexport SHOW_DATA=0\n\nBASE_MODEL=\"${DIR}/../pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b/\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-20000}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-100}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"\\\n${DIR}/../data/OIG/files/unified_ni.jsonl:0.2,\\\n${DIR}/../data/OIG/files/unified_p3.jsonl:0.5,\\\n${DIR}/../data/OIG/files/unified_flan.jsonl:0.2,\\\n${DIR}/../data/OIG/files/unified_chip2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_rallio_safety_and_prosocial.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_soda_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_unifiedskg_instructions.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_merged_code_xp3.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_oscar_en_sample_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_ul2_plus_oscar_en_sample_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_multi_news.jsonl:0.05,\\\n${DIR}/../data/OIG/files/unified_openai_summarize_tldr.jsonl:0.05,\\\n${DIR}/../data/OIG/files/unified_squad_v2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_nq.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_poetry_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_sqlv2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_unnatural_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_conv_finqa.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_essays.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_plot_screenplay_books_dialog.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_grade_school_math_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_mathqa_flanv2_kojma_cot.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_joke_explanations.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_cuad.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_abstract_infill.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_image_prompts_instructions.jsonl:0.01 \\\n\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type gptneox \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 10 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 1e-6 --seq-length 2048 --batch-size 64 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 6 --embedding-dim 6144 \\\n--world-size 8 --pipeline-group-size 8 --data-group-size 1 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/finetune_Pythia-Chat-Base-7B.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=Pythia-Chat-Base-7B\n\nexport SHOW_DATA=0\n\nBASE_MODEL=\"${DIR}/../pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-20000}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-100}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"\\\n${DIR}/../data/OIG/files/unified_ni.jsonl:0.2,\\\n${DIR}/../data/OIG/files/unified_p3.jsonl:0.5,\\\n${DIR}/../data/OIG/files/unified_flan.jsonl:0.2,\\\n${DIR}/../data/OIG/files/unified_chip2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_rallio_safety_and_prosocial.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_soda_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_unifiedskg_instructions.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_merged_code_xp3.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_oscar_en_sample_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_ul2_plus_oscar_en_sample_dialog.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_multi_news.jsonl:0.05,\\\n${DIR}/../data/OIG/files/unified_openai_summarize_tldr.jsonl:0.05,\\\n${DIR}/../data/OIG/files/unified_squad_v2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_nq.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_poetry_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_sqlv2.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_unnatural_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_conv_finqa.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_essays.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_plot_screenplay_books_dialog.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_grade_school_math_instructions.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_mathqa_flanv2_kojma_cot.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_joke_explanations.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_cuad.jsonl:0.01,\\\n${DIR}/../data/OIG/files/unified_abstract_infill.jsonl:0.1,\\\n${DIR}/../data/OIG/files/unified_image_prompts_instructions.jsonl:0.01 \\\n\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type gptneox \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 10 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 8 --embedding-dim 4096 \\\n--world-size 8 --pipeline-group-size 4 --data-group-size 2 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/finetune_RedPajama-INCITE-7B-Chat.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=redpajama-incite-chat-3b-sample\n\nexport SHOW_DATA=0\n\nBASE_MODEL=\"${DIR}/../pretrained/RedPajama-7B/togethercomputer_RedPajama-INCITE-7B-Chat\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-10}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-10}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"${DIR}/../data/OIG-chip2/unified_chip2.jsonl:1\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type gptneox \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 0 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 4 --embedding-dim 2560 \\\n--world-size 8 --pipeline-group-size 8 --data-group-size 1 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/finetune_RedPajama-INCITE-Chat-3B-v1.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=redpajama-incite-chat-3b-sample\n\nexport SHOW_DATA=0\n\nBASE_MODEL=\"${DIR}/../pretrained/RedPajama-3B/togethercomputer_RedPajama-INCITE-Chat-3B-v1\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-10}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-10}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"${DIR}/../data/OIG-chip2/unified_chip2.jsonl:1\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type gptneox \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 0 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 4 --embedding-dim 2560 \\\n--world-size 8 --pipeline-group-size 8 --data-group-size 1 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/finetune_llama-2-7b-32k-booksum.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=llama-2-7b-32k-booksum\n\nexport SHOW_DATA=1\n\nBASE_MODEL=\"${DIR}/../pretrained/Llama-2-7B-32K-beta/togethercomputer_Llama-2-7B-32K-beta\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-10}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-10}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"https://huggingface.co/datasets/togethercomputer/Long-Data-Collections/resolve/main/fine-tune/booksum.jsonl.zst:1\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type llama \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 0 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 2e-5 --seq-length 32768 --batch-size 4 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 4 --embedding-dim 4096 \\\n--world-size 8 --pipeline-group-size 8 --data-group-size 1 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/finetune_llama-2-7b-32k-mqa.sh",
    "content": "DIR=$(cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd)\n\nnetif=lo\nexport GLOO_SOCKET_IFNAME=${netif}\nexport NCCL_SOCKET_IFNAME=${netif}\nexport MODEL_NAME=llama-2-7b-32k-mqa\n\nexport SHOW_DATA=1\n\nBASE_MODEL=\"${DIR}/../pretrained/Llama-2-7B-32K-beta/togethercomputer_Llama-2-7B-32K-beta\"\n\nTOTAL_STEPS=${FINETUNE_TOTAL_STEPS:-10}\nCHECKPOINT_STEPS=${FINETUNE_CHECKPOINT_STEPS:-10}\nCHECKPOINT_PATH=${FINETUNE_CHECKPOINT_PATH:-\"${DIR}/../model_ckpts/${MODEL_NAME}\"}\n\nDATASETS=\"https://huggingface.co/datasets/togethercomputer/Long-Data-Collections/resolve/main/fine-tune/natural_questions_10_200_docs.jsonl.zst:1\"\n\nARGS=\"--model-name ${BASE_MODEL} \\\n--tokenizer-name ${BASE_MODEL} \\\n--project-name together \\\n--model-type llama \\\n--optimizer adam \\\n--seed 42 \\\n--load-pretrained-model true \\\n--task-name \\\n\"${DATASETS}\" \\\n--checkpoint-path ${CHECKPOINT_PATH} \\\n--total-steps ${TOTAL_STEPS} --warmup-steps 0 --train-warmup-steps 0 \\\n--checkpoint-steps ${CHECKPOINT_STEPS} \\\n--lr 2e-5 --seq-length 32768 --batch-size 4 --micro-batch-size 1 --gradient-accumulate-step 1 \\\n--dist-url tcp://127.0.0.1:7033 \\\n--num-layers 4 --embedding-dim 4096 \\\n--world-size 8 --pipeline-group-size 8 --data-group-size 1 \\\n--job-id 0 --net-interface ${netif} \\\n--fp16 \\\n--dp-backend nccl \\\n--dp-mode allreduce \\\n--pp-mode gpipe --profiling no-profiling\"\n\n(trap 'kill 0' SIGINT; \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \\\n    & \\\npython ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \\\n    & \\\nwait)\n"
  },
  {
    "path": "training/lora/example/redpajama-incite-chat-3b.py",
    "content": "import os\nimport json\nos.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\nimport torch\nimport transformers\nimport torch.nn as nn\nimport bitsandbytes as bnb\nfrom datasets import Dataset\nfrom peft import LoraConfig, get_peft_model\nfrom transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n\n# this script should take around 14GB VRAM\n\nMODEL_NAME='redpajama-incite-chat-3b-sample-lowrank'\n\n# read datasets\nwith open('data/OIG-chip2/unified_chip2.jsonl', 'r') as fp:\n    data = [json.loads(x) for x in fp.readlines()]\n\nmodel = AutoModelForCausalLM.from_pretrained(\n    \"togethercomputer/RedPajama-INCITE-Chat-3B-v1\", \n    device_map='auto',\n)\n\ntokenizer = AutoTokenizer.from_pretrained(\"togethercomputer/RedPajama-INCITE-Chat-3B-v1\")\ntokenizer.pad_token = tokenizer.eos_token\n\nfor param in model.parameters():\n  param.requires_grad = False  # freeze the model - train adapters later\n  if param.ndim == 1:\n    # cast the small parameters (e.g. layernorm) to fp32 for stability\n    param.data = param.data.to(torch.float32)\n\nmodel.gradient_checkpointing_enable()  # reduce number of stored activations\nmodel.enable_input_require_grads()\n\ndef print_trainable_parameters(model):\n    \"\"\"\n    Prints the number of trainable parameters in the model.\n    \"\"\"\n    trainable_params = 0\n    all_param = 0\n    for _, param in model.named_parameters():\n        all_param += param.numel()\n        if param.requires_grad:\n            trainable_params += param.numel()\n    print(\n        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n    )\n\nconfig = LoraConfig(\n    r=16,\n    lora_alpha=32,\n    target_modules=[\"query_key_value\", \"xxx\"],\n    lora_dropout=0.05,\n    bias=\"none\",\n    task_type=\"CAUSAL_LM\"\n)\n\nmodel = get_peft_model(model, config)\nprint_trainable_parameters(model)\n\n## Training\n\ndata = Dataset.from_list(data)\ndata = data.map(lambda samples: tokenizer(samples['text']), batched=True)\n\ntrainer = transformers.Trainer(\n    model=model, \n    train_dataset=data,\n    args=transformers.TrainingArguments(\n        per_device_train_batch_size=4, \n        gradient_accumulation_steps=4,\n        warmup_steps=100, \n        max_steps=200, \n        learning_rate=2e-4, \n        fp16=True,\n        logging_steps=1, \n        output_dir='outputs'\n    ),\n    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n)\nmodel.config.use_cache = False  # silence the warnings. Please re-enable for inference!\ntrainer.train()\n\n# save the trained adapter to disk\nmodel.save_pretrained(f\"outputs/{MODEL_NAME}\")\n"
  },
  {
    "path": "training/lora/example/redpajama-incite-chat-3b_inference.py",
    "content": "import torch\nfrom peft import PeftModel, PeftConfig\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\npeft_model_path ='outputs/redpajama-incite-chat-3b-sample-lowrank'\n\nconfig = PeftConfig.from_pretrained(peft_model_path)\nmodel = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')\ntokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n\n# Load the Lora model\nmodel = PeftModel.from_pretrained(model, peft_model_path)\n\nbatch = tokenizer(\"<human>: Hello!\\n<bot>:\", return_tensors='pt')\n\nwith torch.cuda.amp.autocast():\n  output_tokens = model.generate(**batch, max_new_tokens=50)\n\nprint('\\n\\n', tokenizer.decode(output_tokens[0], skip_special_tokens=True))\n"
  },
  {
    "path": "training/modules/__init__.py",
    "content": ""
  },
  {
    "path": "training/modules/deberta_modules.py",
    "content": "import torch\nimport numpy as np\nimport math\nfrom torch import nn\nfrom torch.nn import functional\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\n\n\n#### Hack Deberta #####\n\ndef make_log_bucket_position(relative_pos, bucket_size, max_position):\n    sign = torch.sign(relative_pos)\n    mid = bucket_size // 2\n    abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))\n    log_pos = torch.ceil(torch.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid\n    bucket_pos = torch.where(abs_pos <= mid, relative_pos.type(log_pos.dtype), log_pos * sign).long()\n    return bucket_pos\n\ndef build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device='cpu'):\n    q_ids = torch.arange(0, query_size, device=device)\n    k_ids = torch.arange(0, key_size, device=device)\n    rel_pos_ids = q_ids[:, None] - torch.tile(k_ids, (q_ids.shape[0], 1))\n    if bucket_size > 0 and max_position > 0:\n        rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)\n    rel_pos_ids = rel_pos_ids[:query_size, :]\n    rel_pos_ids = rel_pos_ids.unsqueeze(0)\n    return rel_pos_ids\n\n\nfrom transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax, StableDropout\nclass DisentangledSelfAttention(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        if config.hidden_size % config.num_attention_heads != 0:\n            raise ValueError(\n                f\"The hidden size ({config.hidden_size}) is not a multiple of the number of attention \"\n                f\"heads ({config.num_attention_heads})\"\n            )\n        self.num_attention_heads = config.num_attention_heads\n        _attention_head_size = config.hidden_size // config.num_attention_heads\n        self.attention_head_size = getattr(config, \"attention_head_size\", _attention_head_size)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n        self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)\n        self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)\n        self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)\n\n        self.share_att_key = getattr(config, \"share_att_key\", False)\n        self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []\n        self.relative_attention = getattr(config, \"relative_attention\", False)\n\n        if self.relative_attention:\n            self.position_buckets = getattr(config, \"position_buckets\", -1)\n            self.max_relative_positions = getattr(config, \"max_relative_positions\", -1)\n            if self.max_relative_positions < 1:\n                self.max_relative_positions = config.max_position_embeddings\n            self.pos_ebd_size = self.max_relative_positions\n            if self.position_buckets > 0:\n                self.pos_ebd_size = self.position_buckets\n\n            self.pos_dropout = StableDropout(config.hidden_dropout_prob)\n\n            if not self.share_att_key:\n                if \"c2p\" in self.pos_att_type:\n                    self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)\n                if \"p2c\" in self.pos_att_type:\n                    self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = StableDropout(config.attention_probs_dropout_prob)\n\n    def transpose_for_scores(self, x, attention_heads):\n        new_x_shape = x.size()[:-1] + (attention_heads, -1)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask,\n        output_attentions=False,\n        query_states=None,\n        relative_pos=None,\n        rel_embeddings=None,\n    ):\n        if query_states is None:\n            query_states = hidden_states\n        query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)\n        key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)\n        value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)\n\n        rel_att = None\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        scale_factor = 1\n        if \"c2p\" in self.pos_att_type:\n            scale_factor += 1\n        if \"p2c\" in self.pos_att_type:\n            scale_factor += 1\n        scale = math.sqrt(query_layer.size(-1) * scale_factor)\n        attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale\n        if self.relative_attention:\n            rel_embeddings = self.pos_dropout(rel_embeddings)\n            rel_att = self.disentangled_attention_bias(\n                query_layer, key_layer, relative_pos, rel_embeddings, scale_factor\n            )\n\n        if rel_att is not None:\n            attention_scores = attention_scores + rel_att\n        attention_scores = attention_scores\n        attention_scores = attention_scores.view(\n            -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)\n        )\n\n        # bsz x height x length x dimension\n        attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)\n        attention_probs = self.dropout(attention_probs)\n        context_layer = torch.bmm(\n            attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer\n        )\n        context_layer = (\n            context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))\n            .permute(0, 2, 1, 3)\n            .contiguous()\n        )\n        new_context_layer_shape = context_layer.size()[:-2] + (-1,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n        if output_attentions:\n            return (context_layer, attention_probs)\n        else:\n            return context_layer\n\n    def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):\n        if relative_pos is None:\n            q = query_layer.size(-2)\n            relative_pos = build_relative_position(\n                q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions, device=query_layer.device,\n            )\n        if relative_pos.dim() == 2:\n            relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)\n        elif relative_pos.dim() == 3:\n            relative_pos = relative_pos.unsqueeze(1)\n        # bsz x height x query x key\n        elif relative_pos.dim() != 4:\n            raise ValueError(f\"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}\")\n\n        att_span = self.pos_ebd_size\n\n        rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)\n        if self.share_att_key:\n            pos_query_layer = self.transpose_for_scores(\n                self.query_proj(rel_embeddings), self.num_attention_heads\n            ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)\n            pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(\n                query_layer.size(0) // self.num_attention_heads, 1, 1\n            )\n        else:\n            if \"c2p\" in self.pos_att_type:\n                pos_key_layer = self.transpose_for_scores(\n                    self.pos_key_proj(rel_embeddings), self.num_attention_heads\n                ).repeat(\n                    query_layer.size(0) // self.num_attention_heads, 1, 1\n                )  # .split(self.all_head_size, dim=-1)\n            if \"p2c\" in self.pos_att_type:\n                pos_query_layer = self.transpose_for_scores(\n                    self.pos_query_proj(rel_embeddings), self.num_attention_heads\n                ).repeat(\n                    query_layer.size(0) // self.num_attention_heads, 1, 1\n                )  # .split(self.all_head_size, dim=-1)\n\n        score = 0\n        # content->position\n        if \"c2p\" in self.pos_att_type:\n            scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)\n            c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))\n            c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)\n            c2p_att = torch.gather(\n                c2p_att,\n                dim=-1,\n                index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),\n            )\n            score += c2p_att / scale\n\n        # position->content\n        if \"p2c\" in self.pos_att_type:\n            scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)\n            if key_layer.size(-2) != query_layer.size(-2):\n                r_pos = build_relative_position(\n                    key_layer.size(-2),\n                    key_layer.size(-2),\n                    bucket_size=self.position_buckets,\n                    max_position=self.max_relative_positions,\n                    device=query_layer.device,\n                )\n                r_pos = r_pos.unsqueeze(0)\n            else:\n                r_pos = relative_pos\n\n            p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)\n            p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))\n            p2c_att = torch.gather(\n                p2c_att,\n                dim=-1,\n                index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),\n            ).transpose(-1, -2)\n            score += p2c_att / scale\n\n        return score\nimport transformers.models.deberta_v2.modeling_deberta_v2\ntransformers.models.deberta_v2.modeling_deberta_v2.DisentangledSelfAttention = DisentangledSelfAttention\n\n#### Hack Deberta #####\n\nfrom transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Embeddings, ConvLayer\nfrom transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Layer\nfrom transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Encoder as _DebertaV2Encoder\nfrom transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config\nfrom transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout, ContextPooler \n    \nclass DebertaV2Layers(_DebertaV2Encoder):\n    def __init__(self, config, first_block=False):\n        super(_DebertaV2Encoder, self).__init__()\n        \n        self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])\n        self.relative_attention = getattr(config, \"relative_attention\", False)\n\n        if self.relative_attention:\n            self.max_relative_positions = getattr(config, \"max_relative_positions\", -1)\n            if self.max_relative_positions < 1:\n                self.max_relative_positions = config.max_position_embeddings\n\n            self.position_buckets = getattr(config, \"position_buckets\", -1)\n            pos_ebd_size = self.max_relative_positions * 2\n\n            if self.position_buckets > 0:\n                pos_ebd_size = self.position_buckets * 2\n\n            self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)\n\n        self.norm_rel_ebd = [x.strip() for x in getattr(config, \"norm_rel_ebd\", \"none\").lower().split(\"|\")]\n\n        if \"layer_norm\" in self.norm_rel_ebd:\n            self.LayerNorm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)\n\n        if first_block:\n            self.conv = ConvLayer(config) if getattr(config, \"conv_kernel_size\", 0) > 0 else None\n        else:\n            self.conv = None\n            \n        self.gradient_checkpointing = True # TODO\n        \n        if hasattr(self, 'LayerNorm'):\n            for p in self.LayerNorm.parameters():\n                p.requires_grad = False\n        if hasattr(self, 'rel_embeddings'):\n            for p in self.rel_embeddings.parameters():\n                p.requires_grad = False\n                \n    def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):\n        if self.relative_attention and relative_pos is None:\n            q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)\n            relative_pos = build_relative_position(\n                q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions, device=hidden_states.device,\n            )\n        return relative_pos\n    \n    def forward(\n        self,\n        hidden_states,\n        attention_mask,\n        query_states=None,\n        relative_pos=None,\n    ):\n        if attention_mask.dim() <= 2:\n            input_mask = attention_mask\n        else:\n            input_mask = (attention_mask.sum(-2) > 0).byte()\n        attention_mask = self.get_attention_mask(attention_mask)\n        relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)\n\n        next_kv = hidden_states # TODOs\n        rel_embeddings = self.get_rel_embedding()\n        output_states = next_kv\n        for i, layer_module in enumerate(self.layer):\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                output_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    next_kv,\n                    attention_mask,\n                    query_states,\n                    relative_pos,\n                    rel_embeddings,\n                )\n            else:\n                output_states = layer_module(\n                    next_kv,\n                    attention_mask,\n                    query_states=query_states,\n                    relative_pos=relative_pos,\n                    rel_embeddings=rel_embeddings,\n                )\n\n            if i == 0 and self.conv is not None:\n                output_states = self.conv(hidden_states, output_states, input_mask)\n                \n            next_kv = output_states\n\n        return output_states\n\n\n\nclass DebertaClassificationHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.pooler = ContextPooler(config)\n        self.classifier = nn.Linear(\n            self.pooler.output_dim, getattr(config, \"num_labels\", 2),\n        )\n        \n        drop_out = getattr(config, \"cls_dropout\", None)\n        drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out\n        self.dropout = StableDropout(drop_out)\n        \n    def forward(self, hidden_states, input_ids=None):\n        pooled_output = self.pooler(hidden_states)\n        pooled_output = self.dropout(pooled_output)\n        logits = self.classifier(pooled_output)\n        return logits"
  },
  {
    "path": "training/modules/dist_deberta_pp_module.py",
    "content": "from torch import nn\nfrom .deberta_modules import DebertaV2Embeddings, DebertaV2Layers, DebertaClassificationHead\n\n\nclass DebertaStageBase(nn.Module):\n    def __init__(self, args, config):\n        super().__init__()\n        self._to_cpu = False # (args.dist_backend == \"gloo\")\n        self.config = config\n\n    def _create_first_layer(self):\n        return DebertaV2Embeddings(self.config)\n\n    def _create_last_layer(self):\n        return DebertaClassificationHead(self.config)\n\n    def _create_transformer_layers(self, first_block=False):\n        return DebertaV2Layers(self.config, first_block=first_block) # TODO: checkpoint\n\n\nclass DebertaStageFirst(DebertaStageBase):\n    def __init__(self, args, config, device):\n        super().__init__(args, config)\n        self.device = device\n        self.embeddings = self._create_first_layer().to(device)\n        self.encoder = self._create_transformer_layers(first_block=True).to(device)\n\n    def forward(self, x, token_type_ids=None, attention_mask=None):\n        if self._to_cpu:\n            x = x.to(self.device)\n            if token_type_ids is not None:\n                token_type_ids = token_type_ids.to(self.device)\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(self.device)\n        x = self.embeddings(x, token_type_ids=token_type_ids)\n        out = self.encoder(x, attention_mask=attention_mask)\n        return out.cpu() if self._to_cpu else out\n\n\nclass DebertaStageMiddle(DebertaStageBase):\n    def __init__(self, args, config, device):\n        super().__init__(args, config)\n        self.device = device\n        self.encoder = self._create_transformer_layers(first_block=False).to(device)\n\n    def forward(self, x, attention_mask=None):\n        if self._to_cpu:\n            x = x.to(self.device)\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(self.device)\n        out = self.encoder(x, attention_mask=attention_mask)\n        return out.cpu() if self._to_cpu else out\n\n\nclass DebertaStageLast(DebertaStageBase):\n    def __init__(self, args, config, device):\n        super().__init__(args, config)\n        self.device = device\n        self.encoder = self._create_transformer_layers(first_block=False).to(device)\n        self.output_head = self._create_last_layer().to(device)\n\n    def forward(self, x, attention_mask=None, input_ids=None):\n        if self._to_cpu:\n            x = x.to(self.device)\n            if attention_mask is not None:\n                attention_mask = attention_mask.to(self.device)\n        x = self.encoder(x, attention_mask=attention_mask)\n        out = self.output_head(x)\n        return out.cpu() if self._to_cpu else out"
  },
  {
    "path": "training/modules/dist_gpt_fsdp_module.py",
    "content": "import torch\nfrom fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP\nfrom .task_modules import GlueClassification\nfrom .gpt_modules import MultiHeadAttention, TwoLayerMLP, GPTEmbedding\nfrom fairscale.nn.checkpoint import checkpoint_wrapper\n\n\n# This is only implemented to support checkpoint in FSDP\n\nclass GPTTransformerFsdpLayer(torch.nn.Module):\n    def __init__(self, model_dim, head_num, feedforward_dim=2048, layer_norm_eps=1e-5, use_checkpoint=True,\n                 explicit_fsdp=False) -> None:\n        super(GPTTransformerFsdpLayer, self).__init__()\n        self.attn = MultiHeadAttention(model_dim, head_num)\n        if use_checkpoint:\n            self.attn = checkpoint_wrapper(self.attn)\n        if explicit_fsdp:\n            self.attn = FSDP(self.attn, reshard_after_forward=True, move_params_to_cpu=False, mixed_precision=False,\n                             flatten_parameters=False)\n        # Implementation of Feedforward model\n        self.mlp = TwoLayerMLP(model_dim, feedforward_dim)\n        if use_checkpoint:\n            self.mlp = checkpoint_wrapper(self.mlp)\n        if explicit_fsdp:\n            self.attn = FSDP(self.attn, reshard_after_forward=True, move_params_to_cpu=False, mixed_precision=False,\n                             flatten_parameters=False)\n        self.norm1 = torch.nn.LayerNorm(model_dim, eps=layer_norm_eps)\n        self.norm2 = torch.nn.LayerNorm(model_dim, eps=layer_norm_eps)\n        # self.dropout1 = nn.Dropout(dropout)\n        # self.dropout2 = nn.Dropout(dropout)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.norm1(x)\n        # x = x + self.dropout_1(self.attn(x2, x2, x2))\n        x.requires_grad_(True)\n        x = self.attn(x)\n        x = self.norm2(x)\n        # x = x + self.dropout_2(self.ff(x2))\n        x.requires_grad_(True)\n        x = self.mlp(x)\n        return x\n\n\nclass GPTGlueFsdpModel(torch.nn.Module):\n    def __init__(self, args, vocab_size, num_classes, use_checkpoint=True):\n        super(GPTGlueFsdpModel, self).__init__()\n        self.embedding = GPTEmbedding(vocab_size, args.embedding_dim, args.seq_length)\n\n        module_list = []\n        for _ in range(args.num_layers):\n            module_list.append(GPTTransformerFsdpLayer(args.embedding_dim, args.num_heads,\n                                                       args.embedding_dim * 4, use_checkpoint, explicit_fsdp=False))\n        self.transformers = torch.nn.Sequential(*module_list)\n        self.classifier = GlueClassification(args.embedding_dim, num_classes)\n\n    def forward(self, input_ids, position_ids):\n        input_emb = self.embedding(input_ids, position_ids)\n        output_emb = self.transformers(input_emb)\n        return self.classifier(output_emb)\n\n\nclass GPTFsdpStageBase(torch.nn.Module):\n    def __init__(self, args, num_stage_layers, vocab_size, num_classes, use_checkpoint=True, explicit_fsdp=True):\n        super(GPTFsdpStageBase, self).__init__()\n        self._vocab_size = vocab_size\n        self._explicit_fsdp = explicit_fsdp\n        self._use_checkpoint = use_checkpoint\n        self._embedding_dim = args.embedding_dim  # embedding dimension\n        self._seq_length = args.seq_length\n        self._num_classes = num_classes\n        # the dimension of the feedforward aws_network model in nn.TransformerEncoder\n        self._feedforward_dim = args.embedding_dim * 4\n        self._num_heads = args.num_heads  # the number of heads in the multi-head attention models\n        self._num_layers = num_stage_layers\n\n    def _create_first_layer(self):\n        emb = GPTEmbedding(self._vocab_size, self._embedding_dim, self._seq_length)\n        if self._explicit_fsdp:\n            return FSDP(emb, reshard_after_forward=True, move_params_to_cpu=False, mixed_precision=False,\n                        flatten_parameters=False)\n        else:\n            return emb\n\n    def _create_last_layer(self):\n        classifier = GlueClassification(self._embedding_dim, self._num_classes)\n        if self._explicit_fsdp:\n            return FSDP(classifier, reshard_after_forward=True, move_params_to_cpu=False, mixed_precision=False,\n                        flatten_parameters=False)\n        else:\n            return classifier\n\n    def _create_fsdp_transformer_layer(self):\n        return GPTTransformerFsdpLayer(self._embedding_dim, self._num_heads, self._feedforward_dim,\n                                       use_checkpoint=self._use_checkpoint, explicit_fsdp=self._explicit_fsdp)\n\n\nclass GPTFsdpStageFirst(GPTFsdpStageBase):\n    def __init__(self, args, num_stage_layers, vocab_size, num_classes, device, use_checkpoint=True, explicit_fsdp=True):\n        super(GPTFsdpStageFirst, self).__init__(args, num_stage_layers, vocab_size, num_classes, use_checkpoint,\n                                                explicit_fsdp)\n        self.device = device\n        module_list = [self._create_first_layer()]\n        for _ in range(self._num_layers):\n            module_list.append(self._create_fsdp_transformer_layer())\n        self.model = torch.nn.Sequential(*module_list).to(device)\n\n    def forward(self, x):\n        out = self.model(x)\n        return out\n\n\nclass GPTFsdpStageMiddle(GPTFsdpStageBase):\n    def __init__(self, args, num_stage_layers, vocab_size, num_classes, device, use_checkpoint=True, explicit_fsdp=True):\n        super(GPTFsdpStageMiddle, self).__init__(args, num_stage_layers, vocab_size, num_classes, use_checkpoint,\n                                                 explicit_fsdp)\n        self.device = device\n        module_list = []\n        for _ in range(self._num_layers):\n            module_list.append(self._create_fsdp_transformer_layer())\n        self.model = torch.nn.Sequential(*module_list).to(device)\n\n    def forward(self, x):\n        out = self.model(x)\n        return out\n\n\nclass GPTFsdpStageLast(GPTFsdpStageBase):\n    def __init__(self, args, num_stage_layers, vocab_size, num_classes, device, use_checkpoint=True, explicit_fsdp=True):\n        super(GPTFsdpStageLast, self).__init__(args, num_stage_layers, vocab_size, num_classes, use_checkpoint,\n                                               explicit_fsdp)\n        self.device = device\n        module_list = []\n        for _ in range(self._num_layers):\n            module_list.append(self._create_fsdp_transformer_layer())\n        module_list.append(self._create_last_layer())\n        self.model = torch.nn.Sequential(*module_list).to(device)\n\n    def forward(self, x):\n        out = self.model(x)\n        return out\n"
  },
  {
    "path": "training/modules/dist_gpt_pp_module.py",
    "content": "import numpy as np\nfrom torch import nn\nfrom comm.comm_utils import *\n\nfrom copy import deepcopy\n\n\nclass GPTStageBase(nn.Module):\n    def __init__(self, args, config):\n        super(GPTStageBase, self).__init__()\n        self._to_cpu = (args.dist_backend == \"gloo\")\n        self._embedding_dim = args.embedding_dim  # embedding dimension\n        self._seq_length = args.seq_length\n        # the dimension of the feedforward aws_network model in nn.TransformerEncoder\n        self._feedforward_dim = args.embedding_dim * 4\n        self._num_heads = args.num_heads  # the number of heads in the multi-head attention models\n        self._num_layers = args.num_layers\n        self._layer_begin = get_pipeline_parallel_rank() * args.num_layers\n        self._layer_end = min(self._layer_begin + args.num_layers, args.max_layers)\n        \n        self._task_type = getattr(args, 'task_type', 'language_model')\n        \n        self.load_pretrained_model = args.load_pretrained_model\n        self.model_name = args.model_name\n        self.config = config\n        \n        if hasattr(args, 'model_type'):\n            if args.model_type == \"gpt2\":\n                from .hf_gpt2_modules import GPTEmbeddings, GPTBlock, GPTLMHead\n            elif args.model_type == \"gptj\":\n                from .hf_gptj_modules import GPTEmbeddings, GPTBlock, GPTLMHead\n            elif args.model_type == \"gptneox\":\n                from .hf_gptneox_modules import GPTEmbeddings, GPTBlock, GPTLMHead\n            elif args.model_type == 'llama':\n                from .llama_modules import GPTEmbeddings, GPTBlock, GPTLMHead\n            else:\n                raise Exception(\"unknown\")\n        else:\n            raise Exception(\"!!!! model type not defined\")\n            \n        self._GPTEmbeddings = GPTEmbeddings\n        self._GPTBlock = GPTBlock\n        self._GPTLMHead = GPTLMHead\n\n    def _create_first_layer(self):\n        layer = self._GPTEmbeddings(deepcopy(self.config))\n        if self.load_pretrained_model:\n            print('loading embs')\n            ret = layer.load_state_dict(\n                torch.load(f'{self.model_name}/pytorch_embs.pt'), strict=False\n            )\n            if len(ret.missing_keys):\n                print('The following weight keys are missing:')\n                print(ret.missing_keys)\n            if len(ret.unexpected_keys):\n                print('The following weight keys are unexpected:')\n                print(ret.unexpected_keys)\n        return layer\n\n    def _create_last_layer(self):\n        layer = self._GPTLMHead(deepcopy(self.config))\n        if self.load_pretrained_model:\n            print('loading lm_head')\n            ret = layer.load_state_dict(\n                torch.load(f'{self.model_name}/pytorch_lm_head.pt'), strict=False\n            )\n            if len(ret.missing_keys):\n                print('The following weight keys are missing:')\n                print(ret.missing_keys)\n            if len(ret.unexpected_keys):\n                print('The following weight keys are unexpected:')\n                print(ret.unexpected_keys)\n        return layer\n\n    def _create_transformer_layer(self, layer_idx=0):\n        config = deepcopy(self.config)\n        layer = self._GPTBlock(config, layer_id=layer_idx) # TODO: checkpoint\n        if self.load_pretrained_model:\n            print(f'loading layer {layer_idx}')\n            ret = layer.load_state_dict(\n                torch.load(f'{self.model_name}/pytorch_{layer_idx}.pt'), strict=False\n            )\n            if len(ret.missing_keys):\n                print('The following weight keys are missing:')\n                print(ret.missing_keys)\n            if len(ret.unexpected_keys):\n                print('The following weight keys are unexpected:')\n                print(ret.unexpected_keys)\n        return layer\n    \n\nclass GPTStageFull(GPTStageBase):\n    def __init__(self, args, config, device):\n        super(GPTStageFull, self).__init__(args, config)\n        self.device = device\n        module_list = [self._create_first_layer()]\n        for layer_idx in range(self._layer_begin, self._layer_end):\n            module_list.append(self._create_transformer_layer(layer_idx=layer_idx))\n        if hasattr(args, 'skip_lm_head') and args.skip_lm_head:\n            pass\n        else:\n            module_list.append(self._create_last_layer())\n        self.model = nn.Sequential(*module_list).to(device)\n\n    def forward(self, x, **kargs):\n        for module in self.model:\n            x = module(x, **kargs)\n        return x\n\n\nclass GPTStageFirst(GPTStageBase):\n    def __init__(self, args, config, device):\n        super(GPTStageFirst, self).__init__(args, config)\n        self.device = device\n        module_list = [self._create_first_layer()]\n        for layer_idx in range(self._layer_begin, self._layer_end):\n            module_list.append(self._create_transformer_layer(layer_idx=layer_idx))\n        self.model = nn.Sequential(*module_list).to(device)\n\n    def forward(self, x, **kargs):\n        for module in self.model:\n            x = module(x, **kargs)\n        return x\n        # out = self.model(x.to(self.device), **kargs)\n        # return out.cpu() if self._to_cpu else out\n\n\nclass GPTStageMiddle(GPTStageBase):\n    def __init__(self, args, config, device):\n        super(GPTStageMiddle, self).__init__(args, config)\n        self.device = device\n        module_list = []\n        for layer_idx in range(self._layer_begin, self._layer_end):\n            module_list.append(self._create_transformer_layer(layer_idx=layer_idx))\n        self.model = nn.Sequential(*module_list).to(device)\n\n    def forward(self, x, **kargs):\n        for module in self.model:\n            x = module(x, **kargs)\n        return x\n        # out = self.model(x.to(self.device), **kargs) if self._to_cpu else self.model(x)\n        # return out.cpu() if self._to_cpu else out\n\n\nclass GPTStageLast(GPTStageBase):\n    def __init__(self, args, config, device):\n        super(GPTStageLast, self).__init__(args, config)\n        self.device = device\n        module_list = []\n        for layer_idx in range(self._layer_begin, self._layer_end):\n            module_list.append(self._create_transformer_layer(layer_idx=layer_idx))\n            \n        if hasattr(args, 'skip_lm_head') and args.skip_lm_head:\n            pass\n        else:\n            module_list.append(self._create_last_layer())\n        \n        self.model = nn.Sequential(*module_list).to(device)\n        \n        # self.upscale_last = nn.Linear(args.embedding_dim, 9216).to(device)\n        \n    def forward(self, x, **kargs):\n        for module in self.model:\n            x = module(x, **kargs)\n        \n        return x\n\n#     def forward(self, x, **kargs):\n#         for module in self.model[:-1]:\n#             x = module(x, **kargs)\n#         hid = x\n#         x = self.model[-1](x, **kargs)\n        \n#         hid = self.upscale_last(hid)\n#         loss = torch.nn.functional.mse_loss(hid, kargs['teacher_hidden_states'])\n#         print(loss.item())\n#         return x, loss\n    "
  },
  {
    "path": "training/modules/hf_gpt2_modules.py",
    "content": "import torch\nimport math\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2Attention as _GPT2Attention\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2MLP as _GPT2MLP\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2Block as _GPT2Block\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2Model as _GPT2Model\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as _GPT2LMHeadModel\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification as _GPT2ForSequenceClassification\nfrom transformers.models.gpt2.configuration_gpt2 import GPT2Config as GPTConfig\nfrom typing import Optional, Tuple, Union\n\n\n# @torch.jit.script\ndef gpt_loss_func(input, target):\n    lm_logits, labels = input, target\n    shift_logits = lm_logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n    loss = functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n    return loss\n\n\nclass GPTEmbeddings(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        \n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)\n        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n        self.drop = nn.Dropout(config.embd_pdrop)\n        \n    def forward(self, input_ids, **kargs):\n        \n        device = input_ids.device\n        \n        # input ids\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n        batch_size = input_ids.shape[0]\n        \n        # position ids\n        position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)\n        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n            \n        inputs_embeds = self.wte(input_ids)\n        position_embeds = self.wpe(position_ids)\n        hidden_states = inputs_embeds + position_embeds\n\n        hidden_states = self.drop(hidden_states)\n        \n        return hidden_states\n    \nclass GPTAttention(_GPT2Attention):\n    \n    def _attn(self, query, key, value, attention_mask=None, head_mask=None, prefix_masks=None):\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        if self.scale_attn_weights:\n            attn_weights = attn_weights / torch.tensor(\n                value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device\n            )\n\n        # Layer-wise attention scaling\n        if self.scale_attn_by_inverse_layer_idx:\n            attn_weights = attn_weights / float(self.layer_idx + 1)\n\n        if not self.is_cross_attention:\n            # if only \"normal\" attention layer implements causal mask\n            query_length, key_length = query.size(-2), key.size(-2)\n            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)\n            mask_value = torch.finfo(attn_weights.dtype).min\n            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.\n            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`\n            if prefix_masks is not None:\n                bsz = query.size(0)\n                causal_mask = causal_mask.repeat(bsz, 1, 1, 1) # (bsz, 1, src_len, tgt_len)\n                causal_mask = causal_mask.permute(0, 3, 1, 2) # (bsz, tgt_len, 1, src_len)\n                causal_mask[prefix_masks.bool()] = 1\n                causal_mask = causal_mask.permute(0, 2, 3, 1) # (bsz, 1, src_len, tgt_len)\n            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)\n            attn_weights = torch.where(causal_mask, attn_weights, mask_value)\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n    \n    \n    def forward(\n        self,\n        hidden_states: Optional[Tuple[torch.FloatTensor]],\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = False,\n        output_attentions: Optional[bool] = False,\n        prefix_masks = None,\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:\n        if encoder_hidden_states is not None:\n            if not hasattr(self, \"q_attn\"):\n                raise ValueError(\n                    \"If class is used as cross attention, the weights `q_attn` have to be defined. \"\n                    \"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.\"\n                )\n\n            query = self.q_attn(hidden_states)\n            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)\n            attention_mask = encoder_attention_mask\n        else:\n            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)\n\n        query = self._split_heads(query, self.num_heads, self.head_dim)\n        key = self._split_heads(key, self.num_heads, self.head_dim)\n        value = self._split_heads(value, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n\n        if use_cache is True:\n            present = (key, value)\n        else:\n            present = None\n\n        if self.reorder_and_upcast_attn:\n            attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)\n        else:\n            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, prefix_masks=prefix_masks)\n\n        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)\n        attn_output = self.c_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs  # a, present, (attentions)\n    \n\nclass GPTBlock(_GPT2Block):\n    def __init__(self, config, layer_idx=None, use_checkpoint=True):\n        super(_GPT2Block, self).__init__()\n        hidden_size = config.hidden_size\n        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size\n\n        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.attn = GPTAttention(config, layer_idx=layer_idx)\n        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.mlp = _GPT2MLP(inner_dim, config)\n        \n        self.config = config\n        self.use_checkpoint = use_checkpoint\n        \n        def attn_res(x: torch.Tensor, prefix_masks: torch.Tensor) -> torch.Tensor:\n            res = x\n            x = self.ln_1(x)\n            x = self.attn(x, prefix_masks=prefix_masks)[0]\n            return x + res\n        self.attn_res = attn_res\n        \n        def mlp_res(x: torch.Tensor) -> torch.Tensor:\n            res = x\n            x = self.ln_2(x)\n            x = self.mlp(x)\n            return x + res\n        self.mlp_res = mlp_res\n        \n\n    def forward(self, x: torch.Tensor, prefix_masks=None, **kargs) -> torch.Tensor:\n        \n        if not self.training:\n            x = self.attn_res(x, prefix_masks=prefix_masks)\n            x = self.mlp_res(x)\n            return x\n        \n        if self.use_checkpoint:\n            x.requires_grad_(True)\n            x = checkpoint(self.attn_res, x, prefix_masks)\n        else:\n            x = self.attn_res(x, prefix_masks=prefix_masks)\n        if self.use_checkpoint:\n            x.requires_grad_(True)\n            x = checkpoint(self.mlp_res, x)\n        else:\n            x = self.mlp_res(x)\n        return x\n    \n    \nclass GPTModel(_GPT2Model):\n    def __init__(self, config):\n        super(_GPT2Model, self).__init__(config)\n\n        self.embed_dim = config.hidden_size\n        \n        emb_layer = GPTEmbeddings(config)\n        self.wte = emb_layer.wte\n        self.wpe = emb_layer.wpe\n\n        self.drop = nn.Dropout(config.embd_pdrop)\n        self.h = nn.ModuleList([GPTBlock(config, layer_idx=i, use_checkpoint=True) for i in range(config.num_hidden_layers)])\n        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n        \n    def forward(self, input_ids, attention_mask=None, **kargs):\n        \n        device = input_ids.device\n        \n        # input ids\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n        batch_size = input_shape[0]\n        \n        # position ids\n        position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)\n        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n            \n        inputs_embeds = self.wte(input_ids)\n        position_embeds = self.wpe(position_ids)\n        hidden_states = inputs_embeds + position_embeds\n\n        hidden_states = self.drop(hidden_states)\n\n        hidden_states_tuple = tuple()\n        for layer in self.h:\n            hidden_states_tuple = hidden_states_tuple + (hidden_states,)\n            hidden_states = layer(hidden_states)\n        hidden_states = self.ln_f(hidden_states)\n        hidden_states_tuple = hidden_states_tuple + (hidden_states,)\n        \n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=None,\n            hidden_states=hidden_states_tuple,\n            attentions=None,\n            cross_attentions=None,\n        )\n    \nclass GPTLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n        \n    def forward(self, x, **kargs):\n        x = self.ln_f(x)\n        x = self.lm_head(x)\n        return x\n    \nclass GPTLMHeadModel(_GPT2LMHeadModel):\n\n    def __init__(self, config):\n        super(_GPT2LMHeadModel, self).__init__(config)\n        self.transformer = GPTModel(config)\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n        # ln_f will be calculated in self.transformer\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n        \n        # Initialize weights and apply final processing\n        self.post_init()\n        \nclass GPTClassificationHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)\n        \n    def forward(self, hidden_states, input_ids=None):\n        \n        batch_size, sequence_length = hidden_states.shape[:2]\n        if input_ids is not None:\n            sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1\n        else:\n            sequence_lengths = -1\n        \n        pooled_hidden_states = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths]\n        \n        logits = self.score(self.ln_f(pooled_hidden_states))\n        \n        return logits\n        \nclass GPTForClassification(_GPT2ForSequenceClassification):\n    \n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.transformer = GPTModel(config)\n        self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)\n\n        # Model parallel\n        self.model_parallel = False\n        self.device_map = None\n\n        # Initialize weights and apply final processing\n        self.post_init()\n        \n#     def forward(self, input_ids, labels=None):\n        \n#         ret = self.transformer(input_ids)\n#         pool_hidden_state = ret.last_hidden_state[:, -1]\n        \n#         logits = self.score(pool_hidden_state)\n        \n#         loss = functional.cross_entropy(logits, labels)\n        \n#         return loss\n        "
  },
  {
    "path": "training/modules/hf_gptj_modules.py",
    "content": "import os\nimport torch\nimport math\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\nfrom transformers.models.gptj.modeling_gptj import ACT2FN\nfrom transformers.models.gptj.modeling_gptj import GPTJAttention as _GPTJAttention\nfrom transformers.models.gptj.modeling_gptj import GPTJMLP as _GPTJMLP\nfrom transformers.models.gptj.modeling_gptj import GPTJBlock as _GPTJBlock\nfrom transformers.models.gptj.modeling_gptj import GPTJModel as _GPTJModel\nfrom transformers.models.gptj.modeling_gptj import fixed_pos_embedding\nfrom transformers.models.gptj.configuration_gptj import GPTJConfig as GPTConfig\nfrom transformers.models.gptj.modeling_gptj import fixed_pos_embedding, rotate_every_two, apply_rotary_pos_emb\n\n\n# @torch.jit.script\ndef gpt_loss_func(input, target):\n    lm_logits, labels = input, target\n    shift_logits = lm_logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n    loss = functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n    return loss\n\n# put things on GPU to avoid high CPU usage\ndef fixed_pos_embedding(x, seq_dim=1, seq_len=None):\n    dim = x.shape[-1]\n    if seq_len is None:\n        seq_len = x.shape[seq_dim]\n    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=x.device) / dim))\n    sinusoid_inp = torch.einsum(\"i , j -> i j\", torch.arange(seq_len, device=x.device), inv_freq).float()\n    return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)\n    \n            \nclass GPTJMLP(_GPTJMLP):\n    def __init__(self, intermediate_size, config, device='cpu'):  # in MLP: intermediate_size= 4 * embed_dim\n        super(_GPTJMLP, self).__init__()\n        embed_dim = config.n_embd\n\n        self.fc_in = nn.Linear(embed_dim, intermediate_size, device=device)\n        self.fc_out = nn.Linear(intermediate_size, embed_dim, device=device)\n\n        self.act = ACT2FN[config.activation_function]\n        self.dropout = nn.Dropout(config.resid_pdrop)\n\n\nclass GPTJAttention(_GPTJAttention):\n    \n    def __init__(self, config, device='cpu'):\n        super(_GPTJAttention, self).__init__()\n\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(\n                1, 1, max_positions, max_positions\n            ),\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e9))\n\n        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n        self.resid_dropout = nn.Dropout(config.resid_pdrop)\n\n        self.embed_dim = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_attention_heads\n        if self.head_dim * self.num_attention_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and\"\n                f\" `num_attention_heads`: {self.num_attention_heads}).\"\n            )\n        self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())\n\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, device=device)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, device=device)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, device=device)\n        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, device=device)\n        self.rotary_dim = None\n        if config.rotary_dim is not None:\n            self.rotary_dim = config.rotary_dim\n            \n    def _attn(\n        self,\n        query,\n        key,\n        value,\n        attention_mask=None,\n        head_mask=None,\n        prefix_masks=None,\n    ):\n\n        # compute causal mask from causal mask buffer\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)\n        \n        if prefix_masks is not None:\n            bsz = query.size(0)\n            causal_mask = causal_mask.repeat(bsz, 1, 1, 1) # (bsz, 1, src_len, tgt_len)\n            causal_mask = causal_mask.permute(0, 3, 1, 2) # (bsz, tgt_len, 1, src_len)\n            causal_mask[prefix_masks.bool()] = 1\n            causal_mask = causal_mask.permute(0, 2, 3, 1) # (bsz, 1, src_len, tgt_len)\n\n        # Keep the attention weights computation in fp32 to avoid overflow issues\n        query = query.to(torch.float32)\n        key = key.to(torch.float32)\n\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        mask_value = torch.finfo(attn_weights.dtype).min\n        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.\n        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`\n        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)\n        attn_weights = torch.where(causal_mask, attn_weights, mask_value)\n\n        attn_weights = attn_weights / self.scale_attn\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n        attn_weights = attn_weights.to(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        layer_past=None,\n        head_mask=None,\n        offset=None,\n        use_cache=False,\n        output_attentions=False,\n        prefix_masks=None,\n    ):\n\n        query = self.q_proj(hidden_states)\n        key = self.k_proj(hidden_states)\n        value = self.v_proj(hidden_states)\n\n        query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)\n        key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)\n        value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)\n\n        seq_len = key.shape[1]\n\n        if layer_past is not None:\n            if offset is None:\n                offset = layer_past[0].shape[-2]\n            seq_len += layer_past[0].shape[-2]\n            \n        if offset is None:\n            offset = 0\n\n        if self.rotary_dim is not None:\n            k_rot = key[:, :, :, : self.rotary_dim]\n            k_pass = key[:, :, :, self.rotary_dim :]\n\n            q_rot = query[:, :, :, : self.rotary_dim]\n            q_pass = query[:, :, :, self.rotary_dim :]\n\n            sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)\n            k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)\n            q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)\n\n            key = torch.cat([k_rot, k_pass], dim=-1)\n            query = torch.cat([q_rot, q_pass], dim=-1)\n        else:\n            sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)\n            key = apply_rotary_pos_emb(key, sincos, offset=offset)\n            query = apply_rotary_pos_emb(query, sincos, offset=offset)\n\n        key = key.permute(0, 2, 1, 3)\n        query = query.permute(0, 2, 1, 3)\n\n        if layer_past is not None:\n            past_key = layer_past[0]\n            past_value = layer_past[1]\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n\n        if use_cache is True:\n            present = (key, value)\n        else:\n            present = None\n\n        # compute self-attention: V x Softmax(QK^T)\n        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, prefix_masks=prefix_masks)\n\n        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)\n        attn_output = self.out_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs  # a, present, (attentions)\n\n\nclass GPTEmbeddings(nn.Module):\n    def __init__(self, config, device='cpu'):\n        super().__init__()\n        \n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim, device=device)\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, 'pytorch_embs.pt',\n            )))\n        except:\n            print(f'Cannot load from <model_path>. The model is randomly initialized.')\n        return module\n        \n    def forward(self, input_ids, *args, **kargs):\n        \n        # input ids\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n        hidden_states = self.wte(input_ids)\n        return hidden_states\n    \n\nclass GPTBlock(_GPTJBlock):\n    def __init__(self, config, *args, use_checkpoint=True, device='cpu', **kargs):\n        super(_GPTJBlock, self).__init__()\n        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd\n        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon, device=device)\n        self.attn = GPTJAttention(config, device=device)\n        self.mlp = GPTJMLP(inner_dim, config, device=device)\n        self.config = config\n        self.use_checkpoint = use_checkpoint\n\n        def block_forward(x: torch.Tensor, attention_mask: torch.Tensor, prefix_masks: torch.Tensor) -> torch.Tensor:\n            res = x\n            x = self.ln_1(x)\n            x_a = self.attn(x, prefix_masks=prefix_masks, attention_mask=attention_mask)[0]\n            x_m = self.mlp(x)\n            return res + x_a + x_m\n        \n        self.block_forward = block_forward\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None, layer_index=None):\n        assert layer_index is not None\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, f'pytorch_{layer_index}.pt',\n            )))\n        except Exception as e:\n            print('Cannot load from <model_name>. The model is randomly initialized.')\n            \n        return module\n\n    def forward(self, x: torch.Tensor, prefix_masks=None, layer_past=None, mask=None, skip_ln=False, **kargs) -> torch.Tensor:\n        \n        if mask is not None:\n            # bool -> float\n            attention_mask = (1e4)*(mask[:, None, None, :]-1.0)\n        else:\n            attention_mask = None\n            \n        if mask is None:\n            if layer_past is not None:\n                offset = layer_past[0].size(2)\n            else:\n                offset = 0\n        else:\n            # masked tokens\n            offset = (mask-1).sum(-1, keepdims=False)\n            if layer_past is not None:\n                offset += layer_past[0].size(2)\n                \n        if self.training:\n            \n            if self.use_checkpoint:\n                x.requires_grad_(True)\n                x = checkpoint(self.block_forward, x, attention_mask, prefix_masks)\n            else:\n                x = self.block_forward(x, prefix_masks=prefix_masks)\n            \n            return x\n           \n        else:\n            res = x\n            if not skip_ln:\n                x = self.ln_1(x)\n            x_a = self.attn(x, use_cache=False, layer_past=layer_past, attention_mask=attention_mask, offset=offset, prefix_masks=prefix_masks)[0]\n            x_m = self.mlp(x)\n            return x_a + x_m + res\n    \n    \nclass GPTLMHead(nn.Module):\n    def __init__(self, config, device='cpu'):\n        super().__init__()\n        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon, device=device)\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, device=device)\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, 'pytorch_lm_head.pt',\n            )))\n        except:\n            print('Cannot load from <model_name>. The model is randomly initialized.')\n        return module\n        \n    def forward(self, x, **kargs):\n        x = self.ln_f(x)\n        x = self.lm_head(x)\n        return x\n"
  },
  {
    "path": "training/modules/hf_gptneox_modules.py",
    "content": "import os\nimport torch\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention as _GPTNeoXAttention\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXMLP\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer as _GPTNeoXBlock\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel as _GPTNeoXModel\nfrom transformers.models.gpt_neox.configuration_gpt_neox import GPTNeoXConfig as GPTConfig\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding\n\n\ntry:\n    from flash_attn.flash_attention import FlashAttention\n    flash_attn_installed = True\n    print('>>>>> using flash attention')\nexcept ImportError:\n    flash_attn_installed = False\n\ntry:\n    from fav2.fav2_interface import flash_attn_qkvpacked_func as fav2_qkvpacked_func\n    flash_attn_v2_installed = True\n    print('>>>>> using flash attention v2')\n\n    class FlashAttentionV2(nn.Module):\n        \"\"\"Implement the scaled dot product attention with softmax.\n        Arguments\n        ---------\n            softmax_scale: The temperature to use for the softmax attention.\n                          (default: 1/sqrt(d_keys) where d_keys is computed at\n                          runtime)\n            attention_dropout: The dropout rate to apply to the attention\n                               (default: 0.0)\n        \"\"\"\n        def __init__(self, softmax_scale=None, attention_dropout=0.0):\n            super().__init__()\n            self.softmax_scale = softmax_scale\n            self.dropout_p = attention_dropout\n    \n        def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,\n                    max_s=None, need_weights=False):\n            \"\"\"Implements the multihead softmax attention.\n            Arguments\n            ---------\n                qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None\n                    if unpadded: (nnz, 3, h, d)\n                key_padding_mask: a bool tensor of shape (B, S)\n            \"\"\"\n            assert not need_weights\n            assert qkv.dtype in [torch.float16, torch.bfloat16]\n            assert qkv.is_cuda\n            assert key_padding_mask is None\n            assert cu_seqlens is None\n            assert max_s is None\n\n            output = fav2_qkvpacked_func(\n                qkv, self.dropout_p if self.training else 0.0, \n                softmax_scale=self.softmax_scale, causal=causal\n            )\n    \n            return output, None\nexcept ImportError:\n    flash_attn_v2_installed = False\n\n    \n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, offset=0):\n    if isinstance(offset, torch.Tensor):\n        realidx = torch.arange(q.shape[-2], device=q.device).view(\n            1, q.shape[-2]) + offset[:, None]\n        cos = cos.squeeze(0).squeeze(0)[realidx].view(offset.size(0),\n                                                      1, q.shape[-2],\n                                                      cos.size(-1))\n        sin = sin.squeeze(0).squeeze(0)[realidx].view(offset.size(0),\n                                                      1, q.shape[-2],\n                                                      sin.size(-1))\n    else:\n        cos = cos[..., offset : q.shape[-2] + offset, :]\n        sin = sin[..., offset : q.shape[-2] + offset, :]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass GPTNeoXAttention(_GPTNeoXAttention):\n    \n    def __init__(self, config):\n        super(_GPTNeoXAttention, self).__init__()\n        self.num_attention_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_attention_heads\n        self.rotary_ndims = int(self.head_size * config.rotary_pct)\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(\n                1, 1, max_positions, max_positions\n            ),\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e9))\n        self.rotary_emb = GPTNeoXRotaryEmbedding(\n            self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base\n        )\n        self.register_buffer(\n            \"norm_factor\",\n            torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),\n            persistent=False,\n        )\n        self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n\n        if flash_attn_v2_installed:\n            self.flash_attn = FlashAttentionV2(softmax_scale=1.0/self.norm_factor, attention_dropout = 0)\n        elif flash_attn_installed:\n            self.flash_attn = FlashAttention(softmax_scale=1.0/self.norm_factor, attention_dropout = 0)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask,\n        head_mask=None,\n        layer_past=None,\n        use_cache=False,\n        offset=None,\n        output_attentions=False,\n    ):\n        \n        bsz, tgt_len, _ = hidden_states.shape\n        \n        has_layer_past = layer_past is not None\n\n        # Compute QKV\n        # Attention heads [batch, seq_len, hidden_size]\n        #   --> [batch, seq_len, (np * 3 * head_size)]\n        qkv = self.query_key_value(hidden_states)\n\n        # [batch, seq_len, (num_heads * 3 * head_size)]\n        #   --> [batch, seq_len, num_heads, 3 * head_size]\n        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads,\n                                           3 * self.head_size)\n        qkv = qkv.view(*new_qkv_shape)\n\n        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]\n        query = qkv[..., :self.head_size].permute(0, 2, 1, 3)\n        key = qkv[..., self.head_size:2 * self.head_size].permute(0, 2, 1, 3)\n        value = qkv[..., 2 * self.head_size:].permute(0, 2, 1, 3)\n\n        # Compute rotary embeddings on rotary_ndims\n        query_rot = query[..., :self.rotary_ndims]\n        query_pass = query[..., self.rotary_ndims:]\n        key_rot = key[..., :self.rotary_ndims]\n        key_pass = key[..., self.rotary_ndims:]\n\n        # Compute token offset for rotary embeddings (when decoding)\n        seq_len = key.shape[-2]\n\n        if layer_past is not None:\n            if offset is None:\n                offset = layer_past[0].shape[-2]\n            seq_len += layer_past[0].shape[-2]\n\n        if offset is None:\n            offset = 0\n\n        cos, sin = self.rotary_emb(value, seq_len=seq_len)\n        query, key = apply_rotary_pos_emb(query_rot,\n                                          key_rot,\n                                          cos,\n                                          sin,\n                                          offset=offset)\n        query = torch.cat((query, query_pass), dim=-1)\n        key = torch.cat((key, key_pass), dim=-1)\n\n        # Cache QKV values\n        if has_layer_past:\n            past_key = layer_past[0]\n            past_value = layer_past[1]\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n        present = None if use_cache else (key, value)\n\n        # Compute attention    \n        if flash_attn_installed or flash_attn_v2_installed:\n            \n            query = query.permute(0, 2, 1, 3).half()\n            key = key.permute(0, 2, 1, 3).half()\n            value = value.permute(0, 2, 1, 3).half()\n            qkv = torch.stack(\n                [\n                    query.reshape((bsz, tgt_len, self.num_attention_heads, self.head_size)),\n                    key.reshape((bsz, tgt_len, self.num_attention_heads, self.head_size)),\n                    value.reshape((bsz, tgt_len, self.num_attention_heads, self.head_size)),\n                ],\n                dim=2\n            )\n\n            attn_weights = None\n            attn_output, _ = self.flash_attn(qkv, causal=True)\n            attn_output = attn_output.view(bsz, tgt_len, self.num_attention_heads * self.head_size)\n        else:\n            attn_output, attn_weights = self._attn(query, key, value,\n                                                   attention_mask, head_mask)\n            # Reshape outputs\n            attn_output = self._merge_heads(attn_output, self.num_attention_heads,\n                                            self.head_size)\n        attn_output = self.dense(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights, )\n\n        return outputs\n\n    # fix nan problem\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]\n        # compute causal mask from causal mask buffer\n        batch_size, num_attention_heads, query_length, attn_head_size = query.size(\n        )\n        key_length = key.size(-2)\n\n        causal_mask = self.bias[:, :, key_length -\n                                query_length:key_length, :key_length].bool()\n\n        query = query.view(batch_size * num_attention_heads, query_length,\n                           attn_head_size)\n        key = key.view(batch_size * num_attention_heads, key_length,\n                       attn_head_size)\n        attn_scores = torch.zeros(  # empty sometimes gives nan\n            batch_size * num_attention_heads,\n            query_length,\n            key_length,\n            dtype=query.dtype,\n            device=key.device,\n        )\n        attn_scores = torch.baddbmm(\n            attn_scores,\n            query,\n            key.transpose(1, 2),\n            beta=0.0,\n            alpha=(1.0 / self.norm_factor),\n        )\n        attn_scores = attn_scores.view(batch_size, num_attention_heads,\n                                       query_length, key_length)\n\n        mask_value = torch.finfo(attn_scores.dtype).min\n        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.\n        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`\n        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(\n            attn_scores.device)\n        attn_scores = torch.where(causal_mask, attn_scores, mask_value)\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_scores = attn_scores + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_scores, dim=-1)\n        attn_weights = attn_weights.to(value.dtype)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n        return attn_output, attn_weights\n\n\nclass GPTEmbeddings(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.embed_in = nn.Embedding(config.vocab_size, self.embed_dim)\n\n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        module = cls(config).eval()\n        try:\n            module.load_state_dict(\n                torch.load(os.path.join(\n                    model_path,\n                    'pytorch_embs.pt',\n                )))\n        except:\n            print(\n                f'Cannot load from <model_path>. The model is randomly initialized.'\n            )\n        return module\n\n    def forward(self, input_ids, *args, **kargs):\n\n        # input ids\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n        hidden_states = self.embed_in(input_ids)\n        return hidden_states\n\n\nclass GPTBlock(_GPTNeoXBlock):\n\n    def __init__(self, config, *args, use_checkpoint=True, **kargs):\n        super(_GPTNeoXBlock, self).__init__()\n        self.input_layernorm = nn.LayerNorm(config.hidden_size,\n                                            eps=config.layer_norm_eps)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,\n                                                     eps=config.layer_norm_eps)\n        self.attention = GPTNeoXAttention(config)\n        self.mlp = GPTNeoXMLP(config)\n        self.config = config\n        self.use_checkpoint = use_checkpoint\n\n        def block_forward(x: torch.Tensor, attention_mask: torch.Tensor,\n                          prefix_masks: torch.Tensor) -> torch.Tensor:\n            res = x\n            \"\"\"\n            To be compatible with https://github.com/huggingface/transformers/blob/a0ae2310ec46a2c592950babc85cf02e325bf6a7/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L336-L347\n            \"\"\"\n            layer_norm_out = self.input_layernorm(x)\n            attention_layer_output = self.attention(layer_norm_out, attention_mask=attention_mask)\n            attn_output = attention_layer_output[0]\n            # outputs = attention_layer_output[1:]\n\n            if hasattr(self.config, 'use_parallel_residual') and self.config.use_parallel_residual:\n                # x = x + attn(ln1(x)) + mlp(ln2(x))\n                # x_a = attn_output, \n                mlp_out = self.mlp(self.post_attention_layernorm(x))\n                return res + attn_output + mlp_out\n            else:\n                # x = x + attn(ln1(x)) \n                # x = x + mlp(ln2(x))\n                attn_output = attn_output + x\n                mlp_out = self.mlp(self.post_attention_layernorm(attn_output))\n                return attn_output + mlp_out\n\n        self.block_forward = block_forward\n\n    @classmethod\n    def from_pretrained(cls, model_path, config=None, layer_index=None):\n        assert layer_index is not None\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        module = cls(config).eval().half()\n        try:\n            module.load_state_dict(\n                torch.load(\n                    os.path.join(\n                        model_path,\n                        f'pytorch_{layer_index}.pt',\n                    )))\n        except Exception as e:\n            print(\n                'Cannot load from <model_name>. The model is randomly initialized.'\n            )\n        return module\n\n    def forward(self,\n                x: torch.Tensor,\n                layer_past=None,\n                mask=None,\n                **kargs) -> torch.Tensor:\n\n        if mask is not None:\n            # bool -> float\n            attention_mask = 1e9 * (mask[:, None, None, :] - 1)\n        else:\n            attention_mask = None\n\n        if mask is None:\n            if layer_past is not None:\n                offset = layer_past[0].size(2)\n            else:\n                offset = 0\n        else:\n            # masked tokens\n            offset = (mask - 1).sum(-1, keepdims=False)\n            if layer_past is not None:\n                offset += layer_past[0].size(2)\n\n        if self.training:\n\n            if self.use_checkpoint:\n                x.requires_grad_(True)\n                x = checkpoint(self.block_forward, x, attention_mask, None)\n            else:\n                x = self.block_forward(x, prefix_masks=prefix_masks)\n\n            return x\n\n        else:\n\n            residual = x\n            ln_out = self.input_layernorm(x)\n            attention_layer_outputs = self.attention(\n                ln_out,\n                attention_mask=attention_mask,\n            )\n            attn_output = attention_layer_outputs[\n                0]  # output_attn: a, present, ...\n\n            mlp_output = self.mlp(self.post_attention_layernorm(x))\n            x = mlp_output + attn_output + residual\n\n            return x\n\n\nclass GPTLMHead(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.final_layer_norm = nn.LayerNorm(config.hidden_size,\n                                             eps=config.layer_norm_eps)\n        self.embed_out = nn.Linear(config.hidden_size,\n                                   config.vocab_size,\n                                   bias=False)\n\n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        module = cls(config).eval()\n        try:\n            module.load_state_dict(\n                torch.load(os.path.join(\n                    model_path,\n                    'pytorch_lm_head.pt',\n                )))\n        except:\n            print(\n                'Cannot load from <model_name>. The model is randomly initialized.'\n            )\n        return module\n\n    def forward(self, x, *args, **kargs):\n        x = self.final_layer_norm(x)\n        x = self.embed_out(x)\n        return x\n"
  },
  {
    "path": "training/modules/hf_opt_modules.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport os\nimport torch\nfrom torch import nn\nfrom torch.utils.checkpoint import checkpoint\nimport torch.nn.functional as F\nfrom transformers.models.opt.modeling_opt import ACT2FN\nfrom transformers.models.opt.modeling_opt import OPTDecoderLayer\nfrom transformers.models.opt.modeling_opt import OPTAttention as _OPTAttention\nfrom transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding\nfrom transformers.models.opt.configuration_opt import OPTConfig as GPTConfig\n\n\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, \n    dtype: torch.dtype,\n    device: torch.device,\n    past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(float(\"-inf\")), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), \n             mask], dim=-1\n        )\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\ndef _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):\n    # create causal mask\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask(\n            input_shape, inputs_embeds.dtype, inputs_embeds.device,\n            past_key_values_length=past_key_values_length\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(\n            attention_mask, inputs_embeds.dtype,tgt_len=input_shape[-1])\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\nclass GPTEmbeddings(nn.Module):\n    def __init__(self, config, device='cpu'):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx, device=device)\n        self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False, device=device)\n        else:\n            self.project_in = None\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, 'pytorch_embs.pt',\n            )))\n        except:\n            print('Cannot load from <model_name>. The model is randomly initialized.')\n        return module\n\n    def forward(self, input_ids, past_layer=None, mask=None, **kargs):\n        \n        if mask is None:\n            if past_layer is not None:\n                past_length = past_layer[0].size(2)\n            else:\n                past_length = 0\n        else:\n            # masked tokens\n            past_length = (mask-1).sum(-1, keepdims=True)\n            if past_layer is not None:\n                past_length += past_layer[0].size(2)\n                \n        device = input_ids.device\n        # input ids\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n        batch_size = input_ids.shape[0]\n\n        inputs_embeds = self.embed_tokens(input_ids)\n        \n        # attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)\n        # position_embeds = self.embed_positions(attention_mask, past_length)\n        # position ids\n        position_ids = torch.arange(\n            0, input_shape[-1], dtype=torch.long, device=device)\n        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n        position_ids = position_ids + past_length + self.embed_positions.offset\n        position_ids[position_ids<0] = 0\n        \n        position_embeds = F.embedding(\n            position_ids, self.embed_positions.weight, self.embed_positions.padding_idx, self.embed_positions.max_norm,\n            self.embed_positions.norm_type, self.embed_positions.scale_grad_by_freq, self.embed_positions.sparse)\n        \n        if self.project_in is not None:\n            inputs_embeds = self.project_in(inputs_embeds)\n        \n        hidden_states = inputs_embeds + position_embeds\n\n        # hidden_states = self.drop(hidden_states)\n\n        return hidden_states\n    \n\nclass OPTAttention(_OPTAttention):\n    def __init__(\n        self,\n        embed_dim: int,\n        num_heads: int,\n        dropout: float = 0.0,\n        is_decoder: bool = False,\n        bias: bool = True,\n        device='cpu',\n    ):\n        super(_OPTAttention, self).__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n\n        if (self.head_dim * num_heads) != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}\"\n                f\" and `num_heads`: {num_heads}).\"\n            )\n        self.scaling = self.head_dim**-0.5\n        self.is_decoder = is_decoder\n\n        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)\n        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)\n        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)\n        \n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        key_value_states: Optional[torch.Tensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        # if key_value_states are provided this layer is used as a cross-attention layer\n        # for the decoder\n        is_cross_attention = key_value_states is not None\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n        query_states = self.q_proj(hidden_states) * self.scaling\n        # get key, value proj\n        if is_cross_attention and past_key_value is not None:\n            # reuse k,v, cross_attentions\n            key_states = past_key_value[0]\n            value_states = past_key_value[1]\n        elif is_cross_attention:\n            # cross_attentions\n            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)\n            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)\n        elif past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n        else:\n            # self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n\n        if self.is_decoder:\n            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.\n            # Further calls to cross_attention layer can then reuse all cross-attention\n            # key/value_states (first \"if\" case)\n            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of\n            # all previous decoder key/value_states. Further calls to uni-directional self-attention\n            # can concat previous decoder key/value_states to current projected key/value_states (third \"elif\" case)\n            # if encoder bi-directional self-attention `past_key_value` is always `None`\n            past_key_value = (key_states, value_states)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask\n            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n            dtype_attn_weights = attn_weights.dtype\n\n        # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437\n        if dtype_attn_weights == torch.float16:\n            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)\n        else:\n            attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if layer_head_mask is not None:\n            if layer_head_mask.size() != (self.num_heads,):\n                raise ValueError(\n                    f\"Head mask for a single layer should be of size {(self.num_heads,)}, but is\"\n                    f\" {layer_head_mask.size()}\"\n                )\n            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if output_attentions:\n            # this operation is a bit awkward, but it's required to\n            # make sure that attn_weights keeps its gradient.\n            # In order to do so, attn_weights have to be reshaped\n            # twice and have to be reused in the following\n            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)\n        else:\n            attn_weights_reshaped = None\n\n        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)\n        attn_output = attn_output.transpose(1, 2)\n\n        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be\n        # partitioned aross GPUs when using tensor-parallelism.\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights_reshaped, past_key_value\n\n\nclass GPTBlock(OPTDecoderLayer):\n    def __init__(self, config, *args, use_checkpoint=True, device='cpu', **kargs):\n        # super().__init__(config=config, *args, **kargs)\n        super(OPTDecoderLayer, self).__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = OPTAttention(\n            embed_dim=self.embed_dim,\n            num_heads=config.num_attention_heads,\n            dropout=config.attention_dropout,\n            is_decoder=True,\n            device=device,\n        )\n        self.do_layer_norm_before = config.do_layer_norm_before\n        self.dropout = config.dropout\n        self.activation_fn = ACT2FN[config.activation_function]\n\n        self.activation_dropout = config.activation_dropout\n\n        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, device=device)\n        self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, device=device)\n        self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, device=device)\n        self.final_layer_norm = nn.LayerNorm(self.embed_dim, device=device)\n        \n        self.config = config\n        self.use_checkpoint = use_checkpoint\n        \n        def attn_res(hidden_states: torch.Tensor, attention_mask=None) -> torch.Tensor:\n            residual = hidden_states\n            if self.do_layer_norm_before:\n                hidden_states = self.self_attn_layer_norm(hidden_states)\n                \n            # Self Attention\n            hidden_states, _, present = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n            )\n            hidden_states = residual + hidden_states\n\n            # 350m applies layer norm AFTER attention\n            if not self.do_layer_norm_before:\n                hidden_states = self.self_attn_layer_norm(hidden_states)\n\n            return hidden_states\n        \n        self.attn_res = attn_res\n        \n        def mlp_res(hidden_states: torch.Tensor) -> torch.Tensor:\n            # Fully Connected\n            hidden_states_shape = hidden_states.shape\n            hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))\n            residual = hidden_states\n\n            # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n            if self.do_layer_norm_before:\n                hidden_states = self.final_layer_norm(hidden_states)\n\n            hidden_states = self.fc1(hidden_states)\n            hidden_states = self.activation_fn(hidden_states)\n\n            hidden_states = self.fc2(hidden_states)\n\n            hidden_states = (residual + hidden_states).view(hidden_states_shape)\n            return hidden_states\n        \n        self.mlp_res = mlp_res\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None, layer_index=None):\n        assert layer_index is not None\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, f'pytorch_{layer_index}.pt',\n            )))\n        except:\n            print('Cannot load from <model_name>. The model is randomly initialized.')\n        return module\n\n    def forward(self, x: torch.Tensor, layer_past=None, mask=None, *args, **kargs) -> torch.Tensor:\n        \n        if layer_past is not None:\n            past_length = layer_past[0].size(2)\n        else:\n            past_length = 0\n        if mask is None:\n            mask = torch.ones((x.size(0), x.size(1)+past_length), \n                dtype=torch.bool, device=x.device)\n        attention_mask = _prepare_decoder_attention_mask(\n            mask, x.shape[:2], x, past_length\n        )\n        \n        if self.training:\n            \n            if self.use_checkpoint:\n                x.requires_grad_(True)\n                x = checkpoint(self.attn_res, x, attention_mask)\n            else:\n                x = self.attn_res(x, attention_mask)\n\n            if self.use_checkpoint:\n                x.requires_grad_(True)\n                x = checkpoint(self.mlp_res, x)\n            else:\n                x = self.mlp_res(x)\n            \n            return x\n        \n        else:\n\n            hidden_states = x # alias\n            residual = hidden_states\n\n            # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n            if self.do_layer_norm_before:\n                hidden_states = self.self_attn_layer_norm(hidden_states)\n\n            # Self Attention\n            hidden_states, _, present = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                past_key_value=layer_past,\n            )\n            hidden_states = residual + hidden_states\n\n            # 350m applies layer norm AFTER attention\n            if not self.do_layer_norm_before:\n                hidden_states = self.self_attn_layer_norm(hidden_states)\n\n            # Fully Connected\n            hidden_states_shape = hidden_states.shape\n            hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))\n            residual = hidden_states\n\n            # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n            if self.do_layer_norm_before:\n                hidden_states = self.final_layer_norm(hidden_states)\n\n            hidden_states = self.fc1(hidden_states)\n            hidden_states = self.activation_fn(hidden_states)\n\n            hidden_states = self.fc2(hidden_states)\n\n            hidden_states = (residual + hidden_states).view(hidden_states_shape)\n\n            return hidden_states\n\n\nclass GPTLMHead(nn.Module):\n    def __init__(self, config, device='cpu'):\n        super().__init__()\n        \n        if config.do_layer_norm_before and not config._remove_final_layer_norm:\n            self.final_layer_norm = nn.LayerNorm(config.hidden_size, device=device)\n        else:\n            self.final_layer_norm = None\n            \n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False, device=device)\n        else:\n            self.project_out = None\n        \n        self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False, device=device)\n        \n    @classmethod\n    def from_pretrained(cls, model_path, config=None):\n        if config is None:\n            config = GPTConfig.from_pretrained(model_path)\n        # module = cls(config).eval()\n        module = torch.nn.utils.skip_init(cls, config).eval() # fast init\n        try:\n            module.load_state_dict(torch.load(os.path.join(\n                model_path, 'pytorch_lm_head.pt',\n            )))\n        except:\n            print('Cannot load from <model_name>. The model is randomly initialized.')\n        return module\n\n    def forward(self, x, input_ids=None, *args, **kargs):\n        if self.final_layer_norm is not None:\n            x = self.final_layer_norm(x)\n        if self.project_out is not None:\n            x = self.project_out(x)\n        x = self.lm_head(x)\n        return x"
  },
  {
    "path": "training/modules/llama_modules.py",
    "content": "# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch Llama model.\"\"\"\nimport os\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import add_start_docstrings, logging, replace_return_docstrings\nfrom transformers import LlamaConfig\n\nfrom flash_attn.layers.rotary import (\n    apply_rotary_emb_func,\n    apply_rotary_emb_qkv_,\n    apply_rotary_emb_kv_,\n)\n\nclass RotaryEmbedding(torch.nn.Module):\n    \"\"\"\n    The rotary position embeddings from RoFormer_ (Su et. al).\n    A crucial insight from the method is that the query and keys are\n    transformed by rotation matrices which depend on the relative positions.\n\n    Other implementations are available in the Rotary Transformer repo_ and in\n    GPT-NeoX_, GPT-NeoX was an inspiration\n\n    .. _RoFormer: https://arxiv.org/abs/2104.09864\n    .. _repo: https://github.com/ZhuiyiTechnology/roformer\n    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox\n\n    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).\n    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96\n    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base=10000.0,\n        interleaved=False,\n        scale_base=None,\n        scaling_factor=1.0,\n        pos_idx_in_fp32=True,\n        device=None,\n    ):\n        \"\"\"\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n            of 1st half and 2nd half (GPT-NeoX style).\n        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,\n            otherwise they might be in lower precision.\n            This option was added because previously (before 2023-07-02), when we construct\n            the position indices, we use the dtype of self.inv_freq. In most cases this would\n            be fp32, but if the model is trained in pure bf16 (not mixed precision), then\n            self.inv_freq would be bf16, and the position indices are also in bf16.\n            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the\n            embeddings for some positions will coincide.\n            To maintain compatibility with models previously trained in pure bf16,\n            we add this option.\n        \"\"\"\n        super().__init__()\n        self.dim = dim\n        self.base = float(base)\n        self.pos_idx_in_fp32 = pos_idx_in_fp32\n        # Generate and save the inverse frequency buffer (non trainable)\n        inv_freq = self._compute_inv_freq(device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.interleaved = interleaved\n        self.scale_base = scale_base\n        self.scaling_factor = scaling_factor\n        scale = (\n            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)\n            / (1.4 * dim)\n            if scale_base is not None\n            else None\n        )\n        self.register_buffer(\"scale\", scale, persistent=False)\n\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n\n    def _compute_inv_freq(self, device=None):\n        return 1.0 / (\n            self.base\n            ** (\n                torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)\n                / self.dim\n            )\n        )\n\n    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):\n        # Reset the tables if the sequence length has changed,\n        # if we're on a new device (possibly due to tracing for instance),\n        # or if we're switching from inference mode to training\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n            or (self.training and self._cos_cached.is_inference())\n        ):\n            self._seq_len_cached = seqlen\n            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16\n            # And the output of arange can be quite large, so bf16 would lose a lot of precision.\n            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.\n            if self.pos_idx_in_fp32:\n                t = torch.arange(seqlen, device=device, dtype=torch.float32)\n                t /= self.scaling_factor\n                # We want fp32 here as well since inv_freq will be multiplied with t, and the output\n                # will be large. Having it in bf16 will lose a lot of precision and cause the\n                # cos & sin output to change significantly.\n                # We want to recompute self.inv_freq if it was not loaded in fp32\n                if self.inv_freq.dtype != torch.float32:\n                    inv_freq = self._compute_inv_freq(device=device)\n                else:\n                    inv_freq = self.inv_freq\n            else:\n                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n                t /= self.scaling_factor\n                inv_freq = self.inv_freq\n            # Don't do einsum, it converts fp32 to fp16 under AMP\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            freqs = torch.outer(t, inv_freq)\n            if self.scale is None:\n                self._cos_cached = torch.cos(freqs).to(dtype)\n                self._sin_cached = torch.sin(freqs).to(dtype)\n            else:\n                power = (\n                    torch.arange(\n                        seqlen, dtype=self.scale.dtype, device=self.scale.device\n                    )\n                    - seqlen // 2\n                ) / self.scale_base\n                scale = self.scale.to(device=power.device) ** rearrange(\n                    power, \"s -> s 1\"\n                )\n                # We want the multiplication by scale to happen in fp32\n                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)\n                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)\n                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)\n                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)\n\n    def forward(\n        self,\n        qkv: torch.Tensor,\n        kv: Optional[torch.Tensor] = None,\n        seqlen_offset: int = 0,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,\n             else it's just q of shape (batch, seqlen, nheads, headdim)\n        kv: (batch, seqlen, 2, nheads, headdim)\n        seqlen_offset: can be used in generation where the qkv being passed in is only the last\n        token in the batch.\n        \"\"\"\n        seqlen = qkv.shape[1]\n        self._update_cos_sin_cache(\n            seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype\n        )\n        if kv is None:\n            if self.scale is None:\n                return apply_rotary_emb_qkv_(\n                    qkv,\n                    self._cos_cached[seqlen_offset:],\n                    self._sin_cached[seqlen_offset:],\n                    None,\n                    None,\n                    self.interleaved,\n                )\n            else:\n                return apply_rotary_emb_qkv_(\n                    qkv,\n                    self._cos_cached[seqlen_offset:],\n                    self._sin_cached[seqlen_offset:],\n                    self._cos_k_cached[seqlen_offset:],\n                    self._sin_k_cached[seqlen_offset:],\n                    self.interleaved,\n                )\n        else:\n            q = qkv\n            q = apply_rotary_emb_func(\n                q,\n                self._cos_cached[seqlen_offset:],\n                self._sin_cached[seqlen_offset:],\n                self.interleaved,\n                True,\n            )\n            if self.scale is None:\n                kv = apply_rotary_emb_kv_(\n                    kv,\n                    self._cos_cached[seqlen_offset:],\n                    self._sin_cached[seqlen_offset:],\n                    self.interleaved,\n                )\n            else:\n                kv = apply_rotary_emb_kv_(\n                    kv,\n                    self._cos_k_cached[seqlen_offset:],\n                    self._sin_k_cached[seqlen_offset:],\n                    self.interleaved,\n                )\n            return q, kv\n\n\ntry:\n    from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func\n\n    flash_attn_v2_installed = True\n    print(\">>>>> using flash attention v2\")\n\n    class FlashAttentionV2(nn.Module):\n        \"\"\"Implement the scaled dot product attention with softmax.\n        Arguments\n        ---------\n            softmax_scale: The temperature to use for the softmax attention.\n                          (default: 1/sqrt(d_keys) where d_keys is computed at\n                          runtime)\n            attention_dropout: The dropout rate to apply to the attention\n                               (default: 0.0)\n        \"\"\"\n\n        def __init__(self, softmax_scale=None, attention_dropout=0.0):\n            super().__init__()\n            self.softmax_scale = softmax_scale\n            self.dropout_p = attention_dropout\n\n        def forward(\n            self,\n            qkv,\n            key_padding_mask=None,\n            causal=False,\n            cu_seqlens=None,\n            max_s=None,\n            need_weights=False,\n        ):\n            \"\"\"Implements the multihead softmax attention.\n            Arguments\n            ---------\n                qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None\n                    if unpadded: (nnz, 3, h, d)\n                key_padding_mask: a bool tensor of shape (B, S)\n            \"\"\"\n            assert not need_weights\n            assert qkv.dtype in [torch.float16, torch.bfloat16]\n            assert qkv.is_cuda\n            assert key_padding_mask is None\n            assert cu_seqlens is None\n            assert max_s is None\n\n            output = flash_attn_qkvpacked_func(\n                qkv,\n                self.dropout_p if self.training else 0.0,\n                softmax_scale=self.softmax_scale,\n                causal=causal,\n            )\n\n            return output, None\n\nexcept ImportError:\n    flash_attn_v2_installed = False\n\ntry:\n    import xformers.ops as xops\n\n    xops_installed = True\n    print(\">>>>> Xformers installed\")\nexcept:\n    xops_installed = False\n\n\nlogger = logging.get_logger(__name__)\n\n\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))\n    mask_cond = torch.arange(mask.size(-1))\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1\n        )\n    return mask[None, None, :, :].expand(\n        bsz, 1, tgt_len, tgt_len + past_key_values_length\n    )\n\n\ndef _make_causal_mask_device(\n    input_ids_shape: torch.Size,\n    dtype: torch.dtype,\n    device: torch.device,\n    past_key_values_length: int = 0,\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(float(\"-inf\")), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [\n                torch.zeros(\n                    tgt_len, past_key_values_length, dtype=dtype, device=device\n                ),\n                mask,\n            ],\n            dim=-1,\n        )\n    return mask[None, None, :, :].expand(\n        bsz, 1, tgt_len, tgt_len + past_key_values_length\n    )\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(\n        inverted_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n\ndef _prepare_decoder_attention_mask(\n    attention_mask, input_shape, inputs_embeds, past_key_values_length\n):\n    # create causal mask\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask_device(\n            input_shape,\n            inputs_embeds.dtype,\n            inputs_embeds.device,\n            past_key_values_length=past_key_values_length,\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(\n            attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]\n        )\n        combined_attention_mask = (\n            expanded_attn_mask\n            if combined_attention_mask is None\n            else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\n# @torch.jit.script\ndef rmsnorm_func(hidden_states, weight, variance_epsilon):\n    input_dtype = hidden_states.dtype\n    hidden_states = hidden_states.to(torch.float32)\n    variance = hidden_states.pow(2).mean(-1, keepdim=True)\n    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)\n    return weight * hidden_states.to(input_dtype)\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.register_buffer(\n            \"variance_epsilon\",\n            torch.tensor(eps),\n            persistent=False,\n        )\n\n    def forward(self, hidden_states):\n        return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        config,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        self.head_dim = hidden_size // num_heads\n        max_positions = config.max_position_embeddings\n        self.max_positions = max_positions\n        self.config = config\n\n        if (self.head_dim * num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {num_heads}).\"\n            )\n        self.q_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.k_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.v_proj = nn.Linear(\n            hidden_size,\n            num_heads * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(\n            num_heads * self.head_dim,\n            hidden_size,\n            bias=False,\n        )\n\n        self.rotary_ndims = self.head_dim\n        self.register_buffer(\n            \"norm_factor\",\n            torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(\n                torch.get_default_dtype()\n            ),\n            persistent=False,\n        )\n\n        if self.config.rope_scaling is None:\n            # by default do linear scale if not specified.\n            scaling_factor = max(self.max_positions / 4096, 1.0)\n            print(f\"Linearly scaling {scaling_factor}x.\")\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            assert scaling_type == \"linear\"\n        self.rotary_emb = RotaryEmbedding(\n            self.rotary_ndims,\n            base=10000,\n            interleaved=False,\n            scaling_factor=scaling_factor,\n        )\n\n        if flash_attn_v2_installed:\n            self.flash_attn = FlashAttentionV2(\n                softmax_scale=1.0 / self.norm_factor, attention_dropout=0\n            )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states).view(\n            bsz, q_len, self.num_heads, self.head_dim\n        )\n        key_states = self.k_proj(hidden_states).view(\n            bsz, q_len, self.num_heads, self.head_dim\n        )\n        value_states = self.v_proj(hidden_states).view(\n            bsz, q_len, self.num_heads, self.head_dim\n        )\n\n        qkv = torch.stack([query_states, key_states, value_states], dim=2)\n        qkv = self.rotary_emb(qkv)\n\n        if flash_attn_v2_installed:\n            attn_output, _ = self.flash_attn(qkv, causal=True)\n        elif xops_installed:\n            q, k, v = qkv.unbind(2)\n            attn_output = xops.memory_efficient_attention(\n                q, k, v, attn_bias=xops.LowerTriangularMask()\n            )\n        elif flash_attn_installed:\n            attn_output, _ = self.flash_attn(qkv, causal=True)\n        else:\n            raise Exception(\"Flash Attention not found.\")\n\n        attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, None\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(\n            hidden_size=self.hidden_size,\n            num_heads=config.num_attention_heads,\n            config=config,\n        )\n        self.mlp = LlamaMLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            past_key_value=past_key_value,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass GPTEmbeddings(nn.Module):\n    def __init__(self, config, device=\"cpu\"):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n\n    def forward(\n        self,\n        input_ids,\n        *args,\n        **kargs,\n    ):\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        return inputs_embeds\n\n\nclass GPTLMHead(nn.Module):\n    def __init__(self, config, device=\"cpu\"):\n        super().__init__()\n        self.config = config\n\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n    def forward(\n        self,\n        hidden_states,\n        *args,\n        **kargs,\n    ):\n        hidden_states = self.norm(hidden_states)\n\n        logits = self.lm_head(hidden_states)\n\n        return logits\n\n\nclass GPTBlock(nn.Module):\n    def __init__(self, config: LlamaConfig, *args, **kargs):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(\n            hidden_size=self.hidden_size,\n            num_heads=config.num_attention_heads,\n            config=config,\n        )\n        self.mlp = LlamaMLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n        def attn_res(hidden_states: torch.Tensor, attention_mask=None) -> torch.Tensor:\n            residual = hidden_states\n\n            hidden_states = self.input_layernorm(hidden_states)\n\n            # Self Attention\n            hidden_states, self_attn_weights, present_key_value = self.self_attn(\n                hidden_states=hidden_states,\n                past_key_value=None,\n                attention_mask=attention_mask,\n            )\n            hidden_states = residual + hidden_states\n\n            return hidden_states\n\n        self.attn_res = attn_res\n\n        def mlp_res(hidden_states: torch.Tensor) -> torch.Tensor:\n            # Fully Connected\n            residual = hidden_states\n            hidden_states = self.post_attention_layernorm(hidden_states)\n            hidden_states = self.mlp(hidden_states)\n            hidden_states = residual + hidden_states\n            return hidden_states\n\n        self.mlp_res = mlp_res\n\n        self.use_checkpoint = True\n\n    def forward(\n        self, x: torch.Tensor, layer_past=None, mask=None, *args, **kargs\n    ) -> torch.Tensor:\n        if layer_past is not None:\n            past_length = layer_past[0].size(2)\n        else:\n            past_length = 0\n        if mask is None:\n            mask = torch.ones(\n                (x.size(0), x.size(1) + past_length), dtype=torch.bool, device=x.device\n            )\n\n        attention_mask = None\n\n        if self.use_checkpoint:\n            x.requires_grad_(True)\n            x = checkpoint(self.attn_res, x, attention_mask)\n        else:\n            x = self.attn_res(x, attention_mask)\n\n        if self.use_checkpoint:\n            x.requires_grad_(True)\n            x = checkpoint(self.mlp_res, x)\n        else:\n            x = self.mlp_res(x)\n\n        return x\n"
  },
  {
    "path": "training/modules/task_modules.py",
    "content": "import torch\n\n\nclass GlueClassification(torch.nn.Module):\n    def __init__(self, model_dim, num_classes):\n        super(GlueClassification, self).__init__()\n        self.model_dim = model_dim\n        self.num_classes = num_classes\n        self.pooler_layer = torch.nn.Linear(model_dim, model_dim)\n        self.fc_layer = torch.nn.Linear(model_dim, num_classes)\n\n    def forward(self, hidden_states, pooler_index=0):\n        pooled = hidden_states[:, pooler_index, :]\n        pooled = self.pooler_layer(pooled)\n        pooled = torch.tanh(pooled)\n        return self.fc_layer(pooled)\n"
  },
  {
    "path": "training/modules/tokenizer.py",
    "content": "\nfrom transformers import AutoTokenizer, GPT2TokenizerFast, DebertaV2Tokenizer\n\ndef build_tokenizer(args):\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n    return tokenizer\n\ndef build_gpt2_tokenizer(args):\n    tokenizer = GPT2TokenizerFast.from_pretrained(args.tokenizer_name)\n    tokenizer.pad_token = tokenizer.eos_token\n    return tokenizer\n\ndef build_deberta_tokenizer(args):\n    tokenizer = DebertaV2Tokenizer.from_pretrained(args.tokenizer_name)\n    return tokenizer\n    "
  },
  {
    "path": "training/modules/utils.py",
    "content": "import torch\nimport math\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional\nfrom typing import Optional, Tuple, Union\n\n\n# @torch.jit.script\ndef gpt_loss_func(input, target):\n    lm_logits, labels = input, target\n    shift_logits = lm_logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n    loss = functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n    return loss"
  },
  {
    "path": "training/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "training/optimizer/grad_scalar.py",
    "content": "from abc import ABC\nfrom abc import abstractmethod\n\nimport torch\n\n\nclass GradScaler(ABC):\n    def __init__(self, initial_scale, device=None):\n        \"\"\"Initialize scale value with the input initial scale.\"\"\"\n        assert initial_scale > 0.0\n        self.device = device\n        self._scale = torch.cuda.FloatTensor([initial_scale], device=device)\n\n    @property\n    def scale(self):\n        return self._scale\n\n    @property\n    def inv_scale(self):\n        return self._scale.double().reciprocal().float()\n\n    @abstractmethod\n    def update(self, found_inf):\n        pass\n\n    @abstractmethod\n    def state_dict(self):\n        pass\n\n    @abstractmethod\n    def load_state_dict(self, state_dict):\n        pass\n\n\nclass ConstantGradScaler(GradScaler):\n\n    def update(self, found_inf):\n        pass\n\n    def state_dict(self):\n        return dict()\n\n    def load_state_dict(self, state_dict):\n        pass\n\n\nclass DynamicGradScaler(GradScaler):\n\n    def __init__(self, initial_scale, min_scale,\n                 growth_factor, backoff_factor,\n                 growth_interval, hysteresis, device=None):\n        \"\"\"\"Grad scaler with dynamic scale that gets adjusted\n        during training.\"\"\"\n        super(DynamicGradScaler, self).__init__(initial_scale, device=device)\n\n        # Lower bound on the scale.\n        assert min_scale > 0.0\n        assert min_scale <= initial_scale\n        self.min_scale = torch.cuda.FloatTensor([min_scale], device=device)\n        # Growth and backoff factors for the scale.\n        assert growth_factor > 1.0\n        self.growth_factor = torch.cuda.FloatTensor([growth_factor], device=device)\n        assert backoff_factor < 1.0\n        assert backoff_factor > 0.0\n        self.backoff_factor = torch.cuda.FloatTensor([backoff_factor], device=device)\n        # Interval over which if we don't see any inf/nan,\n        # we will scale the grad scale by the growth factor.\n        assert growth_interval > 0\n        self.growth_interval = growth_interval\n        # Number of inf/nans we should see before scaling down\n        # the grad scale by the backoff factor.\n        assert hysteresis > 0\n        self.hysteresis = hysteresis\n\n        # Trackers.\n        self._growth_tracker = 0\n        self._hysteresis_tracker = self.hysteresis\n\n    def update(self, found_inf):\n        # If we have an inf/nan, growth tracker is set to 0\n        # and hysterisis tracker is reduced by 1.\n        if found_inf:\n            self._growth_tracker = 0\n            self._hysteresis_tracker -= 1\n            # Now if we are out of hysteresis count, scale down the loss.\n            if self._hysteresis_tracker <= 0:\n                self._scale = torch.max(self._scale * self.backoff_factor,\n                                        self.min_scale)\n                print('##### scale backoff to', self._scale)\n        else:\n            # If there is no nan/inf, increment the growth tracker.\n            self._growth_tracker += 1\n            # If we have had enough consequitive intervals with no nan/inf:\n            if self._growth_tracker == self.growth_interval:\n                # Reset the tracker and hysteresis trackers,\n                self._growth_tracker = 0\n                self._hysteresis_tracker = self.hysteresis\n                # and scale up the loss scale.\n                self._scale = self._scale * self.growth_factor\n                print('##### scale grow to', self._scale)\n\n    def state_dict(self):\n        state_dict = {}\n        state_dict['scale'] = self._scale\n        state_dict['growth_tracker'] = self._growth_tracker\n        state_dict['hysteresis_tracker'] = self._hysteresis_tracker\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        self._scale = state_dict['scale'].to(self.device)\n        self._growth_tracker = state_dict['growth_tracker']\n        self._hysteresis_tracker = state_dict['hysteresis_tracker']"
  },
  {
    "path": "training/optimizer/optimizer.py",
    "content": "import torch\nfrom .grad_scalar import *\n\n# This follows some implementation from Megatron\n\n\ndef _has_overflow_serial(grads):\n\n    def _has_inf_or_nan(x):\n        try:\n            # if x is half, the .float() incurs an additional deep copy, but it's necessary if\n            # Pytorch's .sum() creates a one-element tensor of the same type as x\n            # (which is true for some recent version of pytorch).\n            cpu_sum = float(x.float().sum())\n            # More efficient version that can be used if .sum() returns a Python scalar\n            # cpu_sum = float(x.sum())\n        except RuntimeError as instance:\n            # We want to check if inst is actually an overflow exception.\n            # RuntimeError could come from a different error.\n            # If so, we still want the exception to propagate.\n            if \"value cannot be converted\" not in instance.args[0]:\n                raise\n            return True\n        else:\n            if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:\n                return True\n            return False\n\n    for p in grads:\n        if _has_inf_or_nan(p):\n            return torch.FloatTensor([1.0])\n\n    return torch.FloatTensor([0.0])\n\n\n# `x` is a torch.Tensor\n\n\n\ndef _zero_grad_group(group, set_to_none):\n    \"\"\"Zero out the gradient for a group of parameters.\n    Note: copied from torch.optim.optimizer.\"\"\"\n    for param in group:\n        if param.grad is not None:\n            if set_to_none:\n                param.grad = None\n            else:\n                if param.grad.grad_fn is not None:\n                    param.grad.detach_()\n                else:\n                    param.grad.requires_grad_(False)\n                param.grad.zero_()\n\n\n'''\ndef _multi_tensor_copy_this_to_that(this, that):\n    for this_, that_ in zip(this, that):\n        that_.copy_(this_)\n'''\n\n\nclass Fp16Optimizer:\n    # If offload is set to true, the fp32 copy is stored on CPU.\n    def __init__(self, optimizer, grad_scaler, device, offload=False):\n        self.offload = offload\n        if self.offload:\n            self.cpu_to_gpu_stream = torch.cuda.Stream(device=device, priority=-1)\n            self.gpu_to_cpu_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.optimizer = optimizer\n        self.grad_scaler = grad_scaler\n\n        if self.grad_scaler:\n            self.found_inf = torch.cuda.FloatTensor([0.0], device=device) if not self.offload else torch.FloatTensor([0.0])\n\n        self._dummy_overflow_buf = torch.cuda.IntTensor([0], device=device) if not self.offload else torch.IntTensor([0])\n\n        # Note that the model should first be cast to fp16 before passing to the optimizer.\n        self.float16_groups = []\n        self.fp32_from_float16_groups = []\n\n        # For all the groups in the original optimizer:\n        for param_group in self.optimizer.param_groups:\n            float16_params_this_group = []\n            fp32_from_float16_params_this_group = []\n            # For all the parameters in this group:\n            for i, param in enumerate(param_group['params']):\n                if param.requires_grad:\n                    # float16 params:\n                    assert param.type() == 'torch.cuda.HalfTensor'\n                    float16_params_this_group.append(param)\n                    # Create a copy\n                    if self.offload:\n                        optimizer_param = param.detach().clone().float().to(device='cpu')\n                        assert optimizer_param.device == torch.device('cpu')\n                        if optimizer_param.grad is None:\n                            optimizer_param.grad = torch.zeros_like(optimizer_param.data)\n                    else:\n                        optimizer_param = param.detach().clone().float()\n                    # Replace the optimizer params with the new fp32 copy.\n                    param_group['params'][i] = optimizer_param\n                    fp32_from_float16_params_this_group.append(optimizer_param)\n                    # Reset existing state dict key to the new optimizer param.\n                    if param in self.optimizer.state:\n                        self.optimizer.state[optimizer_param] = self.optimizer.state.pop(param)\n\n            self.float16_groups.append(float16_params_this_group)\n            self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)\n\n        # Leverage state_dict() and load_state_dict() to\n        # recast preexisting per-param state tensors\n        self.optimizer.load_state_dict(self.optimizer.state_dict())\n\n    def zero_grad(self, set_to_none=True):\n        for group in self.float16_groups:\n            _zero_grad_group(group, set_to_none)\n        if not self.offload:\n            for group in self.fp32_from_float16_groups:\n                _zero_grad_group(group, set_to_none)\n\n    def get_loss_scale(self):\n        return self.grad_scaler.scale\n\n    def _copy_model_grads_to_optimizer_grads(self):\n        # This only needs to be done for the float16 group.\n        for model_group, optimizer_group in zip(self.float16_groups, self.fp32_from_float16_groups):\n            for model_param, optimizer_param in zip(model_group, optimizer_group):\n                if model_param.grad is not None:\n                    if self.offload:\n                        with torch.cuda.stream(self.gpu_to_cpu_stream):\n                            optimizer_param.grad.copy_(model_param.grad, non_blocking=False)\n                    else:\n                        optimizer_param.grad = model_param.grad.float()\n                # Safe to deallocate model's grad/optimizer_grad after copying.\n                # (If using contiguous buffers, optimizer_grad's memory should\n                # persist and therefore should not be deallocated.)\n                model_param.grad = None\n\n    def _unscale_optimizer_grads_and_check_for_nan(self):\n        optimizer_grads = []\n        # fp32 params fromm float16 ones.\n        for optimizer_group in self.fp32_from_float16_groups:\n            for optimizer_param in optimizer_group:\n                if optimizer_param.grad is not None:\n                    optimizer_grads.append(optimizer_param.grad.data)\n        # Reset found inf.\n        self.found_inf.fill_(0.0)\n        # Unscale and set found inf/nan\n        print(optimizer_grads[0].device, self.found_inf.device, self.grad_scaler.inv_scale.device)\n        if self.offload:\n            self.found_inf = _has_overflow_serial(optimizer_grads)\n        else:\n            torch._amp_foreach_non_finite_check_and_unscale_(optimizer_grads, self.found_inf, self.grad_scaler.inv_scale)\n        # Check for nan.\n        found_inf_flag = (self.found_inf.item() > 0)\n        return found_inf_flag\n\n    def _get_model_and_optimizer_params_data_float16_deprecated(self):\n        model_data = []\n        optimizer_data = []\n        for model_group, optimizer_group in zip(self.float16_groups, self.fp32_from_float16_groups):\n            for model_param, optimizer_param in zip(model_group, optimizer_group):\n                model_data.append(model_param.data)\n                optimizer_data.append(optimizer_param.data)\n        return model_data, optimizer_data\n\n    def _copy_optimizer_params_to_model_params(self):\n        # Only needed for the float16 params.\n        # model_data, optimizer_data = self._get_model_and_optimizer_params_data_float16_deprecated()\n        # _multi_tensor_copy_this_to_that(this=optimizer_data, that=model_data)\n\n        for model_group, optimizer_group in zip(self.float16_groups, self.fp32_from_float16_groups):\n            for model_param, optimizer_param in zip(model_group, optimizer_group):\n                if self.offload:\n                    with torch.cuda.stream(self.cpu_to_gpu_stream):\n                        model_param.data.copy_(optimizer_param.data, non_blocking=False)\n                else:\n                    model_param.data.copy_(optimizer_param.data)\n\n    def _copy_model_params_to_optimizer_params(self):\n        # Only needed for the float16 params.\n        # model_data, optimizer_data = self._get_model_and_optimizer_params_data_float16_deprecated()\n        # _multi_tensor_copy_this_to_that(this=model_data, that=optimizer_data)\n        for model_group, optimizer_group in zip(self.float16_groups, self.fp32_from_float16_groups):\n            for model_param, optimizer_param in zip(model_group, optimizer_group):\n                if self.offload:\n                    with torch.cuda.stream(self.gpu_to_cpu_stream):\n                        optimizer_param.data.copy_(model_param.data, non_blocking=False)\n                else:\n                    optimizer_param.data.copy_(model_param.data)\n\n    def reload_model_params(self):\n        self._copy_model_params_to_optimizer_params()\n\n    @torch.no_grad()\n    def step(self):\n        self._copy_model_grads_to_optimizer_grads()\n\n        found_inf_flag = self._unscale_optimizer_grads_and_check_for_nan()\n        self.grad_scaler.update(found_inf_flag)\n\n        # If we found inf/nan, skip the update.\n        if found_inf_flag:\n            print(\"!!! Warning: find inf in fp16 optimizer-step() !!!\")\n            return False\n        \n        for params in self.fp32_from_float16_groups:\n            torch.nn.utils.clip_grad_norm_(params, 1.0)\n\n        # Step the optimizer.\n        self.optimizer.step()\n\n        self._copy_optimizer_params_to_model_params()\n        # Successful update.\n        return True\n    \n    def scale(self, z):\n        return z * self.grad_scaler.scale\n    \n    def unscale(self, z):\n        return z * self.grad_scaler.inv_scale\n    \n    def state_dict(self):\n        return self.optimizer.state_dict()\n    \n    def load_state_dict(self, state_dict):\n        self.optimizer.load_state_dict(state_dict)\n\n\ndef get_fp16_optimizer(args, optimizer, device):\n    assert args.fp16 is not None\n    if args.loss_scale:\n        print(\"fp16 uses ConstantGradScaler.\")\n        grad_scaler = ConstantGradScaler(args.loss_scale)\n    else:\n        print(\"fp16 uses DynamicGradScaler.\")\n        grad_scaler = DynamicGradScaler(\n            initial_scale=args.initial_loss_scale,\n            min_scale=args.min_loss_scale,\n            growth_factor=2.0,\n            backoff_factor=0.5,\n            growth_interval=args.loss_scale_window,\n            hysteresis=args.hysteresis)\n    return Fp16Optimizer(optimizer, grad_scaler, device, getattr(args, 'use_offload', False))\n\n"
  },
  {
    "path": "training/pipeline_parallel/__init__.py",
    "content": ""
  },
  {
    "path": "training/pipeline_parallel/dist_gpipe_pipeline_async.py",
    "content": "import time\nimport json\nimport torch.nn.functional\nfrom torch import optim\nfrom comm.comm_utils import *\nfrom modules.dist_gpt_pp_module import *\nfrom utils.logging_utils import *\nfrom data_parallel.dist_dp_utils import get_dp_module\nfrom optimizer.optimizer import get_fp16_optimizer\nimport os\nimport cupy\nfrom transformers import get_linear_schedule_with_warmup\n\nflag_profile = int(os.environ.get('FLAG_BENCHMARK', '0'))\n\ndef get_parameter_names(model, forbidden_layer_types):\n    \"\"\"\n    Returns the names of the model parameters that are not inside a forbidden layer.\n    \"\"\"\n    result = []\n    for name, child in model.named_children():\n        result += [\n            f\"{name}.{n}\"\n            for n in get_parameter_names(child, forbidden_layer_types)\n            if not isinstance(child, tuple(forbidden_layer_types))\n        ]\n    # Add model specific parameters (defined with nn.Parameter) since they are not in any child.\n    result += list(model._parameters.keys())\n    return result\n\n\ndef create_optimizer(model, optimizer_type, weight_decay=0.01, learning_rate=2e-5,\n                     adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-6):\n    \n    if optimizer_type == 'adamw' or optimizer_type == 'adam':\n        from torch.optim import AdamW\n        print('>>>>> using Adam')\n    elif optimizer_type == '8bit-adam':\n        from bitsandbytes.optim import Adam8bit as AdamW\n        print('>>>>> using 8bit-Adam')\n    else:\n        assert False\n    \n    decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])\n    decay_parameters = [\n        name for name in decay_parameters if \"bias\" not in name]\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [p for n, p in model.named_parameters() if n in decay_parameters and p.requires_grad],\n            \"weight_decay\": weight_decay,\n        },\n        {\n            \"params\": [p for n, p in model.named_parameters() if n not in decay_parameters and p.requires_grad],\n            \"weight_decay\": 0.0,\n        }\n    ]\n    optimizer_cls = AdamW\n    optimizer_kwargs = {\n        \"betas\": (adam_beta1, adam_beta2),\n        \"eps\": adam_epsilon,\n    }\n    optimizer_kwargs[\"lr\"] = learning_rate\n    optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n    return optimizer\n\n\nclass GpipeAsync:\n    r\"\"\"\n    Async implementation of Gpipe.\n    The current implementation leave the computation on the PyTorch default stream and the communication on a different\n    stream, there is:\n        a group of events to check if recv (from rank i-1) finishes in the forward propagation;\n        a group of events to check if recv (from rank i+1) finishes in the backward propagation;\n        a group of events to check if computation finishes in the forward propagation;\n        a group of events to check if computation finishes in the backward propagation.\n    \"\"\"\n\n    def __init__(self, args, config, device, use_dp=False,\n                 _StageFull=GPTStageFull,\n                 _StageFirst=GPTStageFirst,\n                 _StageLast=GPTStageLast,\n                 _StageMiddle=GPTStageMiddle):\n        print(\"=======Initialize Gpipe.\")\n        if args.fp16:\n            self.use_fp16 = True\n            self.use_dynamic_scale = (args.loss_scale == 0)\n            print(\"=======Gpipe use FP16\")\n        else:\n            self.use_fp16 = False\n            print(\"=======Gpipe use FP32\")\n        self.use_dp = use_dp\n        self.dtype = torch.float16 if self.use_fp16 else torch.float32\n        self.global_rank = args.rank\n        self.pipeline_group_size = args.pipeline_group_size\n        # Rank is the pipeline rank by default.\n        self.pp_rank = get_pipeline_parallel_rank()\n        if use_dp:\n            self.dp_rank = get_data_parallel_rank()\n        else:\n            self.dp_rank = 0\n        self.pre_node_rank = self.pp_rank - 1\n        self.post_node_rank = self.pp_rank + \\\n            1 if self.pp_rank != self.pipeline_group_size - 1 else -1\n        self.comm = get_pipeline_parallel_comm()\n        self.gradient_accumulate_step = args.gradient_accumulate_step\n        print(\"=======Gradient accumulate step: \",\n              self.gradient_accumulate_step)\n\n        assert (args.batch_size % args.micro_batch_size == 0)\n        self.micro_batch_num = args.batch_size // args.micro_batch_size\n        self.micro_batch_size = args.micro_batch_size\n        self.seq_length = args.seq_length\n        self.embedding_dim = args.embedding_dim\n        self.config = config\n        self.vocab_size = config.vocab_size\n        self.num_classes = config.num_labels\n\n        self.enable_tidy_profiling = (args.profiling == 'tidy_profiling')\n        self.device = device\n        self.torch_comp_stream = torch.cuda.default_stream(device=device)\n        self.torch_recv_stream = torch.cuda.Stream(device=device, priority=-1)\n        self.torch_send_stream = torch.cuda.Stream(device=device, priority=-1)\n\n        self.forward_recv_ready_events = [torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n                                          for _ in range(self.micro_batch_num)]\n        self.forward_comp_ready_events = [torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n                                          for _ in range(self.micro_batch_num)]\n\n        self.backward_recv_ready_events = [torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n                                           for _ in range(self.micro_batch_num)]\n        self.backward_comp_ready_events = [torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False)\n                                           for _ in range(self.micro_batch_num)]\n\n        if self.enable_tidy_profiling:\n            self.profiling_log = []\n            self.forward_recv_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                              for _ in range(self.micro_batch_num)]\n            self.forward_comp_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                              for _ in range(self.micro_batch_num)]\n            self.forward_send_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                              for _ in range(self.micro_batch_num)]\n            self.forward_send_end_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                            for _ in range(self.micro_batch_num)]\n\n            self.backward_recv_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                               for _ in range(self.micro_batch_num)]\n            self.backward_comp_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                               for _ in range(self.micro_batch_num)]\n            self.backward_send_start_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                               for _ in range(self.micro_batch_num)]\n            self.backward_send_end_events = [torch.cuda.Event(enable_timing=True, blocking=False)\n                                             for _ in range(self.micro_batch_num)]\n            self.init_event = torch.cuda.Event(\n                enable_timing=True, blocking=False)\n            self.init_time_stamp = None\n            self.optimizer_start_event = torch.cuda.Event(\n                enable_timing=True, blocking=False)\n            self.optimizer_end_event = torch.cuda.Event(\n                enable_timing=True, blocking=False)\n\n        self._compute_micro_batch_size()\n\n        if hasattr(args, 'infer_only') and args.infer_only:\n            do_train = False\n        else:\n            do_train = True\n\n        if self.pp_rank == 0:\n            self.input_micro_batches = None\n        else:\n            self.input_micro_batches = [\n                torch.zeros((self.micro_batch_size, self.seq_length, self.embedding_dim),\n                            requires_grad=do_train, device=self.device, dtype=self.dtype\n                            ) for _ in range(self.micro_batch_num)\n            ]\n\n        if do_train:\n            if self.pp_rank == self.pipeline_group_size - 1:\n\n                init_train_logger(args)\n\n            if self.pp_rank == self.pipeline_group_size - 1:\n                self.output_micro_batches_grad = None\n            else:\n                self.output_micro_batches_grad = [\n                    torch.zeros((self.micro_batch_size, self.seq_length, self.embedding_dim),\n                                requires_grad=False, device=self.device, dtype=self.dtype\n                                ) for _ in range(self.micro_batch_num)\n                ]\n\n        if self.pipeline_group_size > 1:\n            if self.pp_rank == 0:\n                self.model = _StageFirst(args, config, device)\n            elif self.pp_rank == self.pipeline_group_size - 1:\n                self.model = _StageLast(args, config, device)\n            else:\n                self.model = _StageMiddle(args, config, device)\n        else:\n            self.model = _StageFull(args, config, device)\n\n        if self.use_fp16:\n            self.model.half()\n\n        if do_train:\n            if self.use_fp16:\n                tmp_optimizer = create_optimizer(\n                    self.model, optimizer_type=getattr(args, 'optimizer', 'adamw'), learning_rate=args.lr)\n                self.optimizer = get_fp16_optimizer(\n                    args, tmp_optimizer, device)\n                self.scheduler = get_linear_schedule_with_warmup(\n                    tmp_optimizer, args.warmup_steps, args.total_steps, )\n            else:\n                self.optimizer = create_optimizer(\n                    self.model, optimizer_type=getattr(args, 'optimizer', 'adamw'), learning_rate=args.lr)\n                self.scheduler = get_linear_schedule_with_warmup(\n                    self.optimizer, args.warmup_steps, args.total_steps, )\n\n            # Notice that if we use fp16, gradients are aggregated in fp16, this may not be the default in Megatron.\n            if use_dp:\n                self.dp_optim = get_dp_module(\n                    args, device, self.model, self.optimizer)\n\n        self.global_step = 0\n\n    def _compute_micro_batch_size(self):\n        micro_batch_float_num = self.micro_batch_size * \\\n            self.seq_length * self.embedding_dim\n        if self.use_fp16:\n            print(\"=======Current micro-batch send/recv size: {} MB (fp16)\"\n                  .format(micro_batch_float_num * 2 // 1024 // 1024))\n        else:\n            print(\"=======Current micro-batch send/recv size: {} MB (fp32)\"\n                  .format(micro_batch_float_num*4//1024//1024))\n        print(\"=======Number of micro-batches: {}.\".format(self.micro_batch_num))\n\n    def zero_input_grad(self):\n        if self.input_micro_batches:\n            for input_micro_batch in self.input_micro_batches:\n                if input_micro_batch.grad is not None:\n                    input_micro_batch.grad.zero_()\n\n    def profile_mark_forward_comp_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_comp_stream.record_event(\n                self.forward_comp_start_events[i])\n\n    def profile_mark_forward_recv_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_recv_stream.record_event(\n                self.forward_recv_start_events[i])\n\n    def profile_mark_forward_send_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_send_stream.record_event(\n                self.forward_send_start_events[i])\n\n    def profile_mark_forward_send_end(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_send_stream.record_event(\n                self.forward_send_end_events[i])\n\n    def profile_mark_backward_comp_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_comp_stream.record_event(\n                self.backward_comp_start_events[i])\n\n    def profile_mark_backward_recv_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_recv_stream.record_event(\n                self.backward_recv_start_events[i])\n\n    def profile_mark_backward_send_start(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_send_stream.record_event(\n                self.backward_send_start_events[i])\n\n    def profile_mark_backward_send_end(self, i):\n        if self.enable_tidy_profiling:\n            self.torch_send_stream.record_event(\n                self.backward_send_end_events[i])\n\n    def get_ts(self, event):\n        return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3\n\n    def forward_stage(self, input_data=None, aux_input_data=None):\n        # print(\"Forward stage start! rank-\", self.rank)\n\n        if aux_input_data is not None:\n            for k in aux_input_data:\n                aux_input_data[k] = torch.chunk(\n                    aux_input_data[k], self.micro_batch_num, dim=0)\n        else:\n            aux_input_data = {}\n\n        if self.pp_rank == 0:\n            assert(input_data is not None)\n            self.input_micro_batches = torch.chunk(\n                input_data, self.micro_batch_num, dim=0)\n        if self.pp_rank == self.pipeline_group_size - 1:\n            if input_data is not None:\n                input_ids_micro_batches = torch.chunk(\n                    input_data, self.micro_batch_num, dim=0)\n            else:\n                input_ids_micro_batches = [None]*self.micro_batch_num\n        output_micro_batches = []\n\n        for i in range(self.micro_batch_num):\n            if self.pipeline_group_size > 1:\n                if self.pp_rank == 0:  # Only send output to next node, do not receive\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.profile_mark_forward_comp_start(i)\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i],\n                            **{k: v[i] for k, v in aux_input_data.items()}\n                        )\n                        self.torch_comp_stream.record_event(\n                            self.forward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(\n                            self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(\n                            self.forward_comp_ready_events[i])\n                        self.profile_mark_forward_send_start(i)\n                        self.comm.send(current_micro_output.data,\n                                       dst=self.post_node_rank, stream=cupy_send_stream)\n                        self.profile_mark_forward_send_end(i)\n                elif self.pp_rank == self.pipeline_group_size - 1:  # Only receive input from last node, do not send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(\n                            self.torch_recv_stream.cuda_stream)\n                        self.profile_mark_forward_recv_start(i)\n                        self.comm.recv(\n                            self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(\n                            self.forward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(\n                            self.forward_recv_ready_events[i])\n                        self.profile_mark_forward_comp_start(i)\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i], input_ids=input_ids_micro_batches[i],\n                            **{k: v[i] for k, v in aux_input_data.items()}\n                        )\n                        self.torch_comp_stream.record_event(\n                            self.forward_comp_ready_events[i])\n                else:  # receive, compute, and send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(\n                            self.torch_recv_stream.cuda_stream)\n                        self.profile_mark_forward_recv_start(i)\n                        self.comm.recv(\n                            self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(\n                            self.forward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(\n                            self.forward_recv_ready_events[i])\n                        self.profile_mark_forward_comp_start(i)\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i],\n                            **{k: v[i] for k, v in aux_input_data.items()}\n                        )\n                        self.torch_comp_stream.record_event(\n                            self.forward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(\n                            self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(\n                            self.forward_comp_ready_events[i])\n                        self.profile_mark_forward_send_start(i)\n                        self.comm.send(current_micro_output.data,\n                                       dst=self.post_node_rank, stream=cupy_send_stream)\n                        self.profile_mark_forward_send_end(i)\n            else:\n                with torch.cuda.stream(self.torch_comp_stream):\n                    self.profile_mark_forward_comp_start(i)\n                    current_micro_output = self.model(\n                        self.input_micro_batches[i],\n                        **{k: v[i] for k, v in aux_input_data.items()}\n                    )\n                    self.torch_comp_stream.record_event(\n                        self.forward_comp_ready_events[i])\n\n            output_micro_batches.append(current_micro_output)\n\n        return output_micro_batches\n\n    def profiling_forward_stage(self):\n        torch.cuda.synchronize()\n        for i in range(self.micro_batch_num):\n            if self.pp_rank != 0:\n                recv_slot = self.forward_recv_start_events[i].elapsed_time(\n                    self.forward_recv_ready_events[i]) * 1e+3\n                recv_log = {\"name\": \"recv\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"1. forward-recv\",\n                            \"ts\": self.get_ts(self.forward_recv_start_events[i]), \"dur\": recv_slot,\n                            \"args\": {\"micro-batch\": i}, \"cname\": \"startup\"}  # cname is for color, a little silly.\n                # print(recv_log)\n                self.profiling_log.append(recv_log)\n\n            comp_slot = self.forward_comp_start_events[i].elapsed_time(\n                self.forward_comp_ready_events[i]) * 1e+3\n            comp_log = {\"name\": \"comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"2. forward-compute\",\n                        \"ts\": self.get_ts(self.forward_comp_start_events[i]), \"dur\": comp_slot,\n                        \"args\": {\"micro-batch\": i}, \"cname\": \"good\"}\n            # print(comp_log)\n            self.profiling_log.append(comp_log)\n\n            if self.pp_rank != self.pipeline_group_size - 1:\n                send_slot = self.forward_send_start_events[i].elapsed_time(\n                    self.forward_send_end_events[i]) * 1e+3\n                send_log = {\"name\": \"send\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"3. forward-send\",\n                            \"ts\": self.get_ts(self.forward_send_start_events[i]), \"dur\": send_slot,\n                            \"args\": {\"micro-batch\": i}, \"cname\": \"thread_state_iowait\"}\n                # print(send_log)\n                self.profiling_log.append(send_log)\n\n    def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target=None,\n                       loss_func=torch.nn.functional.cross_entropy):\n        # print(\"Backward stage start! rank-\", self.rank)\n        if self.pp_rank == self.pipeline_group_size - 1:\n            assert(target is not None)\n            target_as_micro_batches = torch.chunk(\n                target, self.micro_batch_num, dim=0)\n        # else:\n        #     assert(target is None)\n\n        if self.pp_rank == self.pipeline_group_size - 1:\n            tr_loss = []\n\n        for i in range(self.micro_batch_num):\n            if self.pipeline_group_size > 1:\n                if self.pp_rank == self.pipeline_group_size - 1:  # only send grad back to last node, do not receive\n                    with torch.cuda.stream(self.torch_comp_stream) as st:\n                        self.profile_mark_backward_comp_start(i)\n                        loss = loss_func(\n                            input=cached_output_micro_batches[i], target=target_as_micro_batches[i])\n                        if not flag_profile:\n                            tr_loss.append(loss.item())\n                        if self.use_fp16:\n                            self.optimizer.scale(loss).backward()\n                        else:\n                            loss.backward()\n                        self.torch_comp_stream.record_event(\n                            self.backward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(\n                            self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(\n                            self.backward_comp_ready_events[i])\n                        self.profile_mark_backward_send_start(i)\n                        self.comm.send(\n                            self.input_micro_batches[i].grad, dst=self.pre_node_rank, stream=cupy_send_stream)\n                        self.profile_mark_backward_send_end(i)\n                elif self.pp_rank == 0:  # only receive grad from previous node, do not send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(\n                            self.torch_recv_stream.cuda_stream)\n                        self.profile_mark_backward_recv_start(i)\n                        self.comm.recv(\n                            self.output_micro_batches_grad[i], src=self.post_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(\n                            self.backward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(\n                            self.backward_recv_ready_events[i])\n                        self.profile_mark_backward_comp_start(i)\n                        cached_output_micro_batches[i].backward(\n                            gradient=self.output_micro_batches_grad[i])\n                        self.torch_comp_stream.record_event(\n                            self.backward_comp_ready_events[i])\n                else:  # receive, compute and send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(\n                            self.torch_recv_stream.cuda_stream)\n                        self.profile_mark_backward_recv_start(i)\n                        self.comm.recv(\n                            self.output_micro_batches_grad[i], src=self.post_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(\n                            self.backward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(\n                            self.backward_recv_ready_events[i])\n                        self.profile_mark_backward_comp_start(i)\n                        cached_output_micro_batches[i].backward(\n                            gradient=self.output_micro_batches_grad[i])\n                        self.torch_comp_stream.record_event(\n                            self.backward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(\n                            self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(\n                            self.backward_comp_ready_events[i])\n                        self.profile_mark_backward_send_start(i)\n                        self.comm.send(\n                            self.input_micro_batches[i].grad, dst=self.pre_node_rank, stream=cupy_send_stream)\n                        self.profile_mark_backward_send_end(i)\n            else:\n                \n                with torch.cuda.stream(self.torch_comp_stream) as st:\n                    self.profile_mark_backward_comp_start(i)\n                    loss = loss_func(\n                        input=cached_output_micro_batches[i], target=target_as_micro_batches[i])\n                    if not flag_profile:\n                        tr_loss.append(loss.item())\n                    if self.use_fp16:\n                        self.optimizer.scale(loss).backward()\n                    else:\n                        loss.backward()\n                    self.torch_comp_stream.record_event(\n                        self.backward_comp_ready_events[i])\n\n        if not flag_profile:\n            if self.pp_rank == self.pipeline_group_size - 1:\n                train_log(\n                    {\n                        'loss': sum(tr_loss)/len(tr_loss),\n                        'lr': self.scheduler.get_last_lr()[0],\n                    }, step=self.global_step,\n                )\n\n    def profiling_backward_stage(self):\n        torch.cuda.synchronize()\n        for i in range(self.micro_batch_num):\n            if self.pp_rank != self.pipeline_group_size - 1:\n                recv_slot = self.backward_recv_start_events[i].elapsed_time(\n                    self.backward_recv_ready_events[i]) * 1e+3\n                recv_log = {\"name\": \"recv\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"4. backward-recv\",\n                            \"ts\": self.get_ts(self.backward_recv_start_events[i]), \"dur\": recv_slot,\n                            \"args\": {\"micro-batch\": i}, \"cname\": \"startup\"}\n                # print(recv_log)\n                self.profiling_log.append(recv_log)\n\n            comp_slot = self.backward_comp_start_events[i].elapsed_time(\n                self.backward_comp_ready_events[i]) * 1e+3\n            comp_log = {\"name\": \"comp\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"5. backward-compute\",\n                        \"ts\": self.get_ts(self.backward_comp_start_events[i]), \"dur\": comp_slot,\n                        \"args\": {\"micro-batch\": i}, \"cname\": \"good\"}\n            # print(comp_log)\n            self.profiling_log.append(comp_log)\n            if self.pp_rank != 0:\n                send_slot = self.backward_send_start_events[i].elapsed_time(\n                    self.backward_send_end_events[i]) * 1e+3\n                send_log = {\"name\": \"send\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"6. backward-send\",\n                            \"ts\": self.get_ts(self.backward_send_start_events[i]), \"dur\": send_slot,\n                            \"args\": {\"micro-batch\": i}, \"cname\": \"thread_state_iowait\"}\n                # print(send_log)\n                self.profiling_log.append(send_log)\n\n    def save_on_disk(self, path):\n        os.makedirs(path, exist_ok=True)\n        torch.save(self.model.state_dict(), os.path.join(path, 'pytorch_model.bin'))\n        \n    def optimizer_step(self):\n        # hard code: grad clipping\n        if not self.use_fp16:\n            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)\n        if self.use_dp:\n            with torch.cuda.stream(self.torch_comp_stream):\n                self.torch_comp_stream.record_event(\n                    self.dp_optim.backward_ready_event)\n            self.dp_optim.optimizer_step()\n            self.scheduler.step()\n        else:\n            with torch.cuda.stream(self.torch_comp_stream):\n                if self.enable_tidy_profiling:\n                    self.optimizer_start_event.record()\n                self.optimizer.step()\n                self.scheduler.step()\n                if self.enable_tidy_profiling:\n                    self.optimizer_end_event.record()\n        if self.enable_tidy_profiling:\n            self.profiling_optimizer_step()\n\n    def profiling_optimizer_step(self):\n        torch.cuda.synchronize()\n        if not self.use_dp:\n            optimizer_slot = self.optimizer_start_event.elapsed_time(\n                self.optimizer_end_event) * 1e+3\n            optimizer_log = {\"name\": \"opt\", \"ph\": \"X\", \"pid\": self.global_rank, \"tid\": \"7. optimizer-step\",\n                             \"ts\": self.get_ts(self.optimizer_start_event), \"dur\": optimizer_slot, \"cname\": \"bad\"}\n            # print(optimizer_log)\n            self.profiling_log.append(optimizer_log)\n        else:\n            self.profiling_log.extend(self.dp_optim.profiling_data_parallel(\n                self.init_time_stamp, self.init_event))\n\n    def export_profiling_result(self, filename):\n        with open(filename, 'w') as outfile:\n            json.dump(self.profiling_log, outfile)\n\n    def sgd_iter(self, input_=None, target=None,\n                 aux_input_data=None, loss_func=torch.nn.functional.cross_entropy):\n        \n        \n        if self.use_fp16 and self.use_dynamic_scale:\n            scales_buffer = [torch.ones_like(self.optimizer.grad_scaler._scale) for _ in range(self.pipeline_group_size)]\n            self.comm.all_gather(self.optimizer.grad_scaler._scale, scales_buffer)\n            self.optimizer.grad_scaler._scale.data[:] = min([s.item() for s in scales_buffer])\n        \n        self.comm.barrier()\n            \n        start_time = time.time()\n        if self.enable_tidy_profiling:\n            torch.cuda.synchronize()\n            self.init_time_stamp = time.time() * 1e+6\n            self.init_event.record()\n\n        step = self.global_step % self.gradient_accumulate_step\n        self.zero_input_grad()\n        if step == 0:\n            self.optimizer.zero_grad(set_to_none=False)\n            \n        if step == self.gradient_accumulate_step - 1 and self.use_dp:\n            if hasattr(self.dp_optim, 'pre_optimizer_step'):\n                self.dp_optim.pre_optimizer_step()\n\n        outputs = self.forward_stage(input_, aux_input_data=aux_input_data)\n        forward_time = time.time()\n        forward_slot = forward_time-start_time\n        print(\"Rank {} node forward pass {}/{} takes {:3.2f}s\"\n              .format(self.global_rank, step, self.gradient_accumulate_step, forward_slot))\n        \n        # This is an educated guess that such barrier would make it fair TC (probably required)\n        # self.comm.barrier()\n        self.backward_stage(outputs, target, loss_func=loss_func)\n        backward_time = time.time()\n        print(\"Rank {} node backward pass {}/{} takes {:3.2f}s\"\n              .format(self.global_rank, step, self.gradient_accumulate_step, backward_time-forward_time))\n        if step == self.gradient_accumulate_step - 1:\n            optimizer_time = time.time()\n            self.optimizer_step()\n            torch.cuda.synchronize()\n            \n            if self.enable_tidy_profiling:\n                self.profiling_forward_stage()\n                self.profiling_backward_stage()\n            \n            print('after cuda sync', self.global_rank)\n            self.comm.barrier()\n            end_time = time.time()\n            print(\"Rank {} node optimizer step takes {:3.2f}s\".format(\n                self.global_rank, end_time - optimizer_time))\n        else:\n            self.comm.barrier()\n            end_time = time.time()\n        iter_time = end_time - start_time\n        print(\"Rank {} node whole iteration takes {:3.2f}s\".format(\n            self.global_rank, iter_time))\n        print(\"-------------------------------------------\")\n        # torch.cuda.empty_cache()\n        # print(torch.cuda.memory_summary())\n        self.global_step += 1\n        return iter_time\n    \n    \n    def infer_stage(self, input_data=None, aux_input_data=None, \n                    labels=None, pred_func=None):\n        \n        if aux_input_data is not None:\n            for k in aux_input_data:\n                aux_input_data[k] = torch.chunk(aux_input_data[k], self.micro_batch_num, dim=0)\n        else:\n            aux_input_data = {}\n        \n        if self.pp_rank == 0:\n            assert(input_data is not None)\n            self.input_micro_batches = torch.chunk(input_data, self.micro_batch_num, dim=0)\n        if self.pp_rank == self.pipeline_group_size - 1:\n            if input_data is not None:\n                input_ids_micro_batches = torch.chunk(input_data, self.micro_batch_num, dim=0)\n            else:\n                input_ids_micro_batches = [None]*self.micro_batch_num\n            if labels is not None:\n                labels = torch.chunk(labels, self.micro_batch_num, dim=0)\n            else:\n                labels = [None]*self.micro_batch_num\n                \n        output_micro_batches = []\n\n        for i in range(self.micro_batch_num):\n            if self.pipeline_group_size > 1:\n                if self.pp_rank == 0:  # Only send output to next node, do not receive\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i], \n                            **{k: v[i] for k, v in aux_input_data.items()},\n                        )\n                        self.torch_comp_stream.record_event(self.forward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(self.forward_comp_ready_events[i])\n                        self.comm.send(current_micro_output.data, dst=self.post_node_rank, stream=cupy_send_stream)\n                elif self.pp_rank == self.pipeline_group_size - 1:  # Only receive input from last node, do not send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(self.torch_recv_stream.cuda_stream)\n                        self.comm.recv(self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(self.forward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(self.forward_recv_ready_events[i])\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i], input_ids=input_ids_micro_batches[i],\n                            **{k: v[i] for k, v in aux_input_data.items()},\n                        )\n                        current_micro_output = pred_func(current_micro_output, labels[i])\n                        self.torch_comp_stream.record_event(self.forward_comp_ready_events[i])\n                else:  # receive, compute, and send\n                    with torch.cuda.stream(self.torch_recv_stream):\n                        cupy_recv_stream = cupy.cuda.ExternalStream(self.torch_recv_stream.cuda_stream)\n                        self.comm.recv(self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream)\n                        self.torch_recv_stream.record_event(self.forward_recv_ready_events[i])\n                    with torch.cuda.stream(self.torch_comp_stream):\n                        self.torch_comp_stream.wait_event(self.forward_recv_ready_events[i])\n                        current_micro_output = self.model(\n                            self.input_micro_batches[i],\n                            **{k: v[i] for k, v in aux_input_data.items()},\n                        )\n                        self.torch_comp_stream.record_event(self.forward_comp_ready_events[i])\n                    with torch.cuda.stream(self.torch_send_stream):\n                        cupy_send_stream = cupy.cuda.ExternalStream(self.torch_send_stream.cuda_stream)\n                        self.torch_send_stream.wait_event(self.forward_comp_ready_events[i])\n                        self.comm.send(current_micro_output.data, dst=self.post_node_rank, stream=cupy_send_stream)\n            else:\n                with torch.cuda.stream(self.torch_comp_stream):\n                    current_micro_output = self.model(\n                        self.input_micro_batches[i],\n                        **{k: v[i] for k, v in aux_input_data.items()}\n                    )\n                    current_micro_output = pred_func(current_micro_output, labels[i])\n                    self.torch_comp_stream.record_event(\n                        self.forward_comp_ready_events[i])\n                    \n            output_micro_batches.append(current_micro_output)\n            \n        return output_micro_batches\n    \n    def infer_iter(self, input_=None, target=None, \n                   output_=None, \n                   aux_input_data=None, \n                   pred_func=None):\n        # self.comm.barrier()\n        torch.cuda.synchronize()\n        with torch.no_grad():\n            outputs = self.infer_stage(input_, \n                                       aux_input_data=aux_input_data,\n                                       labels=target, pred_func=pred_func)\n            if output_ is not None:\n                outputs = torch.cat(outputs, 0).mean().item()\n                print(outputs)\n                output_.append(outputs)\n        torch.cuda.synchronize()\n        # self.comm.barrier()\n\n"
  },
  {
    "path": "training/pipeline_parallel/dist_pp_utils.py",
    "content": "from .dist_gpipe_pipeline_async import GpipeAsync\n\n\ndef get_pp_module(args, config, device, use_dp):\n    \n    if args.pp_mode == 'gpipe':\n        return GpipeAsync(args, config, device, use_dp)\n    else:\n        print(\"Not recognize this pipeline parallel mode.\")\n        assert False\n        \n"
  },
  {
    "path": "training/tasks/__init__.py",
    "content": ""
  },
  {
    "path": "training/tasks/data_loaders/__init__.py",
    "content": ""
  },
  {
    "path": "training/tasks/data_loaders/data_utils.py",
    "content": "import os\nimport re\nimport torch\nimport json\nimport numpy as np\nfrom torch.utils.data import IterableDataset, DataLoader\nfrom itertools import cycle, islice\nimport random\nfrom datasets import Dataset\nfrom datasets import load_dataset, load_from_disk\nfrom comm.comm_utils import *\n\n\nfrom itertools import islice\nfrom random import randint\n\nSHOW_DATA = int(os.environ.get('SHOW_DATA', 1))\nUL2R_DENOISE_ENABLED = int(os.environ.get('UL2R_DENOISE_ENABLED', 0))\n\n\nimport os\nimport re\nimport torch\nfrom torch.utils.data import IterableDataset, DataLoader\nfrom itertools import cycle, islice\nimport random\nfrom datasets import Dataset\nfrom datasets import load_dataset, load_from_disk\nfrom comm.comm_utils import *\n\n\ndef random_chunk(li, min_chunk=1, max_chunk=5):\n    it = iter(li)\n    while True:\n        nxt = list(islice(it,randint(min_chunk,max_chunk)))\n        if nxt:\n            yield nxt\n        else:\n            break\n            \n\nclass UL2RProcessor:\n    '''\n    This is a replication of UL2R from our understanding.\n    We welcome PR if there are better implementations.\n    '''\n    \n    def __init__(self, tokenizer, seq_length=1024):\n        self.tokenizer = tokenizer\n        self.seq_length = seq_length\n        \n        self.s2s_prefix = self.tokenizer(\"[S2S]\")['input_ids']\n        self.nlg_prefix = self.tokenizer(\"[NLG]\")['input_ids']\n        self.nlu_prefix = self.tokenizer(\"[NLU]\")['input_ids']\n        \n        self.extra_ids = [self.tokenizer.eos_token_id - 100 + i for i in range(80)]\n        \n        \n    def preprocess_tokens_s2s(self, tokens):\n        \n        tokens = self.s2s_prefix + tokens\n        \n        split = int(random.random() * len(tokens))\n        \n        tokens = tokens[:split] + tokens[split:]\n        tokens = tokens[:self.seq_length]\n        \n        prefix_masks = torch.zeros(len(tokens), dtype=torch.uint8)\n        prefix_masks[:split] = 1\n        \n        return {\n            'input_ids': torch.tensor(tokens),\n            'prefix_masks': prefix_masks,\n        }\n    \n    def preprocess_tokens_nlg(self, tokens):\n        \n        tokens = tokens[:self.seq_length - len(self.nlg_prefix) - 2]\n        \n        start = int(random.random() * len(tokens))\n        end = start + 1 + int(random.random() * 31)\n        \n        left = self.nlg_prefix + tokens[:start] + [self.extra_ids[0]] + tokens[end:]\n        right = [self.extra_ids[0]] + tokens[start:end]\n    \n        tokens = left + right\n        tokens = tokens[:self.seq_length]\n        tokens = tokens + (self.seq_length - len(tokens)) * [self.tokenizer.eos_token_id]\n        \n        prefix_masks = torch.zeros(len(tokens), dtype=torch.uint8)\n        prefix_masks[:len(left)] = 1\n        \n        return {\n            'input_ids': torch.tensor(tokens),\n            'prefix_masks': prefix_masks,\n        }\n        \n    def preprocess_tokens_nlu(self, tokens):\n        \n        tokens = tokens[:self.seq_length - len(self.nlu_prefix) - 10]\n        \n        # split to chunks\n        chunks = list(random_chunk(tokens, min_chunk=1, max_chunk=5))\n        \n        # randomly select 15%\n        K = int(0.15 * len(chunks))\n        indices = random.sample(range(len(chunks)), K)\n        \n        left = self.nlu_prefix\n        right = []\n        extra_id_count = 0\n        \n        last_corrupt = False\n        for i, chunk in enumerate(chunks):\n            # make sure not consecutive corrupt chunks\n            if i in indices and not last_corrupt and extra_id_count < len(self.extra_ids):\n                left += [self.extra_ids[extra_id_count]]\n                right += [self.extra_ids[extra_id_count]] + chunk\n                extra_id_count += 1\n            else:\n                left += chunk\n                last_corrupt = False\n        \n        tokens = left + right\n        tokens = tokens[:self.seq_length]\n        tokens = tokens + (self.seq_length - len(tokens)) * [self.tokenizer.eos_token_id]\n        \n        prefix_masks = torch.zeros(len(tokens), dtype=torch.uint8)\n        prefix_masks[:len(left)] = 1\n        \n        return {\n            'input_ids': torch.tensor(tokens),\n            'prefix_masks': prefix_masks,\n        }\n\n    def preprocess_ul2r(self, inputs):\n        tokens = inputs['input_ids'].tolist()\n        p = random.random()\n        if p > 0.5:\n            return self.preprocess_tokens_s2s(tokens)\n        elif p > 0.25:\n            return self.preprocess_tokens_nlg(tokens)\n        else:\n            return self.preprocess_tokens_nlu(tokens)\n    \n    def preprocess_random(self, inputs):\n        \n        tokens = inputs['input_ids'].tolist()\n        \n        if random.random() < 0.2:\n            # short prompt\n            split = int(random.random() * 20)\n        else:\n            # random length prompt\n            split = int(random.random() * len(tokens))\n        \n        tokens = tokens[:split] + tokens[split:]\n        tokens = tokens[:self.seq_length]\n        \n        prefix_masks = torch.zeros(len(tokens), dtype=torch.uint8)\n        prefix_masks[:split] = 1\n        \n        return {\n            'input_ids': torch.tensor(tokens),\n            'prefix_masks': prefix_masks,\n        }\n    \n    def __call__(self, inputs):\n        if UL2R_DENOISE_ENABLED:\n            return self.preprocess_ul2r(inputs)\n        else:\n            return self.preprocess_random(inputs)\n\n\nclass StreamDataset(IterableDataset):\n    default_doc_separator = '\\n'\n    def __init__(self, data, tokenizer, seq_length=1024, doc_separator=None, cycling=True):\n        self.data = data\n        self.tokenizer = tokenizer\n        self.seq_length = seq_length\n        self.doc_separator = doc_separator or StreamDataset.default_doc_separator\n        self.cycling = cycling\n        self.it = None\n        self.iter_count = 0\n        self.buffer_tokens = []\n        \n    def state_dict(self):\n        return {}\n    \n    def load_state_dict(self, state_dict):\n        pass\n        \n    def get_sequence(self):\n        buffer_tokens = self.buffer_tokens\n        for x in self.data:\n            self.iter_count += 1\n            curr_tokens = self.tokenizer(self.doc_separator + x['text'])['input_ids']\n            buffer_tokens += curr_tokens\n            while len(buffer_tokens) >= self.seq_length:\n                tokens = buffer_tokens[:self.seq_length]\n                buffer_tokens = buffer_tokens[self.seq_length:]\n                input_ids = torch.tensor(tokens)\n                self.buffer_tokens = buffer_tokens # update for restore\n                yield {\n                    'input_ids': input_ids,\n                }\n                \n    def get_stream(self):\n        if self.cycling:\n            return cycle(self.get_sequence())\n        else:\n            return self.get_sequence()\n    \n    def __iter__(self):\n        if self.it is None:\n            self.it = self.get_stream()\n        return self.it\n\n\nclass StreamDatasetList(IterableDataset):\n    def __init__(self, task_names, datasets, sample_probs, tokenizer, seq_length=1024, print_sample_every_n=64, post_processor=None):\n        \n        self.task_names = task_names\n        self.datasets = datasets\n        self.sample_probs = sample_probs\n        self.tokenizer = tokenizer\n        self.seq_length = seq_length\n        self.print_sample_every_n = print_sample_every_n\n        self.post_processor = post_processor\n        self.token_count = None\n        \n        self.it = None\n        \n    def state_dict(self):\n        return {}\n    \n    def load_state_dict(self, state_dict):\n        pass\n        \n    def get_sequence(self):\n        \n        iterators = [cycle(d.get_sequence()) for d in self.datasets]\n        prob_ths = np.cumsum([p / sum(self.sample_probs) for p in self.sample_probs])\n        \n        global_i = 0\n        \n        while True:\n            \n            p = random.random()\n            \n            for task_name, it, th in zip(self.task_names, iterators, prob_ths):\n                if p < th:\n                    \n                    inputs = next(it)\n                    \n                    if self.post_processor is not None:\n                        inputs = self.post_processor(inputs)\n                    \n                    if SHOW_DATA:\n                        if global_i % self.print_sample_every_n == 0:\n                            print(p, th)\n                            print(f\"**{task_name}**:\", self.tokenizer.decode(inputs['input_ids']))\n                        \n                    yield inputs\n                    global_i += 1\n                    break\n                \n    def get_stream(self):\n        return cycle(self.get_sequence())\n    \n    def __iter__(self):\n        if self.it is None:\n            self.it = self.get_stream()\n        return self.it\n\n    def tokenize_function(self, examples):\n        # Update here\n        output = self.tokenizer(\n            examples[\"text\"], padding=False, truncation=True, max_length=self.tokenizer.model_max_length,\n        )\n        return output\n        \n    # Compute the number of tokens in a dataset using a Torch tokenizer\n    # - return: the sum of tokens from the the text field of each sample in the dataset\n    def get_dataset_token_count(self) -> int:\n        if self.token_count is not None:\n            return self.token_count\n\n        self.token_count = 0\n\n        if self.task_names is None:\n            return self.token_count\n\n        raw_datasets = load_dataset(\n                        \"json\",\n                        data_files=self.task_names,\n                        split=\"train\",\n                    )\n        \n        column_names = list(raw_datasets.features)\n        \n        tokenized_datasets = raw_datasets.map(\n            self.tokenize_function,\n            batched=True,\n            remove_columns=column_names,\n            desc=\"Running tokenizer on dataset\",\n        )\n        \n        for item in tokenized_datasets:\n            self.token_count += len(item['input_ids'])\n\n        return self.token_count\n\n    def get_dataset_example_count(self) -> int:\n        num_lines = 0\n\n        if self.task_names is None:\n            return num_lines\n\n        for jsonl_file in self.task_names:\n            with open(jsonl_file, \"r\") as file:\n                for line in file:\n                    if line.replace(\" \", \"\") != \"\\n\":\n                        num_lines += 1\n\n        return num_lines\n\n    \ndef name_to_dataset(task, tokenizer, args):\n    \n    if 'prosocial_plus_regular.jsonl' in task:\n        from .prosocial import StreamDataset as _StreamDataset\n        data = load_dataset(\"json\", data_files=task, split=\"train\", streaming=True).shuffle(buffer_size=100_000, seed=args.seed)\n        dataset = _StreamDataset(data, tokenizer, args.seq_length)\n    elif task != '':\n        data = load_dataset(\"json\", data_files=task, split=\"train\", streaming=True).shuffle(buffer_size=100_000, seed=args.seed)\n        dataset = StreamDataset(data, tokenizer, args.seq_length)\n    else:\n        raise Exception('One of the provided datasets is an empty string.')\n        \n    return dataset\n\ndef name_to_dataset_eval(task, tokenizer, args):\n    \n    if task != '':\n        data = load_dataset(\"json\", data_files=task, split=\"train\", streaming=True)\n        dataset = StreamDataset(data, tokenizer, args.seq_length, cycling=False)\n        \n    return dataset\n\n    \ndef get_train_data_loader(args, tokenizer, num_workers=1, state_dict=None):\n    \n    task_list = args.task_name.split(',')\n    task_names = []\n    datasets = []\n    probs = []\n    \n    print('data_utils: parse task_list')\n    \n    for task in task_list:\n        if task.startswith('http'):\n            # data from url has an addtional :\n            if len(task.split(':')) == 3:\n                prefix, task, prob = task.strip().split(':')\n                task = f\"{prefix}:{task}\"\n                prob = float(prob)\n            elif len(task.split(':')) == 2:\n                task = task.strip()\n                prob = 1.0\n            else:\n                raise Exception('Cannot parse task.')\n        elif ':' in task:\n            task, prob = task.strip().split(':')\n            prob = float(prob)\n        else:\n            task = task.strip()\n            prob = 1.0\n            \n        dataset = name_to_dataset(task, tokenizer, args)\n            \n        print('data_utils:', task, prob)\n    \n        task_names.append(task)\n        datasets.append(dataset)\n        probs.append(prob)\n    \n    stream_dataset = StreamDatasetList(\n        task_names, datasets, probs,\n        tokenizer=tokenizer, seq_length=args.seq_length)\n    \n    if state_dict is not None:\n        stream_dataset.load_state_dict(state_dict)\n    \n    train_data_loader = torch.utils.data.DataLoader(stream_dataset,\n                                                    batch_size=args.batch_size * args.data_group_size,\n                                                    shuffle=False,\n                                                    num_workers=num_workers,\n                                                    pin_memory=True,\n                                                    collate_fn=None)\n    \n    print('data_utils: get train_data_loader')\n    \n    return train_data_loader\n\n\ndef get_eval_data_loader(args, tokenizer, num_workers=1, state_dict=None):\n    \n    task_list = args.task_name.split(',')\n    task_names = []\n    datasets = []\n    probs = []\n    \n    print('data_utils: parse task_list')\n    \n    evaluation_data = args.evaluation_data\n    \n    if evaluation_data is None:\n        return None\n    \n    dataset = name_to_dataset_eval(evaluation_data, tokenizer, args)\n    \n    train_data_loader = torch.utils.data.DataLoader(dataset,\n                                                    batch_size=args.batch_size,\n                                                    shuffle=False,\n                                                    drop_last=True,\n                                                    num_workers=num_workers,\n                                                    pin_memory=True,\n                                                    collate_fn=None)\n    \n    return train_data_loader\n\n\ndef get_ul2r_train_data_loader(args, tokenizer, num_workers=1, state_dict=None):\n    \n    task_list = args.task_name.split(',')\n    task_names = []\n    datasets = []\n    probs = []\n    for task in task_list:\n        if ':' in task:\n            task, prob = task.strip().split(':')\n            prob = float(prob)\n        else:\n            task = task.strip()\n            prob = 1.0\n            \n        dataset = name_to_dataset(task, tokenizer, args)\n    \n        task_names.append(task)\n        datasets.append(dataset)\n        probs.append(prob)\n        \n    ul2r_processor = UL2RProcessor(tokenizer, seq_length=args.seq_length)\n    \n    stream_dataset = StreamDatasetList(\n        task_names, datasets, probs,\n        tokenizer=tokenizer, seq_length=args.seq_length, post_processor=ul2r_processor)\n    \n    if state_dict is not None:\n        stream_dataset.load_state_dict(state_dict)\n    \n    train_data_loader = torch.utils.data.DataLoader(stream_dataset,\n                                                    batch_size=args.batch_size * args.data_group_size,\n                                                    shuffle=False,\n                                                    num_workers=num_workers,\n                                                    pin_memory=True,\n                                                    collate_fn=None)\n    \n    print('ul2r dataloader init done.')\n    \n    return train_data_loader\n"
  },
  {
    "path": "training/tasks/data_loaders/prosocial.py",
    "content": "import os\nimport re\nimport torch\nimport json\nfrom torch.utils.data import IterableDataset, DataLoader\nfrom itertools import cycle, islice\nimport random\nfrom datasets import Dataset\nfrom datasets import load_dataset, load_from_disk\nfrom comm.comm_utils import *\n\n\n\nclass StreamDataset(IterableDataset):\n    def __init__(self, dataset, tokenizer, seq_length=1024):\n        \n        self.dataset = dataset\n        \n        self.tokenizer = tokenizer\n        self.seq_length = seq_length\n        \n        self.it = None\n        self.iter_count = 0\n        \n    def state_dict(self):\n        return {\n            'iter_count': self.iter_count,\n        }\n    \n    def load_state_dict(self, state_dict):\n        self.iter_count = state_dict['iter_count']\n        self.dataset = self.dataset.skip(self.iter_count)\n        \n    def get_sequence(self):\n        \n        it = cycle(iter(self.dataset))\n        \n        while True:\n\n            text_context = '''Possible labels:\n1. casual\n2. needs caution\n3. needs intervention\n4. possibly needs caution\n5. probably needs caution'''\n\n            while True:\n                \n                instance = next(it)\n                \n                text = instance['text']\n                text_context += '\\n\\n' + text\n                \n                input_ids = self.tokenizer(text_context.strip())['input_ids']\n                if len(input_ids) > self.seq_length:\n                    break\n                \n            input_ids = input_ids[:self.seq_length]\n            input_ids = torch.tensor(input_ids).long()\n\n            yield {\n                'input_ids': input_ids,\n            }\n            \n                \n    def get_stream(self):\n        return cycle(self.get_sequence())\n    \n    def __iter__(self):\n        if self.it is None:\n            self.it = self.get_stream()\n        return self.it\n    "
  },
  {
    "path": "training/utils/__init__.py",
    "content": ""
  },
  {
    "path": "training/utils/dist_args_utils.py",
    "content": "def add_device_arguments(parser):\n    parser.add_argument('--use-cuda', default=True, type=lambda x: (str(x).lower() == 'true'),\n                        help='if this is set to True, will use cuda to train')\n    parser.add_argument('--cuda-id', type=int, default=0, metavar='N',\n                        help='cuda index, if the instance has multiple GPUs.')\n    parser.add_argument('--cuda-num', type=int, default=1, metavar='N',\n                        help='number of GPUs, if the instance has multiple GPUs.')\n    parser.add_argument('--debug-mem', default=True, type=lambda x: (str(x).lower() == 'true'),\n                        help='if this is set to True, we will print some memory stats.')\n\n\ndef add_torch_distributed_arguments(parser):\n    parser.add_argument('--dist-backend', type=str, default='cupy_nccl', metavar='S',\n                        help='backend type for distributed PyTorch (default: cupy_nccl)')\n    parser.add_argument('--dp-backend', type=str, default='nccl', metavar='S',\n                        help='backend type for data parallel')\n    parser.add_argument('--dist-url', type=str, default='tcp://127.0.0.1:9000', metavar='S',\n                        help='master ip for distributed PyTorch')\n    parser.add_argument('--world-size', type=int, default=4, metavar='D',\n                        help='world-size (default: 4)')\n    parser.add_argument('--pipeline-group-size', type=int, default=4, metavar='D',\n                        help='world-size (default: 2)')\n    parser.add_argument('--data-group-size', type=int, default=1, metavar='D',\n                        help='world-size (default: 1)')\n    parser.add_argument('--rank', type=int, default=0, metavar='N',\n                        help='rank of the node')\n\n\ndef add_task_arguments(parser):\n    parser.add_argument('--train-data', nargs='+', default=['./glue_dataset/data/QQP/train.tsv'], metavar='S',\n                        help='path to the training data')\n    parser.add_argument('--valid-data', nargs='+', default=['./glue_dataset/data/QQP/test.tsv'], metavar='S',\n                        help='path to the training data')\n    parser.add_argument('--tokenizer-type', type=str, default='BertWordPieceLowerCase', metavar='S',\n                        help='which tokenizer to use.')\n    parser.add_argument('--vocab-file', type=str, default='./glue_dataset/data/bert-large-cased-vocab.txt', metavar='S',\n                        help='which tokenizer to use.')\n    parser.add_argument('--vocab-extra-ids', type=int, default=0, metavar='N',\n                        help='-')\n    parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128, metavar='N',\n                        help='-')\n    parser.add_argument('--optimizer', type=str, default='adamw', metavar='N',\n                        help='-')\n\n\ndef add_model_arguments(parser):\n    parser.add_argument('--seq-length', type=int, default=1024, metavar='N',\n                        help='-')\n    parser.add_argument('--embedding-dim', type=int, default=768, metavar='N',\n                        help='-')\n    parser.add_argument('--num-layers', type=int, default=4, metavar='N',\n                        help='-')\n    parser.add_argument('--num-heads', type=int, default=12, metavar='N',\n                        help='-')\n\n\ndef add_training_hyper_parameter_arguments(parser):\n    parser.add_argument('--train-log-backend', type=str, default='print', metavar='N',\n                        help='-')\n    parser.add_argument('--project-name', type=str, default='test', metavar='N',\n                        help='-')\n    parser.add_argument('--batch-size', type=int, default=32, metavar='N',\n                        help='input batch size for training (default: 100)')\n    parser.add_argument('--micro-batch-size', type=int, default=8, metavar='N',\n                        help='input micro batch size for training (default: 100)')\n    parser.add_argument('--lr', type=float, default=0.01, metavar='N',\n                        help='-')\n    parser.add_argument('--num-iters', type=int, default=10, metavar='N',\n                        help='-')\n\n\ndef add_mixed_precision_arguments(parser):\n    parser.add_argument('--fp16', action='store_true',\n                        help='Run model in fp16 mode.')\n    parser.add_argument('--loss-scale', type=float, default=0,\n                        help='Static loss scaling, positive power of 2 values can improve fp16 convergence. ')\n    parser.add_argument('--initial-loss-scale', type=float, default=32768,\n                        help='Initial loss-scale for dynamic loss scaling.')\n    parser.add_argument('--min-loss-scale', type=float, default=1.0,\n                        help='Minimum loss scale for dynamic loss scale.')\n    parser.add_argument('--loss-scale-window', type=float, default=1000,\n                        help='Window over which to raise/lower dynamic scale.')\n    parser.add_argument('--hysteresis', type=int, default=2,\n                        help='hysteresis for dynamic loss scaling')\n    parser.add_argument('--use-offload', action='store_true',\n                        help='Offload optim states to CPU')\n    \n\n\ndef add_parallel_schema_arguments(parser):\n    parser.add_argument('--pp-mode', type=str, default='gpipe', metavar='S',\n                        help='use which pipeline parallel mode: gpipe or 1f1b.')\n    parser.add_argument('--dp-mode', type=str, default='allreduce', metavar='S',\n                        help='use which data parallel mode: allreduce.')\n    parser.add_argument('--gradient-accumulate-step', type=int, default=1,\n                        help='Number of gradient computation in Pipeline without data parallel sync.')\n    \n\ndef get_model_arguments_str(args):\n    return '_l' + str(args.seq_length) + '_m' + str(args.embedding_dim)\n\n\ndef get_dist_arguments_str(args, add_rank=True):\n    dist_str = '_w' + str(args.world_size) + '_p' + str(args.pipeline_group_size) + \"_\" + \\\n               str(args.gradient_accumulate_step) + '_d' + str(args.data_group_size)\n    if add_rank:\n        dist_str = dist_str + '_' + str(args.rank)\n    return dist_str\n\n\ndef get_learning_arguments_str(args):\n    return '_b' + str(args.batch_size) + '_' + str(args.micro_batch_size)\n\n\ndef get_mixed_precision_arguments_str(args):\n    if args.fp16:\n        return '_fp16'\n    else:\n        return ''\n"
  },
  {
    "path": "training/utils/dist_checkpoint_utils.py",
    "content": "import os\nimport time\nimport random\nimport json\nimport numpy as np\nimport torch\n\nfrom comm.comm_utils import *\n\n\ndef load_checkpoint(pipe, args):\n    \n    if os.path.isfile(os.path.join(args.checkpoint_path, 'latest')):\n        with open(os.path.join(args.checkpoint_path, 'latest')) as f:\n            latest_step = int(f.read())\n    else:\n        print('no checkpoint available, skipping')\n        return\n    \n    checkpoint_step_path = os.path.join(args.checkpoint_path, f\"checkpoint_{latest_step}\")\n    \n    try:\n        with open(os.path.join(checkpoint_step_path, 'meta.json')) as f:\n            meta = json.load(f)\n    except:\n        print('failed to load meta.')\n        \n    pipe.global_step = latest_step\n    \n    try:\n        pipe.model.model.load_state_dict(\n            torch.load(\n                os.path.join(\n                    checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt'\n                ), map_location=torch.device('cpu')\n            )\n        )\n    except:\n        print('failed to load model params.')\n    \n    try:\n        pipe.optimizer.load_state_dict(\n            torch.load(\n                os.path.join(\n                    checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt'\n                ), map_location=torch.device('cpu')\n            )\n        )\n    except:\n        print('failed to load optim states.')\n    \n    try:\n        pipe.scheduler.load_state_dict(\n            torch.load(\n                os.path.join(\n                    checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt'\n                )\n            )\n        )\n    except:\n        print('failed to load scheduler states.')\n        \n            \ndef save_checkpoint(pipe, args) -> str:\n    \n    latest_step = pipe.global_step\n    checkpoint_step_path = os.path.join(args.checkpoint_path, f\"checkpoint_{latest_step}\")\n    \n    os.makedirs(checkpoint_step_path, exist_ok=True)\n\n    print(f\"Saving checkpoint to {checkpoint_step_path} ...\")\n\n    torch.save(\n        pipe.model.model.state_dict(),\n        os.path.join(\n            checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt'\n        )\n    )\n    \n    torch.save(\n        pipe.optimizer.state_dict(),\n        os.path.join(\n            checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt'\n        )\n    )\n    \n    torch.save(\n        pipe.scheduler.state_dict(),\n        os.path.join(\n            checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt'\n        )\n    )\n    \n    with open(os.path.join(checkpoint_step_path, 'meta.json'), 'w') as f:\n        json.dump({\n            'step': latest_step,\n        }, f)\n    \n    with open(os.path.join(args.checkpoint_path, 'latest'), 'w') as f:\n        f.write(f\"{latest_step}\")\n\n    print(f\"Checkpoint saved to {checkpoint_step_path} ... Done\")\n\n    return checkpoint_step_path\n        \n        \ndef save_stream_dataloader_state_dict(dataloader, pipe, args):\n    \n    latest_step = pipe.global_step\n    checkpoint_step_path = os.path.join(args.checkpoint_path, f\"checkpoint_{latest_step}\")\n    \n    os.system(f\"mkdir -p {checkpoint_step_path}\")\n    \n    torch.save(\n        dataloader.dataset.state_dict(),\n        os.path.join(\n            checkpoint_step_path, f'dataset_state_dict.pt'\n        )\n    )\n    \ndef load_stream_dataloader_state_dict(dataloader, pipe, args):\n    \n    latest_step = pipe.global_step\n    checkpoint_step_path = os.path.join(args.checkpoint_path, f\"checkpoint_{latest_step}\")\n    \n    try:\n        state_dict = torch.load(\n            os.path.join(\n                checkpoint_step_path, f'dataset_state_dict.pt'\n            )\n        )\n\n        dataloader.data.load_state_dict(state_dict)\n    \n    except Exception as e:\n        \n        print('failed to load dataset state_dict.')"
  },
  {
    "path": "training/utils/dist_debug_utils.py",
    "content": "import torch\n\n\ndef print_cuda_memory(args, info: str, device=None):\n    if args.debug_mem:\n        if device is None:\n            device = torch.device('cuda', args.cuda_id)\n        print(\"<{}>: current memory allocated: {:2.3f} MB, peak memory: {:2.3f} MB\".format(\n            info, torch.cuda.memory_allocated(device)/1048576, torch.cuda.max_memory_allocated(device)/1048576))\n\n\ndef print_multi_cuda_memory(args, info: str):\n    if args.debug_mem:\n        for local_gpu_rank in range(args.cuda_num):\n            device = torch.device('cuda', local_gpu_rank)\n            print(\"<{}>({}): current memory allocated: {:2.3f} MB, peak memory: {:2.3f} MB\".format(info, local_gpu_rank,\n                  torch.cuda.memory_allocated(device)/1048576, torch.cuda.max_memory_allocated(device)/1048576))\n"
  },
  {
    "path": "training/utils/event_report.py",
    "content": "#!/usr/bin/env python3\n\n# This application reports events that are stored in the event log REST service.\n# Events will be reported to the event log REST service via POST at:\n#\n# http://<endpoint>:<port>/v1/internal/fine-tunes/<job_id>/event\n#\n# with Bearer authorization tokens.\n#\n# The ouput formate is a JSON object with the following fields:\n# - \"object\": <object type>\n# - \"created_at\": <timestamp>\n# - \"level\": <event level>\n# - \"message\": <event message>\n# - \"type\": <event type>\n# - \"param_count\": <number of parameters> (optional)\n# - \"token_count\": <number of tokens> (optional)\n# - \"checkpoint_path\": <path to checkpoint> (optional)\n# - \"model_path\": <path to model> (optional)\n\n\nimport argparse\nimport json\nimport requests\nimport sys\nimport time\n\nclass EventReporter:\n\n    # Event type constants\n    EVENT_TYPE_JOB_START = \"JOB_START\"\n    EVENT_TYPE_MODEL_DOWNLOAD_COMPLETE = \"MODEL_DOWNLOAD_COMPLETE\"\n    EVENT_TYPE_TRAINING_DATA_DOWNLOAD_COMPLETE = \"TRAINING_DATA_DOWNLOAD_COMPLETE\"\n    EVENT_TYPE_TRAINING_START = \"TRAINING_START\"\n    EVENT_TYPE_CHECKPOINT_SAVE = \"CHECKPOINT_SAVE\"\n    EVENT_TYPE_EPOCH_COMPLETE = \"EPOCH_COMPLETE\"\n    EVENT_TYPE_TRAINING_COMPLETE = \"TRAINING_COMPLETE\"\n    EVENT_TYPE_JOB_COMPLETE = \"JOB_COMPLETE\"\n    EVENT_TYPE_JOB_ERROR = \"JOB_ERROR\"\n\n    supported_event_types = [\n        EVENT_TYPE_JOB_START,\n        EVENT_TYPE_MODEL_DOWNLOAD_COMPLETE,\n        EVENT_TYPE_TRAINING_DATA_DOWNLOAD_COMPLETE,\n        EVENT_TYPE_TRAINING_START,\n        EVENT_TYPE_CHECKPOINT_SAVE,\n        EVENT_TYPE_EPOCH_COMPLETE,\n        EVENT_TYPE_TRAINING_COMPLETE,\n        EVENT_TYPE_JOB_COMPLETE,\n        EVENT_TYPE_JOB_ERROR,\n    ]\n\n    # Event level constants\n    LEVEL_INFO = \"Info\"\n    LEVEL_WARNING = \"Warning\"\n    LEVEL_ERROR = \"Error\"\n\n    supported_event_levels = [\n        LEVEL_INFO,\n        LEVEL_WARNING,\n        LEVEL_ERROR,\n    ]\n\n    # Object type constants\n    OBJECT_FINE_TUNE = \"fine-tune\"\n\n    supported_object_types = [\n        OBJECT_FINE_TUNE,\n    ]\n\n    object_type_to_endpoint = {\n        \"fine-tune\": \"fine-tunes\",\n    }\n\n    def __init__(self, host=None, auth_token=None, job_id=None):\n        self.host = host\n        self.auth_token = auth_token\n        self.job_id = job_id\n\n    def is_enabled(self) -> bool:\n        # Validate the URL.\n        if self.host is None:\n            return False\n        \n        # Validate the authorization token.\n        if self.auth_token is None:\n            return False\n        \n        # Validate the job ID.\n        if self.job_id is None:\n            return False\n        \n        return True\n\n    # Report an event to the event log REST service.\n    # The event will be reported to the event log REST service via POST at:\n    # http://<endpoint>:<port>/v1/internal/fine-tunes/<job_id>/event\n    # with Bearer authorization tokens.\n    # The ouput formate is a JSON object with the following fields:\n    # - \"object\": object type to be reported. Supported object types are given by\n    #   `supported_object_types`\n    # - \"created_at\": The creation timestamp for the event. If not specified, the\n    #   current time will be used.\n    # - \"level\": Event level. Supported event levels are given by `supported_event_levels`\n    # - \"message\": Event message.\n    # - \"type\": Event type. Supported event types are given by `supported_event_types`\n    # - \"param_count\": Report the number of model parameters. (optional)\n    # - \"token_count\": Report the number of tokens in the training data. (optional)\n    # - \"checkpoint_path\": The path to a checkpoint file(s) (optional)\n    # - \"model_path\": The path to model file(s) (optional)\n    # - \"requires_is_enabled\": When true, verify that is_enabled to return true \n    #   and raises an exception if it does not. When false, this function silently\n    #   exits without error. (optional)\n    def report(self, object, message, event_type,\n               level=LEVEL_INFO, checkpoint_path=None,\n               model_path=None, param_count=None, token_count=None, \n               requires_is_enabled=True):\n\n        if requires_is_enabled:\n            # Validate the host.\n            if self.host is None:\n                raise ValueError(\"Host is required\")\n            \n            # Validate the authorization token.\n            if self.auth_token is None:\n                raise ValueError(\"Authorization token is required\")\n            \n            # Validate the job ID.\n            if self.job_id is None:\n                raise ValueError(\"Job ID is required\")\n        elif not self.is_enabled():\n            return\n        \n        # Get the creation timestamp.\n        created_at = int(time.time())\n        \n        # Validate the object type.\n        if object is None:\n            raise ValueError(\"Object type is required\")\n        elif object not in self.supported_object_types:\n            raise ValueError(f\"Invalid object type : {object}\")\n        \n        # Validate the message.\n        if message is None:\n            raise ValueError(\"Message is required\")\n\n        # Validate the event type.\n        if event_type is None:\n            raise ValueError(\"Event type is required\")\n        elif event_type not in self.supported_event_types:\n            raise ValueError(f\"Invalid event type : {event_type}\")\n        \n        # Validate the event level.\n        if level is None:\n            level = self.supported_event_levels[0]\n        elif level not in self.supported_event_levels:\n            raise ValueError(f\"Invalid event level : {level}\")\n\n        # Create the JSON object.\n        event = {\n            \"object\": object,\n            \"created_at\": created_at,\n            \"level\": level,\n            \"message\": message,\n            \"type\": event_type\n        }\n        if checkpoint_path is not None and len(checkpoint_path) > 0:\n            event[\"checkpoint_path\"] = checkpoint_path\n        if model_path is not None and len(model_path) > 0:\n            event[\"model_path\"] = model_path\n        if param_count is not None:\n            event[\"param_count\"] = int(param_count)\n        if token_count is not None:\n            event[\"token_count\"] = int(token_count)\n        event_str = json.dumps(event)\n\n        # Send the event to the event log REST service.\n        headers = {\n            \"Authorization\": f\"Bearer {self.auth_token}\",\n            \"Content-Type\": \"application/json\"\n        }\n        endpoint = f\"{self.host}/v1/privlieged/{self.object_type_to_endpoint[object]}/{self.job_id}/event\"\n        response = requests.post(endpoint, headers=headers, data=event_str)\n        if response.status_code != 200:\n            raise ValueError(f\"Failed to send event to event log REST service: ({response.status_code}) response=\\\"{response.text}\\\"\\nEvent: {event_str}\")\n        print(f\"Event reported: {event_str}\")\n        \ndef add_entry_reporter_arguments(parser):\n    parser.add_argument('--event-host', type=str, required=False,\n                        metavar='endpoint:port', help='Event reporting entrypoint URL')\n    parser.add_argument('--event-auth-token', type=str, required=False,\n                        help='Bearer authorization token')\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-u', '--event-host', type=str, required=True,\n                        metavar='<scheme><hostname>:<port>',\n                        help='Event reporting entrypoint URL (e.g. https://127.0.0.1:8895)')\n    parser.add_argument('-a', '--auth-token', type=str, required=True,\n                        help='Bearer authorization token')\n    parser.add_argument('-j', '--job-id', type=str, required=True, help='job id')\n    parser.add_argument('-o', '--object', type=str, required=True, help='object type',\n                        metavar=\"|\".join(EventReporter.supported_object_types))\n    parser.add_argument('-m', '--message', type=str, required=True, help='event message')\n    parser.add_argument('-e', '--event-type', type=str, required=True, help='event type',\n                        metavar=\"|\".join(EventReporter.supported_event_types))\n    parser.add_argument('-c', '--created-at', type=str, required=False, help='timestamp')\n    parser.add_argument('-C', '--checkpoint-path', type=str, required=False, help='S3 checkpoint path')\n    parser.add_argument('-M', '--model-path', type=str, required=False, help='S3 model path')\n    parser.add_argument('-p', '--param-count', type=int, required=False, help='number of parameters')\n    parser.add_argument('-t', '--token-count', type=int, required=False, help='number of tokens')\n    parser.add_argument('-l', '--level', type=str, required=False, help='event level',\n                        metavar=\"|\".join(EventReporter.supported_event_levels))\n    args = parser.parse_args()\n\n    # Create the event reporter.\n    event_reporter = EventReporter(host=args.event_host,\n                                   auth_token=args.auth_token,\n                                   job_id=args.job_id)\n    \n    event_reporter.report(object=args.object,\n                          message=args.message,\n                          event_type=args.event_type,\n                          level=args.level,\n                          checkpoint_path=args.checkpoint_path,\n                          model_path=args.model_path,\n                          param_count=args.param_count,\n                          token_count=args.token_count)\n\n#usage: event_report.py [-h] -u <scheme><hostname>:<port> -a AUTH_TOKEN -j\n#                       JOB_ID -o fine-tune -m MESSAGE -e\n#                       JOB_START|MODEL_DOWNLOAD_COMPLETE|TRAINING_DATA_DOWNLOAD_COMPLETE|TRAINING_START|CHECKPOINT_SAVE|EPOCH_COMPLETE|TRAINING_COMPLETE|JOB_COMPLETE|JOB_ERROR\n#                       [-c CREATED_AT] [-C CHECKPOINT_PATH] [-M MODEL_PATH]\n#                       [-p PARAM_COUNT] [-t TOKEN_COUNT]\n#                       [-l Info|Warning|Error]\n#\n#optional arguments:\n#  -h, --help            show this help message and exit\n#  -u, --event-host <scheme><hostname>:<port>\n#                        Event reporting entrypoint URL (e.g.\n#                        https://127.0.0.1:8895)\n#  -a, --auth-token AUTH_TOKEN\n#                        Bearer authorization token\n#  -j, --job-id JOB_ID\n#                        job id\n#  -o, --object fine-tune\n#                        object type\n#  -m, --message MESSAGE\n#                        event message\n#  -e, --event-type JOB_START|MODEL_DOWNLOAD_COMPLETE|TRAINING_DATA_DOWNLOAD_COMPLETE|TRAINING_START|CHECKPOINT_SAVE|EPOCH_COMPLETE|TRAINING_COMPLETE|JOB_COMPLETE|JOB_ERROR\n#                        event type\n#  -c, --created-at CREATED_AT\n#                        timestamp\n#  -C, --checkpoint-path CHECKPOINT_PATH\n#                        S3 checkpoint path\n#  -M, --model-path MODEL_PATH\n#                        S3 model path\n#  -p, --param-count PARAM_COUNT\n#                        number of parameters\n#  -t, --token-count TOKEN_COUNT\n#                        number of tokens\n#  -l, --level Info|Warning|Error\n#                        event level\nif __name__ == '__main__':\n    try:\n        main()\n    except Exception as e:\n        print(e)\n        sys.exit(1)\n    \n    sys.exit(0)"
  },
  {
    "path": "training/utils/logging_utils.py",
    "content": "import os\n\ntry:\n    import wandb\n    _has_wandb = True\nexcept:\n    _has_wandb = False\n    print(\"wandb is not installed.\")\n    \ntry:\n    import loguru\n    _has_loguru = True\nexcept:\n    _has_loguru = False\n    print(\"loguru is not installed.\")\n    \ntrain_log_backend = None\n    \ndef init_train_logger(args):\n    \n    global train_log_backend\n    train_log_backend = getattr(args, 'train_log_backend', 'print')\n    \n    if train_log_backend == 'print':\n        pass\n    elif train_log_backend == 'loguru':\n        os.system(\"mkdir -p logs\")\n        loguru.logger.add(\"logs/file_{time}.log\")\n    elif train_log_backend == 'wandb':\n        \n        assert _has_wandb\n        \n        if not hasattr(args, 'project_name'):\n            import re\n            args.project_name = \"test-\" + \\\n                re.sub('[^a-zA-Z0-9 \\n\\.]', '_', args.task_name)\n\n        wandb.init(\n            project=args.project_name, \n            config=args,\n        )\n        \n    else:\n        raise Exception('Unknown logging backend.')\n        \ndef train_log(x, *args, **kargs):\n    \n    if train_log_backend == 'print':\n        print(x)\n    elif train_log_backend == 'loguru':\n        loguru.logger.info(x)\n    elif train_log_backend == 'wandb':\n        wandb.log(x, *args, **kargs)\n    else:\n        raise Exception('Unknown logging backend.')\n    \n    "
  },
  {
    "path": "training/utils/upload_manager.py",
    "content": "import argparse\nimport boto3\nimport concurrent.futures\nimport os\nimport re\nimport sys\nimport time\n\nfrom utils.event_report import *\n\nclass UploadManager:\n    def __init__(self, aws_endpoint_url: str, aws_access_key_id: str,\n                 aws_secret_access_key: str, aws_session_token: str = None,\n                 aws_region: str = \"auto\", event_reporter: EventReporter = None,\n                 n_stages: int = 1, max_wait_sec: int = 600, dry_run: bool = False):\n\n        self.executor = concurrent.futures.ThreadPoolExecutor()\n        self.futures = []\n\n        if aws_endpoint_url is not None and aws_access_key_id is not None and aws_secret_access_key is not None and aws_region is not None:\n            # Create an S3 client\n            self.aws_access_key_id = aws_access_key_id\n            self.aws_secret_access_key = aws_secret_access_key\n            self.aws_session_token = aws_session_token\n            self.aws_region = aws_region\n            self.aws_endpoint_url = aws_endpoint_url\n            self.enabled = True\n        else:\n            self.aws_access_key_id = None\n            self.aws_secret_access_key = None\n            self.aws_session_token = None\n            self.aws_region = None\n            self.aws_endpoint_url = None\n            self.enabled = False\n\n        self.event_reporter = event_reporter\n        self.dry_run = dry_run\n        if n_stages < 1 and self.enabled:\n            raise ValueError(\"n_stages must be greater than or equal to 1\")\n        self.n_stages = n_stages\n        self.max_wait_sec = max_wait_sec\n\n    def add_task(self, directory: str, checkpoint_upload_prefix: str, step: int = 0):\n        if self.enabled:\n            # Check that the provided checkpoint upload s3 prefix is valid regex\n            if not re.match(r\"s3://[a-zA-Z0-9.\\-_]{3,255}/.+\", checkpoint_upload_prefix):\n                raise ValueError(\"checkpoint_upload_prefix must start with s3://\")\n            # Get the s3 bucket and key from the checkpoint upload prefix\n            s3_bucket = checkpoint_upload_prefix.split(\"/\")[2]\n            s3_key_prefix = \"/\".join(checkpoint_upload_prefix.split(\"/\")[3:])\n            if not s3_key_prefix.endswith(\"/\"):\n                s3_key_prefix += \"/\"\n            print(f\"Uploading checkpoint to bucket=\\\"{s3_bucket}\\\", prefix=\\\"{s3_key_prefix}\\\"\")\n\n            future = self.executor.submit(self._execute_task, directory, s3_bucket, s3_key_prefix, step)\n            self.futures.append(future)\n\n    def wait(self):\n        if self.enabled:\n            concurrent.futures.wait(self.futures)\n\n    def _report_event(self, **kwargs):\n        if self.event_reporter is not None:\n            self.event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE, **kwargs)\n\n    def _wait_for_file_write_to_finish(self, file_path: str, wait_start_time: float) -> bool:\n        try:\n            file_size = os.stat(file_path).st_size\n            while True:\n                time.sleep(2)\n                file_size_after = os.stat(file_path).st_size\n                if file_size == file_size_after:\n                    return True\n                if time.time() - wait_start_time > self.max_wait_sec:\n                    return False\n                file_size = file_size_after\n        except Exception as e:\n            print(f\"Exception while waiting for file write to finish: {e}\")\n            return False\n\n    def _execute_task(self, directory, s3_bucket, s3_key_prefix, step: int):\n        try:\n            # Create an S3 client\n            session = boto3.Session(\n                aws_access_key_id=self.aws_access_key_id,\n                aws_secret_access_key=self.aws_secret_access_key,\n                aws_session_token=self.aws_session_token,\n                region_name=self.aws_region\n            )\n            s3_client = session.client('s3', endpoint_url=self.aws_endpoint_url)\n\n            print(f\"Step {step} - Wait for all checkpoint stages to finish ...\")\n\n            wait_start_time = time.time()\n            finished_files = set()\n\n            # Wait for all stages to finish\n            # Each stage is written by a separate process. We don't know which process\n            # will finish first. So we wait for all stages to finish before proceeding.\n            while True:\n                # Get the list of files in the directory\n                files = os.listdir(directory)\n                print(f\"Step {step} - Found {len(files)} of expected {3 * self.n_stages + 1} files in directory: {directory}\")\n\n                # Check if all stages have finished\n                all_finished = False\n                if len(files) == 3 * self.n_stages + 1:\n                    all_finished = True\n                    # Check if all files are closed\n                    for file in files:\n                        print(f\"Step {step} - Checking if {file} has is finished writing ...\")\n                        if file not in finished_files:\n                            if self._wait_for_file_write_to_finish(os.path.join(directory, file), wait_start_time) == False:\n                                all_finished = False\n                                break\n                            else:\n                                print(f\"Step {step} - Checking if {file} has is finished writing ... Done\")\n                                finished_files.add(file)\n\n                else:\n                    all_finished = False\n\n                if all_finished:\n                    break\n\n                # Check if we have timed out waiting for all stages to finish\n                if time.time() - wait_start_time > self.max_wait_sec:\n                    print(f\"Step {step} - Timeout waiting for all stages to finish\")\n                    return\n                \n                time.sleep(10)\n\n            print(f\"Step {step} - Compressing files in directory: {directory}\")\n            tar_file_path = f\"{directory}.tar.zst\"\n\n            # Get the tar file path\n            tar_file_name = os.path.basename(tar_file_path)\n\n            # Compress the directory via cli\n            if not self.dry_run:\n                if os.system(f\"tar -cf - -C \\\"{directory}\\\" . | zstd -3 -T4 > \\\"{tar_file_path}\\\"\") != 0:\n                    print(f\"Step {step} - Failed to compress {directory}\")\n                    return\n\n            s3_key = f\"{s3_key_prefix}{tar_file_name}\"\n            print(f\"Step {step} - Uploading checkpoint to s3://{s3_bucket}/{s3_key}\")\n            if not self.dry_run:\n                # Try uploading the tar file to s3. If it fails, try again after\n                # 20 seconds.\n                for i in range(3):\n                    try:\n                        s3_client.upload_file(tar_file_path, s3_bucket, s3_key)\n                        break\n                    except Exception as e:\n                        print(f\"Step {step} - Failed to upload checkpoint to s3: {e}\")\n                        if i == 2:\n                            self._report_event(message=f\"Step {step}, failed to upload checkpoint\",\n                                               event_type=EventReporter.EVENT_TYPE_JOB_ERROR,\n                                               level=EventReporter.LEVEL_ERROR,\n                                               requires_is_enabled=False)\n                            return\n                        time.sleep(20)\n\n                os.remove(tar_file_path)\n\n            if self.event_reporter is not None:\n                print(f\"Step {step} - Reporting event\")\n                try:\n                    self._report_event(message=f\"Uploaded checkpoint, at step {step}\",\n                                       event_type=EventReporter.EVENT_TYPE_CHECKPOINT_SAVE,\n                                       checkpoint_path=f\"s3://{s3_bucket}/{s3_key}\",\n                                       requires_is_enabled=False)\n                except Exception as e:\n                    print(f\"Step {step} - Failed to report event: {e}\")\n            else:\n                print(f\"Step {step} - Event reporter is disabled, skipping reporting event\")\n        except Exception as e:\n            print(f\"Exception: Step {step} - {e}\")\n            self._report_event(message=f\"Step {step}, failed to upload checkpoint\",\n                               event_type=EventReporter.EVENT_TYPE_JOB_ERROR,\n                               level=EventReporter.LEVEL_ERROR,\n                               requires_is_enabled=False)\n\ndef add_aws_arguments(parser: argparse.ArgumentParser):\n    parser.add_argument('--aws-endpoint-url', help='AWS endpoint URL')\n    parser.add_argument('--aws-access-key-id', help='AWS access key ID')\n    parser.add_argument('--aws-secret-access-key', help='AWS secret access key')\n    parser.add_argument('--aws-session-token', help='AWS session token')\n    parser.add_argument('--aws-region', default='auto', help='AWS region (default: auto)')\n\ndef aws_process_args(args: argparse.Namespace, required: bool = False):\n    if args.aws_endpoint_url is None:\n        args.aws_endpoint_url = os.environ.get('AWS_ENDPOINT_URL', 'https://s3.amazonaws.com')\n    if args.aws_access_key_id is None:\n        args.aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')\n        if required and args.aws_access_key_id is None:\n            print(\"Error: AWS_ACCESS_KEY_ID is not set\")\n            sys.exit(1)\n    if args.aws_secret_access_key is None:\n        args.aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')\n        if required and args.aws_secret_access_key is None:\n            print(\"Error: AWS_SECRET_ACCESS_KEY is not set\")\n            sys.exit(1)\n    if args.aws_session_token is None:\n        args.aws_session_token = os.environ.get('AWS_SESSION_TOKEN')\n\ndef main():\n    parser = argparse.ArgumentParser(description='Process S3 file objects with a specific prefix')\n    parser.add_argument('--bucket-name', required=True, help='S3 bucket name')\n    parser.add_argument('--prefix', required=True, help='Prefix for the S3 objects')\n    add_aws_arguments(parser)\n    add_entry_reporter_arguments(parser)\n    parser.add_argument('--job-id', '-j', type=str, required=True, help='job id')\n    parser.add_argument('--n-stages', type=int, default=1, help='Number of stages')\n    parser.add_argument('--dry-run', action='store_true', default=False, \n                        help='Perform a dry run (only print file paths)')\n    parser.add_argument('directories', nargs='+', help='Directories to upload')\n\n    args = parser.parse_args()\n    aws_process_args(args, required=True)\n\n    event_reporter = None\n    if args.event_host is not None and args.event_auth_token is not None and args.job_id is not None:\n        event_reporter = EventReporter(host=args.event_host, auth_token=args.event_auth_token, job_id=args.job_id)\n\n    task_manager = UploadManager(aws_endpoint_url = args.aws_endpoint_url,\n                                 aws_access_key_id = args.aws_access_key_id,\n                                 aws_secret_access_key = args.aws_secret_access_key,\n                                 aws_session_token = args.aws_session_token,\n                                 aws_region = args.aws_region,\n                                 event_reporter = event_reporter,\n                                 n_stages = args.n_stages,\n                                 dry_run = args.dry_run)\n\n    checkpoint_upload_prefix = f\"s3://{args.bucket_name}/{args.prefix}/\"\n    step = 0\n    for directory in args.directories:\n        print(f\"Adding task for directory: {directory}\")\n        step += 1\n        task_manager.add_task(directory=directory, checkpoint_upload_prefix=checkpoint_upload_prefix, step=step)\n        time.sleep(20)\n\n    print(\"Waiting for tasks to complete...\")\n    start_time = time.time()\n    task_manager.wait()\n    end_time = time.time()\n    print(f\"Tasks completed in {end_time - start_time} sec\")\n\nif __name__ == \"__main__\":\n    main()"
  }
]