Full Code of EleutherAI/gpt-neo for AI

master 23485e3c7940 cached
43 files
285.4 KB
75.8k tokens
104 symbols
1 requests
Download .txt
Showing preview only (299K chars total). Download the full file or copy to clipboard to get everything.
Repository: EleutherAI/gpt-neo
Branch: master
Commit: 23485e3c7940
Files: 43
Total size: 285.4 KB

Directory structure:
gitextract_tmcv9_l6/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   └── feature_request.md
│   └── workflows/
│       └── pytest.yml
├── .gitignore
├── CITATION.bib
├── CODEOWNERS
├── Dockerfile
├── GPTNeo_example_notebook.ipynb
├── LICENSE
├── README.md
├── configs/
│   ├── dataset_configs/
│   │   ├── example.json
│   │   ├── openwebtext2_new_inputs.json
│   │   └── pile.json
│   ├── gpt2_small.json
│   ├── gpt3_13B_256.json
│   ├── gpt3_13B_256_Pile.json
│   ├── gpt3_2-7B_256.json
│   ├── gpt3_6-7B_256.json
│   ├── gpt3_PAR_small_256.json
│   ├── gpt3_XL_256_Pile.json
│   ├── gpt3_large_256.json
│   ├── gpt3_medium_256.json
│   └── gpt3_small_256.json
├── configs.py
├── data/
│   ├── create_tfrecords.py
│   ├── encoders.py
│   └── train_tokenizer.py
├── docker-compose.yml
├── encoders.py
├── export.py
├── inputs.py
├── main.py
├── model_fns.py
├── models/
│   ├── activations.py
│   ├── gpt2/
│   │   └── gpt2.py
│   ├── layers.py
│   └── utils.py
├── optimizers.py
├── requirements.txt
├── run_experiment.py
├── sample.py
├── tasks.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Expected behavior**
A clear and concise description of what you expected to happen.

**Proposed solution**
If you have an idea for how we can fix this problem, describe it here. 

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Environment (please complete the following information):**
 - GPUs:
- Configs:

**Additional context**
Add any other context about the problem here.


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: feature request
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.


================================================
FILE: .github/workflows/pytest.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python package

on:
  push:
    branches: [ master ]
  pull_request:
    branches: [ master ]

jobs:
  build:

    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: [3.6, 3.7]

    steps:
    - uses: actions/checkout@v2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v2
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install pytest
        if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
    - name: Test with pytest
      run: |
        pytest


================================================
FILE: .gitignore
================================================
# testing
.test/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

logs/
*.log
test_*
test/
.vscode


run_configs/


================================================
FILE: CITATION.bib
================================================
@software{gpt-neo,
  author       = {Black, Sid and
                  Gao, Leo and
                  Wang, Phil and
                  Leahy, Connor and
                  Biderman, Stella},
  title        = {{GPT-Neo: Large Scale Autoregressive Language 
                   Modeling with Mesh-Tensorflow}},
  month        = mar,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {1.0},
  doi          = {10.5281/zenodo.5297715},
  url          = {https://doi.org/10.5281/zenodo.5297715}
}


================================================
FILE: CODEOWNERS
================================================
* EleutherAI/pm-gptneo


================================================
FILE: Dockerfile
================================================
FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15

WORKDIR /neogpt

# Make RUN commands use `bash --login`:
SHELL ["/bin/bash", "--login", "-c"]
ENV DEBIAN_FRONTEND=noninteractive 
RUN apt-get update -y && apt-get install tmux -y
RUN conda install gcc_linux-64 gxx_linux-64 -y 
ADD requirements.txt .
RUN pip install -r requirements.txt 
RUN apt-get install screen htop -y
RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15

CMD tmux

================================================
FILE: GPTNeo_example_notebook.ipynb
================================================
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "GPTNeo_example_notebook.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J0i5MRP0SV8D"
      },
      "source": [
        "Welcome to the colab notebook for [GPTNeo](https://github.com/EleutherAI/GPTNeo) - a fully open source implementation of GPT like models for mesh-tensorflow by [EleutherAI](eleuther.ai).\n",
        "\n",
        "Our library provides training and inference for GPT models up to GPT3 sizes on both TPUs and GPUs. \n",
        "\n",
        "In this notebook we walk you through TPU training (or finetuning!) and sampling using the freely available colab TPUs.\n",
        "\n",
        "If you find our repo useful, come join [our discord](https://discord.gg/BK2v3EJ) and say hi! 😬\n",
        "\n",
        "Before we get going - make sure you are running this notebook with a TPU available. Go to Runtime -> Change Runtime Type and select 'TPU' under hardware accelerator.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K-53qkZV6Lv9",
        "cellView": "form"
      },
      "source": [
        "#@title Setup\n",
        "%tensorflow_version 2.x\n",
        "!git clone https://github.com/EleutherAI/GPTNeo\n",
        "%cd GPTNeo\n",
        "!pip3 install -q -r requirements.txt\n",
        "pretrained_model = None\n",
        "dataset = None\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M0R1owh2qvp8"
      },
      "source": [
        "## Set Up Google Cloud"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0PmzM4dy7diP"
      },
      "source": [
        "To train on TPUs we need to store our data on a google cloud bucket - as TPUs can't read from local filesystems.\n",
        "\n",
        "You can set up a bucket by signing up for a free trial here: https://console.cloud.google.com/\n",
        "\n",
        "Make a bucket at https://console.cloud.google.com/storage and come back when that's done.\n",
        "\n",
        "Make sure to select 'Uniform' access control when setting up the bucket, or the colab notebook won't have the required permissions to read from it.\n",
        "\n",
        "The next cell sets up google authentication and gives the notebook read and write access to your bucket.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71bQUjPA7qvj"
      },
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "!gcloud init"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cr_c6A2NBK5i",
        "cellView": "form"
      },
      "source": [
        "path_to_cloud_bucket = 'gs://your-cloud-bucket/' #@param {type:\"string\"}"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZGbzUPD0tad"
      },
      "source": [
        "## Set Up Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R918l14UhrBR"
      },
      "source": [
        "We first need to download and tokenize a dataset. If you just want to sample from a pretrained model, you can skip this step and move on to the `Pretrained Model` section.\n",
        "\n",
        "You can choose from:\n",
        "\n",
        "*   Sampling Only - choose this option if you only wish to sample from our trained models, then move on to the `Pretrained Model` section.\n",
        "\n",
        "*   OpenWebText - an opensource clone of OpenAI's WebText dataset, the original training data of GPT2.\n",
        "\n",
        "*   YoutubeSubtitles - a dataset of subtitles scraped from youtube videos.\n",
        "\n",
        "* Hackernews - comments scraped from hackernews\n",
        "\n",
        "* NIHExporter - Data relating to various projects from the national institute of health.\n",
        "\n",
        "* Custom - if this option is chosen you will be prompted to enter the path to your own dataset. It should be a directory containing .txt or .jsonl files.\n",
        "\n",
        "All these datasets are from EleutherAI's side project - [The Pile™](https://github.com/EleutherAI/The-Pile) - an effort to gather a general purpose, diverse and open source plain text dataset large enough to train 1T+ parameter language models.\n",
        "\n",
        "Even the smallest datasets are fairly large files, so this step will likely take a while. Select a dataset in the next cell, then run the next two cells, and go grab a snack and a cup of tea 😊\n",
        "\n",
        "Alternatively, you can provide your own dataset in the form of a folder or gzip archive of .txt files. Simply select 'Custom' below and follow input the path to your data and the name of your dataset when prompted."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pM8jP3Am_hsx",
        "cellView": "form"
      },
      "source": [
        "# Select a Dataset:\n",
        "import os\n",
        "dataset = 'Sampling_Only' #@param [\"Sampling_Only\", \"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]\n",
        "\n",
        "if dataset == \"Sampling_Only\":\n",
        "  pass\n",
        "elif dataset == 'OpenWebText':\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar -O openwebtext.tar.xz\n",
        "  !tar xf openwebtext.tar.xz\n",
        "  dataset_path = \"openwebtext\"\n",
        "  dataset_name = dataset_path\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == 'YoutubeSubtitles':\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst -O data/yt_subs.jsonl.zst\n",
        "  dataset_path = 'data'\n",
        "  dataset_name = 'ytsubs'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == 'HackerNews':\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/hn.tar.gz -O data/hn.tar.gz\n",
        "  dataset_path = 'data'\n",
        "  dataset_name = 'hackernews'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == \"NIHExporter\":\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst -O data/NIH_ExPORTER_awarded_grant_text.jsonl.zst\n",
        "  dataset_path = 'data'\n",
        "  os.system('mv NIH_ExPORTER_awarded_grant_text.jsonl.zst ./data')\n",
        "  dataset_name = 'nihexporter'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == \"Custom\":\n",
        "  dataset_path = input('Enter the path to the folder containing your data: ')\n",
        "  dataset_name = input('Enter the name of your dataset: ')\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "else:\n",
        "  raise NotImplementedError('please select from available options: [\"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]')\n"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zMl1cHtN5I_W"
      },
      "source": [
        "### Tokenize and Upload Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6IBIompTJaqm"
      },
      "source": [
        "Now tokenize the dataset and copy it over to your google cloud bucket. You may skip this step if you are sampling from a pre-trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pq5u0WUSJWwz",
        "cellView": "both"
      },
      "source": [
        "# Tokenize Data\n",
        "!python data/create_tfrecords.py --input_dir /content/GPTNeo/$dataset_path --name $dataset_name --files_per 1000 --output_dir $out_name --write_dataset_config --processes 1\n",
        "\n",
        "# copy the data to your bucket\n",
        "if not path_to_cloud_bucket.endswith('/'):\n",
        "  path_to_cloud_bucket += '/'\n",
        "copy_loc = path_to_cloud_bucket + \"datasets/\" + dataset\n",
        "!gsutil -m cp -r /content/GPTNeo/$out_name $copy_loc\n",
        "!gsutil ls $path_to_cloud_bucket"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NhvmTFD7b_fb"
      },
      "source": [
        "Before starting training - you'll need to edit your dataset & model configs to point to your buckets / data. You need to do this even if you are sampling from a pre-trained model.\n",
        "\n",
        "*   First change the writefile path to point to your chosen dataset - e.g `%%writefile configs/dataset_configs/ytsubs.json`\n",
        "*   Change the \"path\" field to point to your cloud bucket location - e.g `gs://neo_lmdatasets/datasets/ytsubs_*.tfrecords`\n",
        "* Change `dataset_name` in `%%writefile configs/dataset_configs/dataset_name.json` to the name of your chosen dataset.\n",
        "* Once you've made the edits, then run the cell below to overwrite the existing files.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MCsZP48vavCP"
      },
      "source": [
        "%%writefile configs/dataset_configs/Sampling_Only.json\n",
        "\n",
        "{\n",
        "  \"path\": \"gs://eleutherai/datasets/Sampling_Only/Sampling_Only*.tfrecords\",\n",
        "  \"eval_path\": \"\",\n",
        "  \"n_vocab\": 50256,\n",
        "  \"tokenizer_is_pretrained\": true,\n",
        "  \"tokenizer_path\": \"gpt2\",\n",
        "  \"eos_id\": 50256,\n",
        "  \"padding_id\": 50257\n",
        "}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dH0x3dI9j85P"
      },
      "source": [
        "## Set Model Configs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I6GnCgAkB7GQ"
      },
      "source": [
        "The model below is identical to our pretrained GPT3XL model (1.3B Params). \n",
        "\n",
        "If you want to use a smaller model, you can modify any of the config files in ../configs/ ending in _8.json, all of which are designed to train on tpu-v8s.\n",
        "\n",
        "For a more detailed breakdown on what each item in the configuration file means - please read through our training and config guides in our [github README](https://github.com/EleutherAI/GPTNeo#training-guide). \n",
        "\n",
        "You'll want to change the first item in the `datasets` list to the name of your chosen dataset. (the filename minus .json in ./configs/dataset_configs)\n",
        "\n",
        "You'll also want to modify the `model_path` field to point to your google cloud bucket, so checkpoints get saved to there."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L9hUDdokiWj6"
      },
      "source": [
        "%%writefile configs/GPT3_XL.json\n",
        "\n",
        "{\n",
        "    \"n_head\": 16,\n",
        "    \"n_vocab\": 50257,\n",
        "    \"embed_dropout\": 0,\n",
        "    \"lr\": 0.0002,\n",
        "    \"lr_decay\": \"cosine\",\n",
        "    \"warmup_steps\": 3000,\n",
        "    \"beta1\": 0.9,\n",
        "    \"beta2\": 0.95,\n",
        "    \"epsilon\": 1e-8,\n",
        "    \"opt_name\": \"adam\",\n",
        "    \"weight_decay\": 0,\n",
        "    \"train_batch_size\": 256,\n",
        "    \"attn_dropout\": 0,\n",
        "    \"train_steps\": 600000,\n",
        "    \"eval_steps\": 0,\n",
        "    \"predict_steps\": 1,\n",
        "    \"res_dropout\": 0,\n",
        "    \"eval_batch_size\": 4,\n",
        "    \"predict_batch_size\": 1,\n",
        "    \"iterations\": 100,\n",
        "    \"n_embd\": 2048,\n",
        "    \"datasets\": [[\"pile\", null, null, null]],\n",
        "    \"model\": \"GPT\",\n",
        "    \"model_path\": \"gs://eleutherai/GPT3_XL\",\n",
        "    \"n_ctx\": 2048,\n",
        "    \"n_layer\": 24,\n",
        "    \"scale_by_depth\": true,\n",
        "    \"scale_by_in\": false,\n",
        "    \"attention_types\" :  [[[\"global\", \"local\"],12]],\n",
        "    \"mesh_shape\": \"x:4,y:2\",\n",
        "    \"layout\": \"intermediate_expanded:x,heads:x,vocab:n_vocab,memory_length:y,embd:y\",\n",
        "    \"activation_function\": \"gelu\",\n",
        "    \"recompute_grad\": true,\n",
        "    \"gradient_clipping\": 1.0,\n",
        "    \"tokens_per_mb_per_replica\": 2048,\n",
        "    \"precision\": \"bfloat16\"\n",
        "}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWK9MJqwcXKn"
      },
      "source": [
        "## Training from Scratch\n",
        "\n",
        "Now we will begin to train the model. If no previous model is found in \"model_path\", the model will start training from scratch. If you'd prefer to finetune from pretrained, skip to the `Finetune a Pretrained Model` section.\n",
        "\n",
        "If everything's set up correctly, you can now run the main.py function to start training!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VUtrysOSBzjJ"
      },
      "source": [
        "!python3 main.py --model colab_XL --steps_per_checkpoint 500 --tpu colab"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "koKQHA5ikCvD"
      },
      "source": [
        "## Pretrained Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0QZv4_pnkk26"
      },
      "source": [
        "If you want to sample from or finetune a pretrained model, EleutherAI has pretrained two models for release. One with [1.3B parameters](https://the-eye.eu/public/AI/gptneo-release/GPT3_XL/), and another with [2.7B](https://the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/). \n",
        "\n",
        "Select an option below to download the weights locally. You will then need to upload them to your cloud bucket in order to finetune from them. If the download command isn't working, try the commented out code to download from a different source.\n",
        "\n",
        "The 2-7B model likely won't fit into the colab TPUs memory, and you may have to get some larger pods to finetune from it.\n",
        "\n",
        "Sampling from it, however, works just fine.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lgTG1ammqGB0",
        "cellView": "form"
      },
      "source": [
        "# @title Download pretrained model weights:\n",
        "pretrained_model = 'GPT3_2-7B' #@param [\"GPT3_XL\", \"GPT3_2-7B\"]\n",
        "!wget -m -np -c -U \"eye02\" -w 2 -R \"index.html*\" \"https://the-eye.eu/public/AI/gptneo-release/$pretrained_model/\"\n",
        "path_to_local_weights = f\"/content/GPTNeo/the-eye.eu/public/AI/gptneo-release/{pretrained_model}\"\n",
        "\n",
        "# URL = f\"http://eaidata.bmk.sh/data/gptneo-release/{pretrained_model}/\"\n",
        "# FOLDER_NAME = \"GPT3_XL\"\n",
        "# !curl $URL | grep -i \"</a>\" | sed -n 's/.*href=\"\\([^\"]*\\).*/\\1/p' | sed \"s|^|$URL|\" | xargs -n 1 -P 4 wget -P $pretrained_model\n",
        "# path_to_local_weights = pretrained_model\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GU3BDNJN_ZXE"
      },
      "source": [
        "# upload to your bucket\n",
        "bucket_base = \"gs://\" + path_to_cloud_bucket.replace('gs://', '').split('/')[0]\n",
        "!gsutil -m cp -r $path_to_local_weights $bucket_base"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bnqkKBTOn0ox"
      },
      "source": [
        "If everything has worked successfully you should now see your model listed in your bucket below."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "80t9MMionm2h"
      },
      "source": [
        "!gsutil ls $bucket_base"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QDKL8fCSoApL"
      },
      "source": [
        "Now we want to make a few modifications to the model config in order to get training / sampling working on colab.\n",
        "\n",
        "If you are just sampling from our pretrained models, you can leave the settings as is, run the cell below, then move on to the `Sample from your model` section.\n",
        "\n",
        "If finetuning, you can change parameters below. \n",
        "\n",
        "* `path_to_model` should point to the model weights location in your cloud bucket, and will default to `$bucket_base/${pretrained_model}` if nothing is entered.\n",
        "\n",
        "* `batch_size` is your train batch size - if you're encountering memory errors, try lowering this.\n",
        "\n",
        "* `dataset_name` is the name of your dataset, if nothing is entered, this should default to the dataset you selected in the `Prepare Data` section.\n",
        "\n",
        "* `mesh_shape` specifies the way the model will be divided up across the TPU cores. We suggest leaving this alone unless you know what you're doing.\n",
        "\n",
        "* `train_steps` specifies how many steps you want the model to finetune for. We set this to 1000 for demonstrative purposes but you may need to increase this a little depending on your goals. If you are just sampling from the model, you can leave this as is.\n",
        "\n",
        "* `steps_per_checkpoint` specifies how often you want to save model weights during training.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Laf0slBMDCUj",
        "cellView": "form"
      },
      "source": [
        "# @title Modify config for colab. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "path_to_model = \"\" #@param {type:\"string\"}\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "dset = \"\"  #@param {type:\"string\"}\n",
        "mesh_shape = \"x:4,y:2\" #@param {type:\"string\"}\n",
        "train_steps = 1000 #@param {type:\"integer\"}\n",
        "steps_per_checkpoint = 500 #@param {type:\"integer\"}\n",
        "start_step = 400000 if pretrained_model == \"GPT3_2-7B\" else 362000\n",
        "\n",
        "if path_to_model == \"\":\n",
        "  path_to_model = f'{bucket_base.strip(\"/\")}/{pretrained_model}'\n",
        "print(f'MODEL PATH: {path_to_model}\\n')\n",
        "\n",
        "if dset == \"\" and dataset != \"Sampling_Only\":\n",
        "  dset = dataset\n",
        "elif dataset is None and dset == \"\":\n",
        "  dset = \"pile\"\n",
        "\n",
        "def pad_to_multiple_of(n, mult):\n",
        "  \"\"\"\n",
        "  pads n to a multiple of mult\n",
        "  \"\"\"\n",
        "  extra = n % mult\n",
        "  if extra > 0:\n",
        "      n = n + mult - extra\n",
        "  return n\n",
        "\n",
        "with open(f'{path_to_local_weights}/config.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  pprint(data)\n",
        "  dset_val = [[dset, None, None, None]] if dset != \"\" else data[\"datasets\"]\n",
        "  mods = {\n",
        "          \"mesh_shape\": mesh_shape,\n",
        "          \"layout\": \"intermediate_expanded:x,heads:x,memory_length:y,embd:y\",\n",
        "          \"model_path\": path_to_model,\n",
        "          \"datasets\": dset_val,\n",
        "          \"train_steps\": start_step + train_steps,\n",
        "          \"eval_steps\": 0,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"predict_batch_size\": batch_size\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fPwwbPCA6O7r"
      },
      "source": [
        "### Begin Fine-Tuning\n",
        "\n",
        "If you are fine-tuning the pretrained model, this line of code will begin the training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0YlaHzyXuMaj"
      },
      "source": [
        "!python3 main.py --model $pretrained_model --steps_per_checkpoint $steps_per_checkpoint --tpu colab"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I_HxtEmBGTGT"
      },
      "source": [
        "### Sample from your model\n",
        "\n",
        "Once training is finished, (or your pretrained model is on your bucket), you can run the same command with the --predict flag to sample from your model.\n",
        "\n",
        "To pass in a prompt, save it to a .txt file, and pass in the name of the file with the --prompt flag.\n",
        "\n",
        "use the cell below to enter your prompt, and run it to save it to example_prompt.txt.\n",
        "\n",
        "You may need to decrease the predict batch size in your config if you're facing OOM errors.\n",
        "\n",
        "Let's see if the GPTNeo model can finish coding itself, with a sample prompt consisting of the beginning of a `torch.nn.Module`:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CQE1Y5wPFx7h",
        "outputId": "e1a92c0c-18ee-4014-a0b8-d67161384940",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "%%writefile example_prompt.txt\n",
        "\n",
        "class GPT(nn.Module):\n",
        "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
        "\n",
        "    def __init__(self, config):\n",
        "        super().__init__()\n",
        "\n",
        "        # input embedding stem\n",
        "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
        "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
        "        self.drop = nn.Dropout(config.embd_pdrop)\n",
        "        # transformer\n",
        "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
        "        # decoder head\n",
        "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
        "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
        "\n",
        "        self.block_size = config.block_size\n",
        "        self.apply(self._init_weights)\n",
        "\n",
        "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Overwriting example_prompt.txt\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sf_5E4fHFQIh",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f3c12a94-7ef8-43c1-a668-6365966d42b4"
      },
      "source": [
        "!python3 main.py --model $pretrained_model --steps_per_checkpoint 500 --tpu colab --predict --prompt example_prompt.txt"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2021-03-22 12:20:43.411018: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "non-resource variables are not supported in the long term\n",
            "Current step 400000\n",
            "Saving config to gs://test-bucket-neo/GPT3_2-7B\n",
            "2021-03-22 12:20:50.689547: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
            "2021-03-22 12:20:50.691059: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1\n",
            "2021-03-22 12:20:50.701975: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
            "2021-03-22 12:20:50.702051: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (eeb4af61eb99): /proc/driver/nvidia/version does not exist\n",
            "2021-03-22 12:20:52.229703: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)\n",
            "Done!\n",
            "params = defaultdict(<function fetch_model_params.<locals>.<lambda> at 0x7f64ee76fb90>, {'n_head': 20, 'n_vocab': 50257, 'embed_dropout': 0, 'lr': 0.00016, 'lr_decay': 'cosine', 'warmup_steps': 3000, 'beta1': 0.9, 'beta2': 0.95, 'epsilon': 1e-08, 'ada_epsilon1': '1e-30', 'ada_epsilon2': 0.001, 'opt_name': 'adam', 'weight_decay': 0, 'train_batch_size': 16, 'attn_dropout': 0, 'train_steps': 401000, 'lr_decay_end': 300000, 'eval_steps': 0, 'predict_steps': 0, 'res_dropout': 0, 'eval_batch_size': 128, 'predict_batch_size': 4, 'iterations': 500, 'n_embd': 2560, 'datasets': [['pile', None, None, None]], 'model_path': 'gs://test-bucket-neo/GPT3_2-7B', 'n_ctx': 2048, 'n_layer': 32, 'scale_by_depth': True, 'scale_by_in': False, 'attention_types': ['global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local'], 'mesh_shape': 'x:4,y:2', 'layout': 'intermediate_expanded:x,heads:x,memory_length:y,embd:y', 'activation_function': 'gelu', 'recompute_grad': True, 'gradient_clipping': 1.0, 'tokens_per_mb_per_replica': 4096, 'padding_id': 50257, 'eos_id': 50256, 'dataset_configs': {'pile': {'n_vocab': 50257, 'path': 'gs://neo-datasets/pile/pile_*.tfrecords', 'eval_path': 'gs://neo-datasets/pile_val.tfrecords', 'tokenizer_is_pretrained': True, 'tokenizer_path': 'gpt2', 'eos_id': 50256, 'padding_id': 50257}}, 'mlm_training': False, 'causal': True, 'num_cores': 8, 'auto_layout': False, 'auto_layout_and_mesh_shape': False, 'use_tpu': True, 'gpu_ids': ['device:GPU:0'], 'steps_per_checkpoint': 500, 'predict': True, 'model': 'GPT', 'export': False, 'sampling_use_entmax': False, 'moe_layers': None, 'slow_sampling': False})\n",
            "Using config: {'_model_dir': 'gs://test-bucket-neo/GPT3_2-7B', '_tf_random_seed': None, '_save_summary_steps': 500, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
            "cluster_def {\n",
            "  job {\n",
            "    name: \"worker\"\n",
            "    tasks {\n",
            "      key: 0\n",
            "      value: \"10.82.219.162:8470\"\n",
            "    }\n",
            "  }\n",
            "}\n",
            "isolate_session_state: true\n",
            ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.82.219.162:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.82.219.162:8470', '_evaluation_master': 'grpc://10.82.219.162:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=500, num_shards=8, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7f64ee774a90>}\n",
            "_TPUContext: eval_on_tpu True\n",
            "Predictions generated\n",
            "Querying Tensorflow master (grpc://10.82.219.162:8470) for TPU system metadata.\n",
            "2021-03-22 12:20:53.623443: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:373] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.\n",
            "Initializing TPU system (master: grpc://10.82.219.162:8470) to fetch topology for model parallelism. This might take a while.\n",
            "Found TPU system:\n",
            "*** Num TPU Cores: 8\n",
            "*** Num TPU Workers: 1\n",
            "*** Num TPU Cores Per Worker: 8\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 6478766768852144079)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1341089584581626564)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -607673649088781696)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -4050793109911027603)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, -6683233089843062258)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -4741539030516422912)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 2164395643386766058)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 3352841220362516620)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 5726423099271110669)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 7316344872981758207)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 7432402242254058183)\n",
            "Calling model_fn.\n",
            "num_cores_per_replica: 1\n",
            "computation_shape: [1, 1, 1, 1]\n",
            "num_replicas: 8\n",
            "device_assignment.topology.device_coordinates: [[[0 0 0 0]\n",
            "  [0 0 0 1]\n",
            "  [1 0 0 0]\n",
            "  [1 0 0 1]\n",
            "  [0 1 0 0]\n",
            "  [0 1 0 1]\n",
            "  [1 1 0 0]\n",
            "  [1 1 0 1]]]\n",
            "device_assignment.core_assignment: [[[0 0 0 0]]\n",
            "\n",
            " [[0 0 0 1]]\n",
            "\n",
            " [[1 0 0 0]]\n",
            "\n",
            " [[1 0 0 1]]\n",
            "\n",
            " [[0 1 0 0]]\n",
            "\n",
            " [[0 1 0 1]]\n",
            "\n",
            " [[1 1 0 0]]\n",
            "\n",
            " [[1 1 0 1]]]\n",
            "2021-03-22 12:21:11.005988: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
            "device_list = ['/job:worker/task:0/device:CPU:0']\n",
            "SimdMeshImpl ignoring devices ['', '', '', '', '', '', '', '']\n",
            "SimdMeshImpl init: Shape[x=4, y=2] LayoutRules{('heads', 'x'), ('embd', 'y'), ('intermediate_expanded', 'x'), ('memory_length', 'y')}\n",
            "Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f64e9078050>\n",
            "Create pnum_tensor\n",
            "Variable gpt2/h0/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h0/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h0/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h1/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h1/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h1/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h10/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h10/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h10/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h11/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h11/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h11/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h12/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h12/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h12/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h13/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h13/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h13/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h14/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h14/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h14/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h15/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h15/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h15/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h16/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h16/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h16/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h17/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h17/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h17/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h18/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h18/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h18/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h19/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h19/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h19/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h2/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h2/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h2/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h20/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h20/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h20/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h21/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h21/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h21/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h22/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h22/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h22/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h23/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h23/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h23/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h24/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h24/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h24/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h25/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h25/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h25/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h26/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h26/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h26/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h27/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h27/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h27/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h28/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h28/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h28/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h29/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h29/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h29/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h3/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h3/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h3/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h30/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h30/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h30/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h31/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h31/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h31/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h4/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h4/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h4/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h5/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h5/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h5/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h6/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h6/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h6/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h7/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h7/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h7/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h8/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h8/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h8/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h9/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h9/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h9/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/wpe                                                     size 5242880      slice_size 2621440      Shape[embed_sequence=2048, embd=2560]                       \n",
            "Variable gpt2/wte                                                     size 128657920    slice_size 64328960     Shape[vocab=50257, embd=2560]                               \n",
            "Variable stacked/gpt2/h0/mlp/conv1d_main/c_fc/bias                    size 256000       slice_size 64000        Shape[stacked=25, intermediate_expanded=10240]              \n",
            "    gpt2/h0/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h1/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h2/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h3/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h4/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h5/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h6/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h7/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h8/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h9/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h10/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h11/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h12/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h13/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h14/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h15/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h16/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h17/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h18/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h19/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h20/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h21/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h22/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h23/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h24/mlp/conv1d_main/c_fc/bias\n",
            "Variable stacked/gpt2/h0/norm_1/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h0/norm_1/g\n",
            "    gpt2/h0/norm_1/b\n",
            "    gpt2/h0/attn/compute_output_bias/o_b\n",
            "    gpt2/h0/norm_2/g\n",
            "    gpt2/h0/norm_2/b\n",
            "    gpt2/h0/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h1/norm_1/g\n",
            "    gpt2/h1/norm_1/b\n",
            "    gpt2/h1/attn/compute_output_bias/o_b\n",
            "    gpt2/h1/norm_2/g\n",
            "    gpt2/h1/norm_2/b\n",
            "    gpt2/h1/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h2/norm_1/g\n",
            "    gpt2/h2/norm_1/b\n",
            "    gpt2/h2/attn/compute_output_bias/o_b\n",
            "    gpt2/h2/norm_2/g\n",
            "    gpt2/h2/norm_2/b\n",
            "    gpt2/h2/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h3/norm_1/g\n",
            "    gpt2/h3/norm_1/b\n",
            "    gpt2/h3/attn/compute_output_bias/o_b\n",
            "    gpt2/h3/norm_2/g\n",
            "    gpt2/h3/norm_2/b\n",
            "    gpt2/h3/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h4/norm_1/g\n",
            "    gpt2/h4/norm_1/b\n",
            "    gpt2/h4/attn/compute_output_bias/o_b\n",
            "    gpt2/h4/norm_2/g\n",
            "    gpt2/h4/norm_2/b\n",
            "    gpt2/h4/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h5/norm_1/g\n",
            "    gpt2/h5/norm_1/b\n",
            "    gpt2/h5/attn/compute_output_bias/o_b\n",
            "    gpt2/h5/norm_2/g\n",
            "    gpt2/h5/norm_2/b\n",
            "    gpt2/h5/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h6/norm_1/g\n",
            "    gpt2/h6/norm_1/b\n",
            "    gpt2/h6/attn/compute_output_bias/o_b\n",
            "    gpt2/h6/norm_2/g\n",
            "    gpt2/h6/norm_2/b\n",
            "    gpt2/h6/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h7/norm_1/g\n",
            "    gpt2/h7/norm_1/b\n",
            "    gpt2/h7/attn/compute_output_bias/o_b\n",
            "    gpt2/h7/norm_2/g\n",
            "    gpt2/h7/norm_2/b\n",
            "    gpt2/h7/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h8/norm_1/g\n",
            "    gpt2/h8/norm_1/b\n",
            "    gpt2/h8/attn/compute_output_bias/o_b\n",
            "Variable stacked/gpt2/h17/norm_1/g                                    size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h17/norm_1/g\n",
            "    gpt2/h17/norm_1/b\n",
            "    gpt2/h17/attn/compute_output_bias/o_b\n",
            "    gpt2/h17/norm_2/g\n",
            "    gpt2/h17/norm_2/b\n",
            "    gpt2/h17/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h18/norm_1/g\n",
            "    gpt2/h18/norm_1/b\n",
            "    gpt2/h18/attn/compute_output_bias/o_b\n",
            "    gpt2/h18/norm_2/g\n",
            "    gpt2/h18/norm_2/b\n",
            "    gpt2/h18/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h19/norm_1/g\n",
            "    gpt2/h19/norm_1/b\n",
            "    gpt2/h19/attn/compute_output_bias/o_b\n",
            "    gpt2/h19/norm_2/g\n",
            "    gpt2/h19/norm_2/b\n",
            "    gpt2/h19/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h20/norm_1/g\n",
            "    gpt2/h20/norm_1/b\n",
            "    gpt2/h20/attn/compute_output_bias/o_b\n",
            "    gpt2/h20/norm_2/g\n",
            "    gpt2/h20/norm_2/b\n",
            "    gpt2/h20/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h21/norm_1/g\n",
            "    gpt2/h21/norm_1/b\n",
            "    gpt2/h21/attn/compute_output_bias/o_b\n",
            "    gpt2/h21/norm_2/g\n",
            "    gpt2/h21/norm_2/b\n",
            "    gpt2/h21/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h22/norm_1/g\n",
            "    gpt2/h22/norm_1/b\n",
            "    gpt2/h22/attn/compute_output_bias/o_b\n",
            "    gpt2/h22/norm_2/g\n",
            "    gpt2/h22/norm_2/b\n",
            "    gpt2/h22/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h23/norm_1/g\n",
            "    gpt2/h23/norm_1/b\n",
            "    gpt2/h23/attn/compute_output_bias/o_b\n",
            "    gpt2/h23/norm_2/g\n",
            "    gpt2/h23/norm_2/b\n",
            "    gpt2/h23/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h24/norm_1/g\n",
            "    gpt2/h24/norm_1/b\n",
            "    gpt2/h24/attn/compute_output_bias/o_b\n",
            "    gpt2/h24/norm_2/g\n",
            "    gpt2/h24/norm_2/b\n",
            "    gpt2/h24/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h25/norm_1/g\n",
            "    gpt2/h25/norm_1/b\n",
            "    gpt2/h25/attn/compute_output_bias/o_b\n",
            "Variable stacked/gpt2/h25/mlp/conv1d_main/c_fc/bias                   size 71680        slice_size 17920        Shape[stacked=7, intermediate_expanded=10240]               \n",
            "    gpt2/h25/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h26/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h27/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h28/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h29/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h30/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h31/mlp/conv1d_main/c_fc/bias\n",
            "Variable stacked/gpt2/h25/norm_2/g                                    size 104960       slice_size 52480        Shape[stacked=41, embd=2560]                                \n",
            "    gpt2/h25/norm_2/g\n",
            "    gpt2/h25/norm_2/b\n",
            "    gpt2/h25/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h26/norm_1/g\n",
            "    gpt2/h26/norm_1/b\n",
            "    gpt2/h26/attn/compute_output_bias/o_b\n",
            "    gpt2/h26/norm_2/g\n",
            "    gpt2/h26/norm_2/b\n",
            "    gpt2/h26/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h27/norm_1/g\n",
            "    gpt2/h27/norm_1/b\n",
            "    gpt2/h27/attn/compute_output_bias/o_b\n",
            "    gpt2/h27/norm_2/g\n",
            "    gpt2/h27/norm_2/b\n",
            "    gpt2/h27/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h28/norm_1/g\n",
            "    gpt2/h28/norm_1/b\n",
            "    gpt2/h28/attn/compute_output_bias/o_b\n",
            "    gpt2/h28/norm_2/g\n",
            "    gpt2/h28/norm_2/b\n",
            "    gpt2/h28/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h29/norm_1/g\n",
            "    gpt2/h29/norm_1/b\n",
            "    gpt2/h29/attn/compute_output_bias/o_b\n",
            "    gpt2/h29/norm_2/g\n",
            "    gpt2/h29/norm_2/b\n",
            "    gpt2/h29/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h30/norm_1/g\n",
            "    gpt2/h30/norm_1/b\n",
            "    gpt2/h30/attn/compute_output_bias/o_b\n",
            "    gpt2/h30/norm_2/g\n",
            "    gpt2/h30/norm_2/b\n",
            "    gpt2/h30/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h31/norm_1/g\n",
            "    gpt2/h31/norm_1/b\n",
            "    gpt2/h31/attn/compute_output_bias/o_b\n",
            "    gpt2/h31/norm_2/g\n",
            "    gpt2/h31/norm_2/b\n",
            "    gpt2/h31/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/ln_f/g\n",
            "    gpt2/ln_f/b\n",
            "Variable stacked/gpt2/h8/norm_2/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h8/norm_2/g\n",
            "    gpt2/h8/norm_2/b\n",
            "    gpt2/h8/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h9/norm_1/g\n",
            "    gpt2/h9/norm_1/b\n",
            "    gpt2/h9/attn/compute_output_bias/o_b\n",
            "    gpt2/h9/norm_2/g\n",
            "    gpt2/h9/norm_2/b\n",
            "    gpt2/h9/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h10/norm_1/g\n",
            "    gpt2/h10/norm_1/b\n",
            "    gpt2/h10/attn/compute_output_bias/o_b\n",
            "    gpt2/h10/norm_2/g\n",
            "    gpt2/h10/norm_2/b\n",
            "    gpt2/h10/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h11/norm_1/g\n",
            "    gpt2/h11/norm_1/b\n",
            "    gpt2/h11/attn/compute_output_bias/o_b\n",
            "    gpt2/h11/norm_2/g\n",
            "    gpt2/h11/norm_2/b\n",
            "    gpt2/h11/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h12/norm_1/g\n",
            "    gpt2/h12/norm_1/b\n",
            "    gpt2/h12/attn/compute_output_bias/o_b\n",
            "    gpt2/h12/norm_2/g\n",
            "    gpt2/h12/norm_2/b\n",
            "    gpt2/h12/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h13/norm_1/g\n",
            "    gpt2/h13/norm_1/b\n",
            "    gpt2/h13/attn/compute_output_bias/o_b\n",
            "    gpt2/h13/norm_2/g\n",
            "    gpt2/h13/norm_2/b\n",
            "    gpt2/h13/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h14/norm_1/g\n",
            "    gpt2/h14/norm_1/b\n",
            "    gpt2/h14/attn/compute_output_bias/o_b\n",
            "    gpt2/h14/norm_2/g\n",
            "    gpt2/h14/norm_2/b\n",
            "    gpt2/h14/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h15/norm_1/g\n",
            "    gpt2/h15/norm_1/b\n",
            "    gpt2/h15/attn/compute_output_bias/o_b\n",
            "    gpt2/h15/norm_2/g\n",
            "    gpt2/h15/norm_2/b\n",
            "    gpt2/h15/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h16/norm_1/g\n",
            "    gpt2/h16/norm_1/b\n",
            "    gpt2/h16/attn/compute_output_bias/o_b\n",
            "    gpt2/h16/norm_2/g\n",
            "    gpt2/h16/norm_2/b\n",
            "    gpt2/h16/mlp/conv1d_main/c_proj/bias\n",
            "Trainable Variables            count: 200     Total size: 2651307520       Total slice_size: 381853440      \n",
            "All Variables                  count: 200     Total size: 2651307520       Total slice_size: 381853440      \n",
            "Counters:\n",
            "allreduce: 1.68e+10\n",
            " allreduce/[0]: 5.37e+09\n",
            "  allreduce/[0]/einsum_op: 5.37e+09\n",
            " allreduce/[1]: 1.14e+10\n",
            "  allreduce/[1]/einsum_op: 1.14e+10\n",
            "  allreduce/[1]/reduce_op: 1.9e+07\n",
            "einsum: 3.19e+13\n",
            "einsum_unique: 2.48e+13\n",
            "output: 2.02e+11\n",
            " output/AddOperation: 5.68e+10\n",
            " output/BinaryOpWithBroadcasting: 6.88e+08\n",
            " output/BroadcastOperation: 5.4e+09\n",
            " output/ConcatOperation: 2.69e+09\n",
            " output/Constant: 2.62e+05\n",
            " output/EinsumOperation: 5.59e+10\n",
            " output/ImportOperation: 1.31e+05\n",
            " output/OneHotOperation: 3.33e+09\n",
            " output/RangeOperation: 3.19e+05\n",
            " output/ReduceOperation: 2.95e+07\n",
            " output/ReshapeOperation: 1.01e+10\n",
            " output/ScalarAddOperation: 5.37e+09\n",
            " output/ScalarMultiplyOperation: 1.89e+10\n",
            " output/ShiftOperation: 1.34e+09\n",
            " output/SlicewiseOperation: 2.73e+10\n",
            " output/StackedVariable: 2.64e+06\n",
            " output/StopGradient: 8.05e+09\n",
            " output/UnstackOperation: 2.64e+06\n",
            " output/Variable: 3.05e+09\n",
            " output/WhileLoopOperation: 2.68e+09\n",
            "output_unique: 1.09e+11\n",
            " output_unique/AddOperation: 3.1e+10\n",
            " output_unique/BinaryOpWithBroadcasting: 8.81e+07\n",
            " output_unique/BroadcastOperation: 5.38e+09\n",
            " output_unique/ConcatOperation: 1.34e+09\n",
            " output_unique/Constant: 3.28e+04\n",
            " output_unique/EinsumOperation: 2.53e+10\n",
            " output_unique/ImportOperation: 1.64e+04\n",
            " output_unique/OneHotOperation: 4.16e+08\n",
            " output_unique/RangeOperation: 4.1e+04\n",
            " output_unique/ReduceOperation: 1.16e+07\n",
            " output_unique/ReshapeOperation: 5.37e+09\n",
            " output_unique/ScalarAddOperation: 2.68e+09\n",
            " output_unique/ScalarMultiplyOperation: 8.75e+09\n",
            " output_unique/ShiftOperation: 6.71e+08\n",
            " output_unique/SlicewiseOperation: 1.75e+10\n",
            " output_unique/StackedVariable: 8.24e+05\n",
            " output_unique/StopGradient: 6.71e+09\n",
            " output_unique/UnstackOperation: 8.24e+05\n",
            " output_unique/Variable: 2.65e+09\n",
            " output_unique/WhileLoopOperation: 1.34e+09\n",
            "variables: 2.65e+09\n",
            " variables/trainable: 2.65e+09\n",
            "Done calling model_fn.\n",
            "TPU job name worker\n",
            "Graph was finalized.\n",
            "Restoring parameters from gs://test-bucket-neo/GPT3_2-7B/model.ckpt-400000\n",
            "Running local_init_op.\n",
            "Done running local_init_op.\n",
            "From /usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:840: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Prefer Variable.assign which has equivalent behavior in 2.X.\n",
            "Starting infeed thread controller.\n",
            "Starting outfeed thread controller.\n",
            "Initialized dataset iterators in 0 seconds\n",
            "Before copy master to slices.\n",
            "Done with copy master to slices.\n",
            "Enqueue next (1) batch(es) of data to infeed.\n",
            "Dequeue next (1) batch(es) of data from outfeed.\n",
            "Outfeed finished for iteration (0, 0)\n",
            "======================================== SAMPLE 0 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "\n",
            "    def forward(self, input):\n",
            "        \"\"\" return gpt from position embedding (embedding for position and context)\"\"\"\n",
            "        return GPT(input, self.pos_emb, self.tok_emb, self.drop, self.ln_f, self.head)\n",
            "\n",
            "    def get_type_log_probability(self, input, target, p_type):\n",
            "        \"\"\" get negative log-likelihood for the current probability (p_type)\n",
            "        \"\"\"\n",
            "        embedding = self.tok_emb(input)\n",
            "        return nn.log_softmax(embedding, dim=1) / sum(input.size(1) for input in input)\n",
            "\n",
            "\n",
            "def update_parameters_for_training(model, input_length, targets,\n",
            "                                   target_length, context_size, apply_onehot=False):\n",
            "    \"\"\" update parameters after re-training or training in 2-shot.\n",
            "\n",
            "            model.set_params(...)..returns(model_post_training)\n",
            "            model_post_training: the updated model\n",
            "    \"\"\"\n",
            "    if not model.sampler:\n",
            "        model.reset_params()\n",
            "    elif model.sampler.get_seed()!= 0 or limit_sampled_sequences(model.sampler.get_seed()):\n",
            "        if apply_onehot:\n",
            "            model.reset_params()\n",
            "\n",
            "    loss = nn.BCELoss()\n",
            "    model.loss = loss\n",
            "    model.disp = model.disp + (1.0 - model.disp) * model.log_prob(input_length, target_length)\n",
            "    model.mean_disp = model.disp\n",
            "    model.mean_pos = model.pos\n",
            "    score = model.log_prob(target_length, target_length)\n",
            "    if (input_length == target_length):\n",
            "        # single shot - ignore intro, ilux and outros\n",
            "        if apply_onehot:\n",
            "            target[0][0] = '%s %s' % (target_length, target_length)\n",
            "        else:\n",
            "            target[0][0] = '%s %d' % (target_length, target_length)\n",
            "    else:\n",
            "        # 2-shot - batch one of the input embedding, multi-shot - batch by sequence.\n",
            "        targets = torch.cat([tuple([chr[0] if chr[0] in target[0] else '?' for chr in target])\n",
            "                             for target in target_length], 2)\n",
            "        target_length = len(targets)\n",
            "\n",
            "    pos_emb = self.pos_emb(input)\n",
            "    tok_emb = self.tok_emb(input)\n",
            "    drop = self.drop(input)\n",
            "\n",
            "    head_drop = tok_emb.nonlinearity * drop\n",
            "\n",
            "    for ln_f in self.ln_f:\n",
            "        self.ln_f = nn.LayerNorm(self.n_embd)\n",
            "        self.ln_f.weight.data.zero_()\n",
            "        self.ln_f.bias.data.zero_()\n",
            "\n",
            "    for block in self.blocks:\n",
            "        self.head_drop.weight.data.zero_()\n",
            "        self.head_drop.bias.data.zero_()\n",
            "\n",
            "    for i in range(self.n_layer):\n",
            "        param_tuple = (i, block, head_drop, len(targets), config.init_lstm_c)\n",
            "        t_pos, t_targets, _ = torch.max(target, param_tuple[0], param_tuple[1])\n",
            "\n",
            "        # fast threshold -> 1 will be equal to target, non-zero will not be all 0\n",
            "        t_pos = t_pos if t_pos == 0 else 1\n",
            "        t_targets = t_targets if t_targets == 0 else 1\n",
            "        self.pixel_to_pos = target[t_pos:t_pos+1]\n",
            "\n",
            "        # linear decrease\n",
            "        self.disp_drop = tok_emb.nonlinearity * self.drop(t_pos)\n",
            "        self.disp_drop.weight.data.zero_()\n",
            "\n",
            "        self.weight_reset = torch.zeros(2)\n",
            "        self.bias_reset = torch.zeros(2)\n",
            "\n",
            "        emb_tok_id = model.pixel_to_pos\n",
            "        weight_last = Embedding(1, config.n_embd)\n",
            "        self.q = weight_last(emb_tok_id)\n",
            "        self.q_last = weight_last(self.q)\n",
            "\n",
            "        mask_name = '%s/%s/%s_%d' % (config.tok_id, config.pos_id, tok_emb.size(), pos_emb.size())\n",
            "        self.loss_mask = nn.LogSoftmax(dim=1)\n",
            "        self.loss_state = nn.Linear(config.n_embd+num_tok_c, config.n_embd)\n",
            "        self.target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\n",
            "        self.loss_target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\n",
            "        self.mask_loss_name = \"loss_mask\"\n",
            "        target_to_pos = nn.LogSoftmax(dim=1)\n",
            "        for i in\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 1 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "        # normalization\n",
            "        self.weight_gpu = nn.Parameter(torch.Tensor(self.weight.size(1).num()))\n",
            "        self.bias_gpu = nn.Parameter(torch.zeros(1).type(torch.float32))\n",
            "\n",
            "    def _init_weights(self):\n",
            "        num_b = self.head.weight.size(1)\n",
            "        drop_b = self.head.bias.size(0)\n",
            "        self.weight = nn.Parameter(torch.Tensor(num_b, drop_b))\n",
            "        self.bias = nn.Parameter(torch.zeros(drop_b).type(torch.float32))\n",
            "\n",
            "    def forward(self, H, g, X): \n",
            "        \"\"\"  - token-level feed forward\n",
            "        - Embed Otherwise\n",
            "            (f) g is ignored for the embeddings, and this is only used to save the\n",
            "                gpt translation encoder memory.\n",
            "        \"\"\"\n",
            "        output = {}\n",
            "        if self.head.keep:\n",
            "            X_top = X_top.view(-1, self.emb_size, 1)\n",
            "\n",
            "            for j in range(self.head.nheads):\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "\n",
            "            for i in range(self.head.n_layer):\n",
            "                H = torch.cat([H, self.ln_f(H)[0]]).view(-1)\n",
            "                if self.drop > 0:\n",
            "                    g = torch.zeros_like(H.long()).float()\n",
            "                else:\n",
            "                    g = H.long()\n",
            "\n",
            "                g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\n",
            "                if self.apply_del_emb:\n",
            "                    output[j] = g.transpose(0, 1)\n",
            "                else:\n",
            "                    output[j] = self.head(g)\n",
            "\n",
            "                H = H.transpose(0, 1)\n",
            "        else:\n",
            "            X_top = X_top.view(-1, self.emb_size, 1)\n",
            "            for j in range(self.head.nheads):\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "\n",
            "            for i in range(self.head.n_layer):\n",
            "                g = torch.cat([self.ln_f(H)[0], g])[0]\n",
            "                if self.drop > 0:\n",
            "                    g = torch.zeros_like(g).float()\n",
            "                else:\n",
            "                    g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\n",
            "                if self.apply_del_emb:\n",
            "                    output[j] = g.transpose(0, 1)\n",
            "                else:\n",
            "                    output[j] = self.head(g)\n",
            "\n",
            "        output = output[\"h\"].transpose(0, 1)\n",
            "        return output\n",
            "\n",
            "\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 2 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "\n",
            "        self.optimizer = optim.Adam(\n",
            "            self.head,\n",
            "            parameters_ub=self.parameters(),\n",
            "            lam=config.initial_learning_rate\n",
            "        )\n",
            "\n",
            "    def forward(self, input_text):\n",
            "        \"\"\" the overall model.co: forward pass \"\"\"\n",
            "\n",
            "        limit = self.head.output_size(0)\n",
            "        head = self.head\n",
            "        attn = self.head.weight\n",
            "        # tagwith = self.head.weight\n",
            "\n",
            "        block = self.blocks[:,0][self.block_size:,:]\n",
            "        forward_attn = block(attn)\n",
            "        forward_text = forward_attn + input_text\n",
            "        forward_text = conv_block(forward_text)\n",
            "        forward_text = forward_linear(forward_text)\n",
            "        forward_text = forward_linear(forward_linear(forward_text))\n",
            "\n",
            "        lower_attn = (self.tok_emb(forward_text)).sum(1, keepdim=True)\n",
            "        # lower_attn = self.tok_emb(1)\n",
            "\n",
            "        #rnn_basic_block1 = forward_attn[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:].view(1, 1, self.block_size, -1)\n",
            "        #rnn_basic_block1 = rnn_tok(forward_text[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0), :].transpose(1, 0, 2) + forward_text[self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:])\n",
            "        #rnn_basic_block1_drop = nn.Dropout(config.drop_rate)\n",
            "        #print(rnn_basic_block1_drop.shape)\n",
            "        #print(rnn_basic_block1.weight.shape)\n",
            "        # post_drop = rnn_basic_block1_drop.view(1, 1, self.block_size, 1)\n",
            "        # rnn_part_block1 = forward_attn[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(1, self.block_size, self.head.n_layer)\n",
            "        # post_drop = post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\n",
            "\n",
            "        #rnn_part_block1 = (head(head(numpy.squeeze(forward_text), 1)))[:self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.block_size, self.head.layers_[0].n_layer, -1)\n",
            "        #rnn_part_block1 = rnn_basic_block1_drop + post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\n",
            "\n",
            "        lower_rnn_text = (head(head(numpy.squeeze(forward_text), 1)))[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.head.n_layer, self.block_size, -1)\n",
            "        lower_rnn = rnn_tok(lower_rnn_text)\n",
            "\n",
            "        #attention_layers = self.attention_layer\n",
            "        #context_attention_layers = self.context_attention_layer\n",
            "        #attn_context_layers = self.attention_layer + self.context_attention_layer\n",
            "        #attn_context_layers = self.attention_layer\n",
            "        #propagation_layers = self.proper_layer + self.context_attention_layer\n",
            "\n",
            "        return lower_attn + lower_rnn_text + lower_rnn\n",
            "\n",
            "    def backward(self, grad_output, grad_input):\n",
            "        \"\"\" the model.co: backward pass \"\"\"\n",
            "\n",
            "        grad_weight = torch.matmul(grad_output[self.head.layers_[0].output_size(0)], grad_input.contiguous())\n",
            "        return grad_weight.view(batch_size, -1, self.head.n_layer), grad_weight.view(batch_size, -1, self.head.n_layer)\n",
            "\n",
            "    def clip_gradient(self, grad_input):\n",
            "        \"\"\" clip gradient \"\"\"\n",
            "        logger.warning(\"clip_gradient: clip(grad_input, 0.0 - 1.0)\")\n",
            "        return grad_input.clamp(0.0 - 1.0).detach().cpu().numpy()\n",
            "\n",
            "    def _get_cell(self, name):\n",
            "        if self.args.tied_base_model:\n",
            "            return self.head.layers_[name].n_op\n",
            "\n",
            "        return self.head.layers_[name]\n",
            "\n",
            "    def _get_head(self, head_name):\n",
            "        if self.args.tied_base_model:\n",
            "            return head_name\n",
            "\n",
            "        return self.head.n_layer\n",
            "\n",
            "\n",
            "    def forward_gpt_cell(self, head):\n",
            "        \"\"\" the forward pass of the gpt\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 3 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "        logger.info(\"images size: %e\", config.images_len)\n",
            "        logger.info(\"embedding size: %e\", config.embedding_size)     \n",
            "\n",
            "        self.vocab_size = config.vocab_size\n",
            "        self.hidden_size = config.hidden_size\n",
            "        self.n_layer = config.n_layer\n",
            "        self.block_size = config.block_size\n",
            "        self.cell_dim = config.cell_dim\n",
            "        self.n_embd = config.n_embd\n",
            "        self.n_embd = config.n_embd\n",
            "        self.embd_pdrop = config.embd_pdrop\n",
            "        self.n_batch = config.n_batch\n",
            "        self.n_dembd = config.n_dembd\n",
            "        self.101k_embd = config.101k_embd\n",
            "        self.shotting_dist = config.shotting_dist\n",
            "        self.dropout = config.embd_pdrop\n",
            "\n",
            "        # init variables\n",
            "        self._init_weights()\n",
            "\n",
            "    def _init_weights(self):\n",
            "        for layer in self.blocks:\n",
            "            for cell in layer:\n",
            "                param_init = cell.init_weights()\n",
            "                self.parameters()[layer][cell] = param_init.assign(param_init)\n",
            "\n",
            "    def forward(self, x, gpt_emb, gpt_state, gpt_emb_dim, gpt_state_dim):\n",
            "        \"\"\"  a forward pass for language model derivations\n",
            "\n",
            "            input, latent and context embeddings of convolutional layers as well as the entity embedding to obtain the\n",
            "            topic-embedding applied to the gpt entity embedding to generate the knowledge graph representation\n",
            "            gpt latent representations are then transformed into some vector representation\n",
            "\n",
            "            latent representation is then used as input to the decoder head, to produce the gpt entity representation\n",
            "\n",
            "            finally, the gpt entity representation is used as input to the decoder head, to produce the gpt latent representation on top of which\n",
            "            the knowledge graph representation is constructed\n",
            "        \"\"\"\n",
            "\n",
            "        n_ctx = len(x)\n",
            "        x_bn = x.nonzero()[0]/n_ctx\n",
            "        latent_bn = x_bn.nonzero()[0]/n_ctx\n",
            "        cv_emb = self.tok_emb(x_bn)\n",
            "        # consider the entity embedding to get the gpt latent representation\n",
            "        entity_mask = self.apply(gpt_emb_dim) if self.embd else 0\n",
            "        latent = self.apply(gpt_state_dim)\n",
            "        latent = latent * gpt_emb + entity_mask * gpt_state + self.drop\n",
            "\n",
            "        self.ln_f.weight.data.fill_(1.0)\n",
            "        self.ln_f.bias.data.zero_()\n",
            "        self.ln_f.weight.data[0].copy_(self.tok_emb)\n",
            "        self.ln_f.bias.data[0].copy_(self.pos_emb)\n",
            "        mlp = nn.Linear(config.hidden_size, config.n_embd)\n",
            "        mlp.bias.data[0].copy_(self.hidden_size)\n",
            "        self.ln_f.weight.data[0].copy_(mlp.weight.data)\n",
            "        # get the gpt latent representation on top of which the knowledge graph\n",
            "        ln_gpt_emb = self.apply(gpt_emb_dim)\n",
            "        # ln_gpt_emb = logits.sample(self.shotting_dist)\n",
            "        # ln_gpt_emb_shape = [1]\n",
            "        # gpt_ln_emb.data[0].copy_(ln_gpt_emb.data[0])\n",
            "        # gpt_ln_emb_shape = [0]\n",
            "        # gpt_gpt_emb = gpt_ln_emb.gather([0], gpt_ln_emb.shape)\n",
            "        # get the gpt latent representation to be used as the starting latent embedding of the decoder\n",
            "        ln_src_emb = gpt_ln_emb\n",
            "        ln_state = gpt_ln_emb.gather([0], ln_gpt_emb.shape)\n",
            "\n",
            "        # get the context and latent embedding representation of the entire input\n",
            "        # x_ext = x_bn[latex_str].squeeze()\n",
            "        self.apply(n)\n",
            "        # get the context representation used to decode the embedded gpt representation\n",
            "        x_ext = x_bn[latex_str].squeeze()\n",
            "        x_ext = x_ext.transpose(1, 0)\n",
            "        x_ext = F.relu(self.apply(x_ext))\n",
            "        x_ext_bn = x_ext.transpose(1, 0)\n",
            "        # x_ext_bn = x_ext_bn.transpose(1, 0)\n",
            "        # initialize the decoder hidden state\n",
            "        ln_src_emb, diff_emb = collections.defaultdict(list), []\n",
            "        for i, i_emb in enumerate(self.ln_f):\n",
            "            i_blk = int(self.block_size*(i+1))\n",
            "            mlp = nn.Linear(context_embedding_dim, n\n",
            "\n",
            "================================================================================\n",
            "\n",
            "Enqueue next (1) batch(es) of data to infeed.\n",
            "Dequeue next (1) batch(es) of data from outfeed.\n",
            "Outfeed finished for iteration (1, 0)\n",
            "Stop infeed thread controller\n",
            "Shutting down InfeedController thread.\n",
            "InfeedController received shutdown signal, stopping.\n",
            "Infeed thread finished, shutting down.\n",
            "infeed marked as finished\n",
            "Stop output thread controller\n",
            "Shutting down OutfeedController thread.\n",
            "OutfeedController received shutdown signal, stopping.\n",
            "Outfeed thread finished, shutting down.\n",
            "outfeed marked as finished\n",
            "Shutdown TPU system.\n",
            "prediction_loop marked as finished\n",
            "prediction_loop marked as finished\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nE9VImzHaI0z"
      },
      "source": [
        "# Evaluating the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XGGbkgaFfp6f"
      },
      "source": [
        "This section assumes you are using a pretrained model and relies on variables created in the `Pretrained model` section."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I45yUIpbaLUJ"
      },
      "source": [
        "## Wikitext"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwBDB9U2keFV"
      },
      "source": [
        "Download the wikitext test set:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uuugiBmJaNxf"
      },
      "source": [
        "wikitext103_src = \"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip\"\n",
        "!wget $wikitext103_src\n",
        "!unzip wikitext-103-raw-v1.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J5wf3QWKkhZt"
      },
      "source": [
        "Tokenize and upload to bucket:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6mo8UUtDdctH"
      },
      "source": [
        "\n",
        "!mkdir wikitext\n",
        "!mv /content/GPTNeo/wikitext-103-raw/wiki.test.raw wikitext/wikitext_test.txt\n",
        "\n",
        "# Tokenize Data\n",
        "!python data/create_tfrecords.py --input_dir wikitext --name wikitext --files_per 1000 --output_dir wikitext_tokenized --write_dataset_config --processes 1 --wikitext-detokenize\n",
        "\n",
        "# copy the data to your bucket\n",
        "if not path_to_cloud_bucket.endswith('/'):\n",
        "  path_to_cloud_bucket += '/'\n",
        "copy_loc = path_to_cloud_bucket \n",
        "!gsutil -m cp -r wikitext_tokenized $copy_loc\n",
        "!gsutil ls $path_to_cloud_bucket"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GE84TUd1fAzf"
      },
      "source": [
        "Now make a dataset config that points to the tokenized wikitext data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z5UU7DQeeY0S"
      },
      "source": [
        "%%writefile configs/dataset_configs/wikitext.json\n",
        "\n",
        "{\n",
        "  \"path\": \"\",\n",
        "  \"eval_path\": \"gs://test-bucket-neo/wikitext_tokenized/*.tfrecords\",\n",
        "  \"n_vocab\": 50256,\n",
        "  \"tokenizer_is_pretrained\": true,\n",
        "  \"tokenizer_path\": \"gpt2\",\n",
        "  \"eos_id\": 50256,\n",
        "  \"padding_id\": 50257\n",
        "}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "egvdwIOqfFER"
      },
      "source": [
        "And update your model config to point to that dataset:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "AtdoIFMgfOe8"
      },
      "source": [
        "# @title Modify config for wikitext. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "assert pretrained_model is not None\n",
        "with open(f'configs/{pretrained_model}.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  pprint(data)\n",
        "  dset_val = [[\"wikitext\", None, None, None]]\n",
        "  mods = {\n",
        "          \"datasets\": dset_val,\n",
        "          \"eval_steps\": 139 // batch_size,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"eval_batch_size\": batch_size,\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2d5eTHEg6Xj"
      },
      "source": [
        "Now run model in eval mode over tokenized data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s1Uz3PXzg5Pm"
      },
      "source": [
        "!python3 main.py --eval --tpu colab --model $pretrained_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9dbkPVcMhVaR"
      },
      "source": [
        "## Lambada\n",
        "\n",
        "Lambada eval is built into the codebase and can be run by adding a field to your model config"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "z4FJXOlJiEYo"
      },
      "source": [
        "# @title Modify config for Lambada. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "assert pretrained_model is not None\n",
        "with open(f'configs/{pretrained_model}.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  mods = {\n",
        "          \"datasets\": dset_val,\n",
        "          \"eval_steps\": 0,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"eval_batch_size\": batch_size,\n",
        "          \"eval_tasks\": [\"lambada\"]\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Upp-bGMriVPK"
      },
      "source": [
        "Now run the eval:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OOA1YZDRiUhN"
      },
      "source": [
        "!python3 main.py --eval --tpu colab --model $pretrained_model"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2020 EleutherAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# GPT Neo

[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5297715.svg)](https://doi.org/10.5281/zenodo.5297715) [![arXiv](https://img.shields.io/badge/arXiv-2101.00027-f9f107.svg)](https://arxiv.org/abs/2101.00027)

**As of August, 2021 code is no longer maintained. It is preserved here in archival form for people who wish to continue to use it.*

🎉 1T or bust my dudes 🎉

An implementation of model & data parallel [GPT3](https://arxiv.org/abs/2005.14165)-like models using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library.

**If you're just here to play with our pre-trained models, we strongly recommend you try out the [HuggingFace Transformer integration](https://huggingface.co/EleutherAI).**

Training and inference is officially supported on TPU and should work on GPU as well. This repository will be (mostly) archived as we move focus to our GPU-specific repo, [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).

In addition to the functionality offered by GPT-3, we also offer the following:
* [Local attention](https://arxiv.org/abs/2004.05150)
* [Linear attention](https://arxiv.org/abs/1812.01243)
* [Mixture of Experts](https://arxiv.org/abs/1701.06538)
* [Axial Positional embedding](https://arxiv.org/abs/1912.12180)

NB, while neo can *technically* run a training step at 200B+ parameters, it is very inefficient at those scales. This, as well as the fact that many GPUs became available to us, among other things, prompted us to move development over to [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).

# Pretrained Models

**Update 21/03/2021:**

We're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/public/AI/gptneo-release/).

1.3B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_XL/

2.7B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/

For more information on how to get these set up, see the colab notebook, or read through the rest of the readme.

## Model Evaluations

#### Linguistic Reasoning

| Model and Size   | Pile BPB   | Pile PPL  | Wikitext PPL | Lambada PPL | Lambada Acc | Winogrande | Hellaswag  |
|------------------|------------|-----------|--------------|-------------|-------------|------------|------------|
| **GPT-Neo 125M** | -----      | -----     | **32.285**   | **30.266**  | **37.36%**  | **50.43%** | **28.67%** |
| GPT-3 125M       | -----      | -----     | -----        | 18.6        | 42.7%       | 52.0%      | 33.7%      |
| **GPT-Neo 350M** | -----      | -----     | **22.5657**  | **13.876**  | **47.27%**  | **51.14%** | **32.16%** |
| GPT-3 350M       | -----      | -----     | -----        | 9.09        | 54.3%       | 52.1%      | 43.6%      |
| GPT-3 Ada        | 0.9631     | -----     | -----        | 9.954       | 51.60%      | 52.90%     | 35.93%     |
| **GPT-Neo 1.3B** | **0.7527** | **6.159** | **13.10**    | **7.498**   | **57.23%**  | **55.01%** | **38.66%** |
| GPT-3 1.3B       | -----      | -----     | -----        | 5.44        | 63.6%       | 58.7%      | 54.7%      |
| GPT-2 1.5B       | 1.0468     | -----     | 17.48        | 10.634      | 51.21%      | 59.40%     | 40.03%     |
| **GPT-Neo 2.7B** | **0.7165** | **5.646** | **11.39**    | **5.626**   | **62.22%**  | **56.50%** | **42.73%** |
| GPT-3 2.7B       | -----      | -----     | -----        | 4.60        | 67.1%       | 62.3%      | 62.8%      |


#### Physical and Scientific Reasoning

| Model and Size   | MathQA     | PubMedQA   | Piqa       |
|------------------|------------|------------|------------|
| **GPT-Neo 125M** | **22.78%** | **55.10%** | **63.06%** |
| GPT-3 125M       | -----      | -----      | 64.6%      |
| **GPT-Neo 350M** | **23.45%** | **53.80%** | **65.07%** |
| GPT-3 350M       | -----      | -----      | 70.2%      |
| GPT-3 Ada        | 24.29%     | 52.80%     | 68.88%     |
| **GPT-Neo 1.3B** | **24.05%** | **54.40%** | **71.11%** |
| GPT-3 1.3B       | -----      | -----      | 75.1%      |
| GPT-2 1.5B       | 23.64%     | 58.33%     | 70.78%     |
| **GPT-Neo 2.7B** | **24.72%** | **57.54%** | **72.14%** |
| GPT-3 2.7B       | -----      | -----      | 75.6%      |


**Note:** All evaluations were done using our [evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). Some results for GPT-2 and GPT-3 are inconsistent with the values reported in the respective papers. We are currently looking into why, and would greatly appreciate feedback and further testing of our eval harness.

# Setup

```bash
git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt
```
# Training Setup

## TPUs:

Sign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage). 

Create your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and install the requirements with pip (see above).

Google colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes.
Click [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EleutherAI/GPTNeo/blob/master/GPTNeo_example_notebook.ipynb) to run through our example colab notebook.

For more detailed instructions, run through our [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below.

## GPUs:

You can also choose to train GPTNeo locally on your GPUs. To do so, you can omit the Google cloud setup steps above, and git clone the repo locally. Run through the [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below, then when running main.py, you simply have to omit the `tpu` flag, and pass in GPU ids instead.

Note: Some users have reported having difficulty getting MTF to recognize their GPUs. See [here](https://github.com/EleutherAI/gpt-neo/issues/150) for details and instructions on how to fix it.

# Generating Text

Once you have a trained model, or you've downloaded one of our pre-trained models, generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so:

```bash
python3 main.py --predict --prompt <example_prompt.txt> --tpu <tpu_name> --model <config_name>
```

or, if using GPUs:

```bash
python3 main.py --predict --prompt <example_prompt.txt> --gpu_ids <device:GPU:0 device:GPU:1> --model <config_name>
```

# Training Guide

## 1. Create your Tokenizer (OPTIONAL)

We recommend you use [Huggingface's pretrained GPT2 tokenizer](https://huggingface.co/transformers/model_doc/gpt2.html#transformers.GPT2Tokenizer) with our repo (instructions provided below), but if you want to train a model with a different vocabulary size, we provide facilities to train your own tokenizer like so:

```bash
python data/train_tokenizer.py \
    --base_dir ./path/to/your/txt/files \
    --output_dir ./output/path \
    --file_type txt \
    --vocab_size 50257

# if it succeeded, you should see the message
# 'tokenizer saved at ./output/path/byte-level-bpe.tokenizer.json'
```

## 2. Tokenizing your Dataset

If you just want to test training, you can skip this step and download some dummy data like so:

```
wget https://storage.googleapis.com/connors-datasets/bundestag/bundestag_0.tfrecords
```

Then copy the data to your bucket, or if using GPUs, a local directory: 

```
gsutil cp bundestag_0.tfrecords gs://<your bucket>/
```

If using your own data to train, you can use the `data/create_tfrecords.py` script to encode your text data into tfrecords.

Your data must either be in the form of lots of normal .txt files (one document per file), or in any format supported by [lm_dataformat](https://github.com/leogao2/lm_dataformat). 

You can run the script without parameters to see help for all options.

In **document mode** Each example in the tfrecords is one (variably sized) document. This is to be used with the `documents_fixed` and `documents_random` sampling modes (For more details see the parameters reference section).
Document mode is the default mode.

The below command will tokenize all files in acceptable formats in *base_dir* using gpt2 tokenizer and save them to *output_dir*
```
python3 create_tfrecords.py --mode documents --input_dir <base> --name <name> --output_dir <output> --use_gpt2_tokenizer --minimum_size <min> 
```

- `input_dir`: Defines the folder where your data is located. The script will encode all files present in this folder.
- `name`: Name of output files will be `name_i.tfrecords` where i is the number of the file.
- `output_dir`: Where to save the tfrecords to
- `use_gpt2_tokenizer`: Whether to use the pretrained HuggingFace GPT2 tokenizer, in which case the separator will be set to [50256].
- `encoder_path`: if not using the pretrained gpt2 tokenizer, use this flag to provide a path to your generated tokenizer json.
- `separator`: Written in list format, the separator token(s) to insert between documents (e.g. "[0]"). Will depend on your encoder.
- `minimum_size`: The minimum size (in tokens) a document must have, otherwise it is discarded. This is what will later determine your `stitch` parameter: `stitch * minimum_size` must always be greater or equal `n_ctx` (For more details see the parameters reference section).

## 4. Using a Dataset in a Model

To use a dataset in a model, you must first register that dataset under `./configs/dataset_configs` folder. First choose a filename with a `.json` extension. That filename will serve as the dataset identification. The config should be filled out the following manner.

If you have a dataset encoded using the pretrained gpt2 tokenizer, you can specify that like so:

```json
{
    "n_vocab": 50257,
    "path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
    "eval_path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
    "tokenizer_is_pretrained": true,
    "tokenizer_path": "gpt2"
}
```

or if you've trained a custom tokenizer, like so:

```json
{
    "n_vocab": 32768,
    "path": "./path/to/your/*.tfrecords",
    "eval_path": "./path/to/your/eval/*.tfrecords",
    "tokenizer_path": "./path/to/your/byte-level-bpe.tokenizer.json"
}
```

Finally, in your model config, add the filename that you created above to the `datasets` array.

The `<dataset id>` will be the filename, excluding the `.json`, that you created above

```
"datasets": [[<dataset id>, <stitch>, <datatype>, <weight>]] # datasets key defines at run time how each dataset is processed for training
```

## 5. Choose a model configuration

Once you have your datasets set up, find a suitable config in `/configs`.

Here we use a GPT3-XL sized model as an example, but there are many more in `./configs`, all of which have short summaries in the Available Configs section.

All you need to do is edit the dataset id as described above, and edit `model_path` (where logs and checkpoints will be saved) to point to a cloud bucket you have write access to (or local path, if using GPUs).

```json
{
    "n_head": 32,
    "n_vocab": 50257,
    "embed_dropout": 0.1,
    "lr": 0.0002,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.1,
    "train_batch_size": 512,
    "attn_dropout": 0.1,
    "train_steps": 286150,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0.1,
    "eval_batch_size": 128,
    "predict_batch_size": 1,
    "iterations": 2500,
    "n_embd": 2048,
    "datasets": [["your_dataset_name", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_XL",
    "n_ctx": 2048,
    "n_layer": 24,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],24]],
    "mesh_shape": "x:128,y:2",
    "layout": "batch:x,memory_length:y,embd:y",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0,
    "tokens_per_mb_per_replica": 2048
}
```


## 6. Run Training

```
python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --tpu <tpu-name>
```

- `tpu`: Name of the TPU to use.
- `steps_per_checkpoint`: The frequency in steps at which to save checkpoints.
- `--auto_layout` and `--auto_layout_and_mesh_shape` (Optional): Disable training and instead auto generate a memory efficient `layout` (and `mesh_shape`)
- `gpu_ids`: if training using GPUs, omit the `tpu` flag and pass in the ids of your gpus. In the example below, we train on 3 GPUs, specifying their device ids delimited by spaces:

```
python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --gpu_ids <device:GPU:0 device:GPU:1>
```

# Available Configs

We have several model sizes available, but some of our configs require large TPUs and will need tweaking to run on smaller machines, or GPUs. Below is a short guide to each model in the configs directory:

TODO

# Extra Features: 

## Training (with Sacred)

[Sacred](https://github.com/IDSIA/sacred) helps track experiments and is much nicer to work with than tensorboard.

To setup:

1. Install Docker and Docker-compose

2. Run `docker-compose up`

To use: 

1. Ensure model_dir doesn't have any metric logs in it (it trips up the metric stuff for tensorboard, which assumes that it's a continuation of the existing run). You can use `gsutil rm -r ...` to delete model dir

2. Run `python3 run_experiment.py --tpu sometpuhere --model someconfig.json` Options are the same as `main.py`. 

3. You can go to http://server_ip_goes_here:8081/ to see the Omniboard overview. If you prefer to see a tensorboard, the script also spins one up and automatically assigns it a port. The script should print out the tensorboard port near the top of the log. 

## Peeking at a Dataset

If you are ever confused by the dataset of a particular config file, you can easily check the minimum and maximum token ids with a single command. This is useful for making sure that the vocabulary size of the model is at least as large as the maximum token id. Tensorflow will not error if you try to gather on a matrix with out of bounds indices, so you need to make sure your vocabulary size is sufficiently large.

```bash
python main --model {config_name} --check_dataset
```

## Masked Language Modeling

In addition to being able to train large GPT's, this repository also allows you to easily do masked language modeling (BERT, RoBERTa). In order to do so, you must follow two additional steps.

1. When tokenizing your dataset, you must reserve a special id for the `[mask]` token.

2. In the configs, you will have to define two additional fields

```python
"mlm_training": true,                           # must be set to true
"mlm_mask_id": <mask id>                        # the mask id that you reserved from above
```

That's all you need to train a model with the MLM objective, good for any type of data that you have encoded properly. If you would like to tweak the other related hyperparameters, please continue reading.

```python
"mlm_cls_token_id": <cls token id>,                # auto append specified CLS token id on the left
"mlm_mask_prob": 0.15,                             # the probability of masking a token, defaults to 15%
"mlm_same_token_prob": 0.10,                       # probability of keeping the token the same, defaults to 10%
"mlm_random_token_prob": 0.10,                     # probability of tokens that are replaced with random tokens, 10% was recommended by the BERT paper
"mlm_mask_ignore_ids": [<cls token>, <sep token>]  # ignore masking other special tokens, if any
```

## Parameter Reference

Pick a valid config from `/configs` and tweak the parameters as needed:

- `n_heads`: The number of attention heads.
- `n_embd`: Size of the hidden layers, must be divisible by `n_heads`.
- `n_vocab`: Vocabulary size.
- `embed_dropout`, `res_dropout`, `attn_dropout`: Dropout probability for word embedding/residuals/attention
- `lr`: Learning rate
- `warmup_steps`: Number of steps before full learning rate is reached (linear ramp from `0` to `lr`).
- `lr_decay`: `cosine` or `linear`.
- `opt_name`: `adam` or `adafactor`.
- `beta1`, `beta2` and `epsilon`: `adam` optimizer params.
- `beta1`, `ada_epsilon1` and `ada_epsilon2`: `adafactor` optimizer params.
- `weight_decay`: Weight decay parameter, if not present no weight decay is used (the weight decay fix for Adam is used) (default: 0.01) (optional).
- `train_batch_size`: Batch size during training.
- `train_steps`: Number of training steps (batches), set to roughly ~1 epoch for now (total number of tokens in your dataset / number of tokens per batch (= `train_batch_size` / `n_ctx`)).
- `eval_steps`: Number of steps to run for each evaluation. Set to `0` for no eval. i.e After every checkpoint, the model is tested for `eval_steps`
- `iterations`: Number of steps queued to the TPU, must be smaller than `steps_per_checkpoint`. (default: 500)
- `datasets`: List of tfrecords datasets to use. Each dataset is a list with the following parameters: `[train glob , eval glob, stitch, sampling_mode, weight]`. So for example for a single dataset (note the double list): `[["bundestag_*.tfrecords", "", 10, "random_sample", 1.0]]`
    + `dataset_id`: The name of a dataset configuration file in `./configs/dataset_configs`
    + `stitch`: If `sampling_mode` `random_sample` is used, the input pipeline samples this amount of texts into one to sample from. You must select stitch so that `stitch * minimum_document_length >= n_ctx`
    + `sampling_mode`: `chunks` (tfrecords are preprocessed into the correct length and are read sequentially) or `documents_random` (`stitch` amount of documents are concatenated and then a `n_ctx` chunk is randomly subsampled)
    + `weights`: How much relative weight this dataset should have compared to others
- `model`: Which model to train. Currently only `GPT` is supported, and it defaults to this if not present.
- `model_path`: Google storage bucket location (or local path, if using GPUs) to save model checkpoints and logs.
- `n_ctx`: Size of context window. Default is 2048
- `n_layer`: Number of layers (blocks) in the model.
- `scale_by_depth`: If true, the weight initialization of layers are scaled by their depth as in the GPT2 paper.
- `scale_by_in`: If true, the weight initialization of layers are scaled by their number of inputs as in the GPT2 paper.
- `mesh_shape`: A Mesh is an n-dimensional array of processors with named dimensions used for parallelism in the mesh-tensorflow library. Each Tensor is split evenly across mesh dimensions according to the layout (see below). The 'mesh_shape' is the shape of this array, and must be equal to the number of processors. e.g., for a v3-128 TPU "mesh_shape": “x:16,y:8”.
- `layout`: A Tensor is laid out on its mesh with one slice on each processor. A Tensor "layout", is an injective partial map specifying which dimensions of the tensor are (evenly) split across which dimensions of the mesh. No dimension of a tensor may be split across two dimensions of its mesh and no two dimensions of a tensor may be split across the same dimension of its mesh. The user defines a global set of layout rules in the form of (tensor-dimension-name, mesh-dimension-name) pairs. A dimension of a tensor is split across a dimension of its mesh if there is a matching rule, e.g. (for the above example mesh_shape: "layout":"batch:x,heads:y"
- `activation_function`: `selu` (self normalizing) or `gelu` (used by OA), activation function used in feed-forward passes. (default: gelu)
- `attention_types`: the type of attention for each layer in a list of the following format [[["attention_type"], n_layers]]. e.g. for a 12 layer net [[["global"], 12]] or [[["local"], 10], [["global"], 2]].
    + Choose from: `linear`, `global`, `local` or `none`. We have found a 50/50 mix of `global` and `linear` to work well. `none` allows you to create feed-forward only layers for more efficient [PAR Transformer](https://arxiv.org/abs/2009.04534) models.
- `precision`: `float32` or `bfloat16`.
- `tokens_per_mb_per_replica`: If not None, will split the batch up into smaller microbatches containing `tokens_per_mb_per_replica` tokens to avoid OOMs. Gradients are accumulated locally and reduced once. IMPORTANT: mb refers to *minibatch* not megabyte here. 

**Mixture of Experts**

- `moe_layers`: A list of layer numbers to append a [mixture of experts](https://arxiv.org/abs/1701.06538) layer onto. E.G: `[2,4,6,8,10,12]`.
We have experimentally found a moe layer for every two self-attention layers to work well.
-  `moe_params`: a dictionary of additional kwargs to pass in to the moe layer. E.G
    `{"moe_dropout_rate": 0.0 }`
    
**Experimental features** 

- `axial_pos_emb_`: If true, uses [axial positional embedding](https://arxiv.org/abs/1912.12180. 
- `mlp_glu`: If true, uses a gated linear unit variant of feed forward layers.
- `scalenorm`: If true, uses scalenorm instead of layernorm.
- `rezero`: If true, uses [rezero](https://www.groundai.com/project/rezero-is-all-you-need-fast-convergence-at-large-depth/1) instead of layernorm.
- `num_mem_kv`: adds memory / key values from the [all-attention paper](https://arxiv.org/pdf/1907.01470.pdf). Param is an int with the number of desired mem/key values.
- `macaron`: if true - uses a [macaron transformer](https://arxiv.org/pdf/1906.02762.pdf) for each layer block.

## TODO: 

- [x] finalize documentation
- [ ] update configs

## Citing GPT-Neo

If you have found GPT-Neo helpful in your work, you can cite this repository as

```
@software{gpt-neo,
  author       = {Black, Sid and
                  Gao, Leo and
                  Wang, Phil and
                  Leahy, Connor and
                  Biderman, Stella},
  title        = {{GPT-Neo: Large Scale Autoregressive Language 
                   Modeling with Mesh-Tensorflow}},
  month        = mar,
  year         = 2021,
  note         = {{If you use this software, please cite it using 
                   these metadata.}},
  publisher    = {Zenodo},
  version      = {1.0},
  doi          = {10.5281/zenodo.5297715},
  url          = {https://doi.org/10.5281/zenodo.5297715}
}

```
The version number should be replaced with the version number you are using, and the year corresponds to the project's open-source release.

If you are specifically interested in citing the GPT-Neo models trained on [the Pile](https://arxiv.org/abs/2101.00027), we would appreciate also citing
```
@article{gao2020pile,
  title={The Pile: An 800GB Dataset of Diverse Text for Language Modeling},
  author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others},
  journal={arXiv preprint arXiv:2101.00027},
  year={2020}
}
```


================================================
FILE: configs/dataset_configs/example.json
================================================
{
	"n_vocab": 32768,
	"path": "./tfrecords/openwebtext_*.tfrecords",
	"eval_path": "",
	"tokenizer_path": "./datasets/openwebtext/byte-level-bpe.tokenizer.json",
	"eos_id": 1,
	"padding_id": 0
}


================================================
FILE: configs/dataset_configs/openwebtext2_new_inputs.json
================================================
{
	"n_vocab": 50257,
	"path": "gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords",
	"eval_path": "gs://neo-datasets/openwebtext2_new_inputs/eval/*.tfrecords",
	"tokenizer_is_pretrained": true,
	"tokenizer_path": "gpt2",
	"eos_id": 50256,
	"padding_id": 50257
}


================================================
FILE: configs/dataset_configs/pile.json
================================================
{
	"n_vocab": 50257,
	"path": "gs://neo-datasets/pile/pile_*.tfrecords",
	"eval_path": "gs://neo-datasets/pile_val.tfrecords",
	"tokenizer_is_pretrained": true,
	"tokenizer_path": "gpt2",
	"eos_id": 50256,
	"padding_id": 50257
}


================================================
FILE: configs/gpt2_small.json
================================================
{
    "n_head": 6,
    "n_vocab": 50257,
    "embed_dropout": 0.1,
    "lr": 0.0006,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0,
    "train_batch_size": 512,
    "attn_dropout": 0.1,
    "train_steps": 1000000,
    "lr_decay_end": 300000,
    "eval_steps": 30,
    "predict_steps": 0,
    "res_dropout": 0.1,
    "eval_batch_size": 128,
    "predict_batch_size": 8,
    "iterations": 2500,
    "n_embd": 768,
    "datasets": ["openwebtext2_new_inputs"],
    "model_path": "gs://neo-models/GPT2_SMALL",
    "n_ctx": 1024,
    "n_layer": 12,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],12]],
    "activation_function": "gelu",
    "mesh_shape": "all:64",
    "layout": "batch:all",
    "recompute_grad": false,
    "gradient_clipping": 1.0
}

================================================
FILE: configs/gpt3_13B_256.json
================================================
{
    "n_head": 40,
    "n_vocab": 50257,
    "embed_dropout": 0,
    "lr": 0.0001,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "ada_epsilon1": 1e-30,
    "ada_epsilon2": 1e-3,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 1024,
    "attn_dropout": 0,
    "train_steps": 143075,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 128,
    "predict_batch_size": 1,
    "iterations": 500,
    "n_embd": 5120,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_13B",
    "n_ctx": 2048,
    "n_layer": 40,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global", "local"],20]],
    "mesh_shape": "x:16,y:16",
    "layout": "batch:x,embd:y,memory_length:y",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0,
    "tokens_per_mb_per_replica": 2048,
    "precision": "bfloat16"
}



================================================
FILE: configs/gpt3_13B_256_Pile.json
================================================

{
    "n_head": 40,
    "n_vocab": 50257,
    "embed_dropout": 0,
    "lr": 0.0001,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.1,
    "train_batch_size": 1024,
    "attn_dropout": 0,
    "train_steps": 286150,
    "eval_steps": 10,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 512,
    "predict_batch_size": 1,
    "iterations": 500,
    "n_embd": 5120,
    "datasets": [["pile", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_13B_Pile",
    "n_ctx": 2048,
    "n_layer": 40,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],40]],
    "mesh_shape": "x:16,y:16",
    "layout": "batch:x,memory_length:y,embd:y",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0,
    "tokens_per_mb_per_replica": 2048,
    "precision": "bfloat16"
}


================================================
FILE: configs/gpt3_2-7B_256.json
================================================
{
    "n_head": 32,
    "n_vocab": 50257,
    "embed_dropout": 0,
    "lr": 0.00016,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "ada_epsilon1": 1e-30,
    "ada_epsilon2": 1e-3,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 512,
    "attn_dropout": 0,
    "train_steps": 286150,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 128,
    "predict_batch_size": 1,
    "iterations": 500,
    "n_embd": 2560,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_2-7B",
    "n_ctx": 2048,
    "n_layer": 32,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],32]],
    "mesh_shape": "x:128,y:2",
    "layout": "embd:y,batch:x",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0
}



================================================
FILE: configs/gpt3_6-7B_256.json
================================================
{
    "n_head": 32,
    "n_vocab": 50257,
    "embed_dropout": 0,
    "lr": 0.00012,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 1024,
    "attn_dropout": 0,
    "train_steps": 143075,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 128,
    "predict_batch_size": 1,
    "iterations": 500,
    "n_embd": 4096,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_6-7B",
    "n_ctx": 2048,
    "n_layer": 32,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],32]],
    "mesh_shape": "x:128,y:2",
    "layout": "embd:y,batch:x",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0
}



================================================
FILE: configs/gpt3_PAR_small_256.json
================================================
{
    "n_head": 12,
    "n_vocab": 50304,
    "embed_dropout": 0,
    "lr": 0.0006,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 256,
    "attn_dropout": 0,
    "train_steps": 572300,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 64,
    "predict_batch_size": 1,
    "iterations": 1000,
    "n_embd": 768,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_PAR_SMALL",
    "n_ctx": 2048,
    "n_layer": 19,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types": [[["global", "none", "none"],5], [["none"], 4]],
    "mesh_shape": "x:64,y:4",
    "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
    "activation_function": "gelu",
    "recompute_grad": false,
    "gradient_clipping": 1.0
}



================================================
FILE: configs/gpt3_XL_256_Pile.json
================================================
{
    "n_head": 32,
    "n_vocab": 50257,
    "embed_dropout": 0,
    "lr": 0.0002,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.1,
    "train_batch_size": 512,
    "attn_dropout": 0,
    "train_steps": 286150,
    "eval_steps": 10,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 512,
    "predict_batch_size": 1,
    "iterations": 500,
    "n_embd": 2048,
    "datasets": [["pile", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_XL_Pile",
    "n_ctx": 2048,
    "n_layer": 24,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],24]],
    "mesh_shape": "x:128,y:2",
    "layout": "batch:x,memory_length:y,embd:y",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0,
    "tokens_per_mb_per_replica": 2048,
    "precision": "bfloat16"
}


================================================
FILE: configs/gpt3_large_256.json
================================================
{
    "n_head": 16,
    "n_vocab": 50304,
    "embed_dropout": 0,
    "lr": 0.00025,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "ada_epsilon1": 1e-30,
    "ada_epsilon2": 1e-3,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 256,
    "attn_dropout": 0,
    "train_steps": 572300,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 64,
    "predict_batch_size": 1,
    "iterations": 2500,
    "n_embd": 1536,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_LARGE",
    "n_ctx": 2048,
    "n_layer": 24,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],24]],
    "mesh_shape": "x:64,y:4",
    "layout": "batch:x,vocab:y,heads:y",
    "activation_function": "gelu",
    "recompute_grad": true,
    "gradient_clipping": 1.0,
    "tokens_per_mb_per_replica": 2048
}



================================================
FILE: configs/gpt3_medium_256.json
================================================
{
    "n_head": 16,
    "n_vocab": 50304,
    "embed_dropout": 0,
    "lr": 0.0003,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 256,
    "attn_dropout": 0,
    "train_steps": 572300,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 64,
    "predict_batch_size": 1,
    "iterations": 2500,
    "n_embd": 1024,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_MEDIUM",
    "n_ctx": 2048,
    "n_layer": 24,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types" :  [[["global"],24]],
    "mesh_shape": "x:64,y:4",
    "layout": "batch:x,heads:y,vocab:y",
    "activation_function": "gelu",
    "recompute_grad": false,
    "gradient_clipping": 1.0
}



================================================
FILE: configs/gpt3_small_256.json
================================================
{
    "n_head": 12,
    "n_vocab": 50304,
    "embed_dropout": 0,
    "lr": 0.0006,
    "lr_decay": "cosine",
    "warmup_steps": 3000,
    "beta1": 0.9,
    "beta2": 0.95,
    "epsilon": 1e-8,
    "opt_name": "adam",
    "weight_decay": 0.10,
    "train_batch_size": 256,
    "attn_dropout": 0,
    "train_steps": 572300,
    "eval_steps": 0,
    "predict_steps": 1,
    "res_dropout": 0,
    "eval_batch_size": 64,
    "predict_batch_size": 1,
    "iterations": 2500,
    "n_embd": 768,
    "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
    "model_path": "gs://neo-models/GPT3_SMALL",
    "n_ctx": 2048,
    "n_layer": 12,
    "scale_by_depth": true,
    "scale_by_in": false,
    "attention_types": [[["global"],12]],
    "mesh_shape": "x:64,y:4",
    "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
    "activation_function": "gelu",
    "recompute_grad": false,
    "gradient_clipping": 1.0
}



================================================
FILE: configs.py
================================================
import json
from pathlib import Path
from collections import defaultdict

DATASETS = {}

for path in Path("configs/dataset_configs").glob("*.json"):
    dataset_id = path.stem
    DATASETS[dataset_id] = json.loads(path.read_text())


def fetch_model_params(model):
    model_path = model if model.endswith(".json") else f"configs/{model}.json"
    with open(model_path) as f:
        params = json.load(f)

    dataset_ids = []
    for d in params.get("datasets"):
        if isinstance(d, list):
            dataset_ids.append(d[0])
        else:
            dataset_ids.append(d)
    no_datasets = params.get("no_dataset", False)
    assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config"

    datasets = {}
    last_dataset = None
    for dataset_id in dataset_ids:
        assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder."
        dataset = DATASETS[dataset_id]
        assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})"
        datasets[dataset_id] = dataset
        last_dataset = dataset

    if last_dataset is not None:
        params["padding_id"] = last_dataset.get("padding_id", 0)
        params["eos_id"] = last_dataset.get("eos_id", 1)

    params["dataset_configs"] = datasets

    # Set some other parameter defaults
    params["mlm_training"] = params.get("mlm_training") == True
    params["causal"] = not params["mlm_training"]

    # Set all other parameter values to default to None
    params = defaultdict(lambda: None, params)
    return params


================================================
FILE: data/create_tfrecords.py
================================================
import argparse
import os
from pathlib import Path

import ftfy
import tensorflow as tf
from lm_dataformat import Reader
from tokenizers import Tokenizer
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import logging
from multiprocessing import Pool, cpu_count
from itertools import repeat
import re

logging.getLogger("transformers").setLevel(logging.ERROR)

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are "
                                                  "treated as archives, all others as raw text.")
parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord")
parser.add_argument("--name", type=str, default="openwebtext",
                    help="Name of output files will be name_i.tfrecords where i is the number of the file")
parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords")
parser.add_argument("--encoder_path", type=str,
                    help="Path to encoder files, or leave unspecified to use GPT2 tokenizer")
parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included")
parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy")
parser.add_argument("--wikitext-detokenize", action="store_false", help="use wikitext detokenizer")
parser.add_argument("--separator", nargs="+", type=int, default=[50256],
                    help="separator to place between files in chunk mode")
parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. "
                                                                 "Should equal your model's context size")
parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion")
parser.add_argument("--processes", type=int, default=0, help="Number of processes to use. Defaults to cpu count.")

args = parser.parse_args()
if not args.output_dir.endswith("/"):
    args.output_dir = args.output_dir + "/"
if not args.input_dir.endswith("/"):
    args.input_dir = args.input_dir + "/"
assert len(args.separator) == 1


def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


def _int64_feature(value):
    """
    Returns an int64_list from a bool / enum / int / uint.
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def write_to_file(writer, data):
    """
    writes data to tfrecord file
    """
    feature = {
        "text": _int64_feature(data)
    }
    tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(tf_example.SerializeToString())


def get_tokenizer(args):
    if args.encoder_path is None:
        return GPT2TokenizerFast.from_pretrained('gpt2')
    else:
        return Tokenizer.from_file(args.encoder_path)


def split_list(l, n):
    # splits list/string into n size chunks
    return [l[i:i + n] for i in range(0, len(l), n)]


def archive_to_tokens(f, encoder, args, prefix=[]):
    # Generator that yields the contents of the files in an archive
    # if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data
    reader = Reader(f)
    for doc in reader.stream_data(threaded=False):
        if args.ftfy:  # fix text with ftfy if specified
            doc = ftfy.fix_text(doc, normalization='NFKC')
        if args.wikitext_detokenize:
            doc = wikitext_detokenizer(doc)
        doc = encoder.encode(doc) + args.separator  # read document from lmd and append separator token
        yield split_list(prefix + doc, args.chunk_size)  # split into n_ctx + 1 size chunks
        prefix = []


def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
    # writes a list of files to .tfrecords
    if files == None:
        return
    chunks = split_list(files, files_per)
    if not chunks:
        return
      
    if len(chunks[-1]) != files_per and not write_remainder:  # pop the last file if it's length != files per
        remainder = chunks.pop(-1)
    else:
        remainder = None  # assuming files = remainder from an old chunk here
        files_per = len(chunks[-1])

    for files in chunks:
        fp = f"{output_dir}/{out_name}_{start_no}"
        if process_no is not None:
            fp += f"_{process_no}"
        fp += f"_{files_per}"  # add number of files in tfrecord to end of fp
        fp += ".tfrecords"
        with tf.io.TFRecordWriter(fp) as writer:
            for f in files:
                write_to_file(writer, f)
        start_no += 1
    return start_no, remainder


def get_files(input_dir, filetypes=None):
    # gets all files of <filetypes> in input_dir
    if filetypes == None:
        filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
    files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes]
    # flatten list of list -> list and stringify Paths
    flattened_list = [str(item) for sublist in files for item in sublist]
    if not flattened_list:
        raise Exception(f"""did not find any files at this path {input_dir},\
 please also ensure your files are in format {filetypes}""")
    return flattened_list


def read_checkpoint(checkpoint_path, resume_from_checkpoint=True):
    # init checkpointing
    if resume_from_checkpoint and os.path.isfile(checkpoint_path):
        try:
            resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, "r").read().split(", ")]
            print(f"\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}")
            return resume_files_processed, tfrecord_count
        except:
            pass
    return 0, 0


def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False,
                     resume_from_checkpoint=False, display_pbar=False):
    # iterates through files in input_dir, splitting into <args.chunk_size> chunks and saving a tfrecords file every <args.files_per> chunks.
    files, args, process_no = params
    enc = get_tokenizer(args)  # get tokenizer

    # init metadata
    discarded_files = 0
    files_processed = 0
    pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ",
                disable=not display_pbar)
    checkpoint_path = f"{args.output_dir}/checkpoint.txt"
    resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint)

    data_to_prepend = []
    tokenized_files_array = []

    for f in files:
        for tokenized_files in archive_to_tokens(f, enc, args, prefix=data_to_prepend):
            files_processed += 1
            if files_processed < resume_files_processed:
                continue  # resume from checkpoint

            # if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file
            data_to_prepend = []
            n_tokens = len(tokenized_files[-1])
            if n_tokens < args.chunk_size:
                data = tokenized_files.pop(-1)
                if n_tokens >= args.minimum_size:
                    data_to_prepend = data
                else:
                    discarded_files += 1

            # add tokenized files > chunk size to main array
            tokenized_files_array.extend(tokenized_files)

            if len(tokenized_files_array) >= args.files_per * write_every_n_files:  # write every n files
                _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
                                                         output_dir=args.output_dir, out_name=args.name,
                                                         start_no=tfrecord_count, process_no=process_no)
                pbar.update(_tfrecord_count - tfrecord_count)  # update progress bar
                pbar.set_description(
                    f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
                tfrecord_count = _tfrecord_count
                tokenized_files_array = remainder if remainder is not None else []  # add remaining files to next chunk
                with open(checkpoint_path, "w") as checkpoint_file:
                    checkpoint_file.write(f"{files_processed}, {tfrecord_count}")

    if len(tokenized_files_array) >= args.files_per:  # also write at end
        _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
                                                 output_dir=args.output_dir, out_name=args.name,
                                                 start_no=tfrecord_count, process_no=process_no)
        pbar.update(_tfrecord_count - tfrecord_count)
        pbar.set_description(
            f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
        tfrecord_count = _tfrecord_count
        with open(checkpoint_path, "w") as checkpoint_file:
            checkpoint_file.write(f"{files_processed}, {tfrecord_count}")
    else:
        remainder = tokenized_files_array  # add remaining to remainder

    if write_remainder:
        # write out the remaining files even if there's less than files_per
        write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name,
                    start_no=tfrecord_count, write_remainder=True)

    successful_files = files_processed - discarded_files
    return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files}


def create_tfrecords_mp(files, args):
    files = split_list(files, len(files) // args.processes)
    with Pool(processes=args.processes) as pool:
        pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files)))))
        meta = {"discarded": 0, "processed": 0, "successful": 0}
        for results in pbar:
            pbar.update()
            for k, v in results.items():
                meta[k] += v  # update metadata
        return meta


if __name__ == "__main__":
    os.makedirs(args.output_dir, exist_ok=True)  # make output dir if it doesn't exist
    files = get_files(args.input_dir)
    args.chunk_size += 1  # we shift the data by 1 to the right for targets, so increment the chunk size here

    if args.processes == 0:
        args.processes = cpu_count()
    if args.processes > 1:
        results = create_tfrecords_mp(files, args)
    else:
        results = create_tfrecords((files, args, 0), display_pbar=True)
    print(results)


================================================
FILE: data/encoders.py
================================================
from tokenizers import Tokenizer
from transformers import GPT2Tokenizer, GPT2TokenizerFast

def fetch_encoder(params):
    no_dataset = params.get('no_dataset', False)
    if no_dataset:
        return None

    dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
    path = dataset["tokenizer_path"]
    is_pretrained = dataset.get("tokenizer_is_pretrained", False)

    if is_pretrained:
        tok = GPT2TokenizerFast.from_pretrained(path)

        # Will add a padding token id of 50257 at run-time
        tok.add_special_tokens({'pad_token': '<|padding|>'})
        return tok

    return Tokenizer.from_file(path)


# GPT2Tokenizer and Tokenizer have different ways of fetching token ids
def encode(encoder, text):
    result = encoder.encode(text)
    if isinstance(result, list):
        return result
    return result.ids


================================================
FILE: data/train_tokenizer.py
================================================
import os
import random
import argparse
import shutil
from glob import glob
from pathlib import Path

from lm_dataformat import Reader
from tokenizers import (Tokenizer, decoders, models, pre_tokenizers,
                        processors, trainers)
from tokenizers.normalizers import NFKC
from tqdm import tqdm

# parser

parser = argparse.ArgumentParser()
parser.add_argument("--base_dir", type=str, help="Path to where your files are located. Files ending in .zst are treated as \
                    archives, all others as raw text.")
parser.add_argument("--output_dir", type=str, default="tokenizers", help="Where to put the tokenizer")
parser.add_argument("--file_type", type=str, choices=["xz", "txt"], default="xz", help="Extension of file to parse")
parser.add_argument("--vocab_size", type=int, help="Size of vocabulary", required = True)
args = parser.parse_args()

# main script

data_path = Path(args.base_dir)
archives = glob(str(data_path / f"*.{args.file_type}"))

out_path = Path(args.output_dir)

if os.path.exists(out_path):
    shutil.rmtree(out_path)

if not out_path.is_dir():
    out_path.mkdir()

    for arch in tqdm(archives):
        name = os.path.basename(arch).split(".")[0] + ".txt"
        fp = out_path / name

        if args.file_type == 'xz':
            g = Reader(arch).stream_data()

            with open(fp, "w") as f:
                for s in g:
                    f.write(s)
                    f.write("\n\n")
        elif args.file_type == 'txt':
            shutil.copyfile(str(arch), str(fp))

data_files = glob(str(out_path / "*.txt"))
data_files = random.sample(data_files, int(0.2 * len(data_files)))

assert len(data_files) > 0, 'No data files found'

# Initialize a tokenizer
tokenizer = Tokenizer(models.BPE())

# Customize pre-tokenization and decoding
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
tokenizer.normalizer = NFKC()

# And then train
trainer = trainers.BpeTrainer(vocab_size=args.vocab_size, min_frequency=2, special_tokens=["<|endoftext|>", "<|padding|>"])
tokenizer.train(trainer, data_files)

# And Save it
tokenizer_path = out_path / "byte-level-bpe.tokenizer.json"
tokenizer.save(str(tokenizer_path), pretty=True)

print(f'tokenizer saved at {str(tokenizer_path)}')

================================================
FILE: docker-compose.yml
================================================
version: '3'
services:

  mongo:
    image: mongo
    ports:
      - 127.0.0.1:27017:27017
    environment:
      MONGO_INITDB_ROOT_USERNAME: user
      MONGO_INITDB_ROOT_PASSWORD: password
      MONGO_INITDB_DATABASE: db
    expose:
      - 27017
    networks:
      - omniboard
    volumes:
      - ./data:/data/db

  mongoClientTemp:
   image: mongo:latest
   container_name: mongoClientTemp
   links:
    - mongo:mongo
   command: mongo --host mongo -u user -p password --eval  "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});"
   depends_on:
    - mongo
   networks:
    - omniboard

  omniboard_readonly:
          #image: vivekratnavel/omniboard:latest
    build: https://github.com/lucidrains/omniboard.git
    command: ["--mu", "mongodb://readonly:password@mongo:27017/db"]
    ports:
            - 0.0.0.0:8081:9000
    networks:
      - omniboard
    depends_on:
      - mongo

  omniboard:
          #image: vivekratnavel/omniboard:latest
    build: https://github.com/lucidrains/omniboard.git
    command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"]
    expose:
      - 9000
    networks:
      - omniboard
    depends_on:
      - mongo

  nginx:
    image: dhswt/nginx-basic-auth:1.3
    environment:
      - HTPASSWD=isaac: #put passwd here
      - FORWARD_HOST=omniboard
      - FORWARD_PORT=9000
    networks:
      - omniboard
    depends_on:
      - omniboard
    ports:
            - 0.0.0.0:8080:80
    expose:
      - 8080
networks:
  omniboard:


================================================
FILE: encoders.py
================================================
from tokenizers import Tokenizer
from transformers import GPT2Tokenizer, GPT2TokenizerFast

def fetch_encoder(params):
    no_dataset = params.get('no_dataset', False)
    if no_dataset:
        return None

    dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
    path = dataset["tokenizer_path"]
    is_pretrained = dataset.get("tokenizer_is_pretrained", False)

    if is_pretrained:
        tok = GPT2TokenizerFast.from_pretrained(path)

        # Will add a padding token id of 50257 at run-time
        tok.add_special_tokens({'pad_token': '<|padding|>'})
        return tok

    return Tokenizer.from_file(path)


# GPT2Tokenizer and Tokenizer have different ways of fetching token ids
def encode(encoder, text, gpt=True):
    result = encoder.encode(text)
    if isinstance(result, list):
        return result
    return result.ids


================================================
FILE: export.py
================================================
import tensorflow.compat.v1 as tf

def export_model(estimator, export_dir, params,
                 checkpoint_path=None):


    def serving_input_receiver_fn():
        t = tf.placeholder(dtype=tf.int64,
                            shape=[1, params["n_ctx"]],
                            name='input_example_tensor')
        return tf.estimator.export.ServingInputReceiver(t, t)

    return estimator.export_saved_model(
        export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)

================================================
FILE: inputs.py
================================================
import numpy as np
import tensorflow.compat.v1 as tf
from functools import partial
from data.encoders import encode
import random
import re
import logging
from itertools import cycle
from utils import natural_sort


### IN USE ###

def _get_number_of_documents(filename):
    # extracts number of files from a filename formatted "<name>_<num_documents>.tfrecords."
    # if no pattern is matched, returns None
    match = re.search("_(\d{1,}).tfrecords$", filename)
    return int(match.group(1)) if match is not None else match


def _get_number_of_documents_by_iteration(filename):
    # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename
    # this could be very slow.
    logging.warning(
        "inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length")
    count = 0
    for item in tf.io.tf_record_iterator(filename):
        count += 1
    return count


def _get_skip_index(all_files, n_batches):
    prev_cumsum = 0
    cumsum = 0
    global_n_documents = None
    for count, f in cycle(enumerate(all_files)):
        prev_cumsum = cumsum
        if _get_number_of_documents(f) is not None:
            cumsum += _get_number_of_documents(f)
        elif global_n_documents is None:
            global_n_documents = _get_number_of_documents_by_iteration(f)
            cumsum += global_n_documents
        else:
            cumsum += global_n_documents
        if cumsum == n_batches:
            remainder = 0
            skip_idx = count + 1
        elif cumsum > n_batches:
            remainder = n_batches - prev_cumsum
            skip_idx = count
            break
    return skip_idx, remainder


def _parse_function(example_proto):
    features = {
        "text": tf.VarLenFeature(tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])


def autoregressive_sample_text(params, x):
    vals1 = x[:params["n_ctx"]]
    vals2 = x[1:params["n_ctx"] + 1]

    vals1 = tf.reshape(vals1, [params["n_ctx"]])
    vals2 = tf.reshape(vals2, [params["n_ctx"]])
    vals1 = tf.cast(vals1, dtype=tf.int32)
    vals2 = tf.cast(vals2, dtype=tf.int32)
    return vals1, vals2


def sequential_input(params, global_step=None, eval=False):
    """
    Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:

        - has the number of documents for each tfrecord file encoded in the title in the format
          <name>_<n_documents>.tfrecords.

          OR

        - has a fixed number of documents per tfrecord file.

    If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.
    If this isn't the case, it may result in errors, or some samples being missed.

    This means we can calculate the number of samples we've seen so far using the global step,
    and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.

    If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
    performance, as it results in less repeated data.
    """
    if not eval:
        assert global_step is not None
    logging.warning(
        "Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")
    batch_size = params['eval_batch_size' if eval else 'train_batch_size']

    filenames = []
    for dataset_config in params['dataset_configs'].values():  # iterate through each dataset and read params
        path_key = 'path' if not eval else 'eval_path'
        path = dataset_config[path_key]
        filenames.extend(
            tf.io.gfile.glob(path))  # then glob all files that fit the pattern specified in dataset_configs

    filenames = natural_sort(filenames)
    shuffle_filenames = params.get("shuffle_input_filenames", True)
    if shuffle_filenames:
        seed = params.get('seed', 1)  # shuffle deterministically
        random.seed(seed)
        random.shuffle(filenames)

    dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat()  # repeat filenames to infinity

    if not eval:
        # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
        skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[
            "train_batch_size"])  # TODO: fix for > 1 epoch
        dataset = dataset.skip(skip_idx)  # skip to skip idx

        # read tfrecord examples and skip remainder
        dataset = dataset.apply(tf.data.TFRecordDataset)
        dataset = dataset.skip(remainder)
    else:
        # shuffle filenames if in eval mode
        dataset = dataset.shuffle(len(filenames))
        dataset = dataset.apply(tf.data.TFRecordDataset)

    # parse the tokenized data from the tfrecord files and shuffle
    dataset = dataset.map(_parse_function, num_parallel_calls=1)
    dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)

    # batch data and repeat to infinity
    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
    return dataset.repeat()


def pred_input(params, logger, enc=None,
               path_to_prompt=""):
    unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
               "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
               "researchers was the fact that the unicorns spoke perfect English."

    text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
    tokens = encode(enc, text)

    if len(tokens) > params["n_ctx"]:
        logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
        tokens = tokens[len(tokens) - params["n_ctx"]:]
    if len(tokens) < params["n_ctx"]:
        tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
    t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
    dataset = tf.data.Dataset.from_tensors(t)

    def _dummy_labels(x):
        return x, x

    dataset = dataset.map(_dummy_labels)
    return dataset


def handle_pred_output(predictions, logger, enc, params, out_name="test"):
    with tf.gfile.Open(f"{out_name}.txt", "w") as f:
        for i, p in enumerate(predictions):
            p = p["outputs"]

            # remove eos + padding ids from output
            idx = np.argmax(p == params['eos_id'])
            if idx > 0:
                p = p[:idx]
            idx = np.argmax(p == params['padding_id'])
            if idx > 0:
                p = p[:idx]

            text = enc.decode(p)
            f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
            f.write(text)
            f.write("\n" + "=" * 80 + "\n")

            logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
            logger.info(text)
            logger.info("\n" + "=" * 80 + "\n")


### DEPRECATED ###

def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
    logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.")
    i = 0 if not eval else 1

    weights = []
    datasets = []

    for dataset in params["datasets"]:
        dataset_id, stitch, datatype, weight = dataset

        assert dataset_id in params[
            'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
        dataset_config = params['dataset_configs'][dataset_id]

        path_key = 'path' if not eval else 'eval_path'
        path = dataset_config[path_key]

        datasets.append(text_dataset(
            tf.io.gfile.glob(path),
            params,
            stitch=stitch,
            datatype=datatype,
            batch=False,
            sample_text_fn=sample_text_fn
        ))

        weights.append(weight)

    batch_size = params['eval_batch_size' if eval else 'train_batch_size']

    seed = params.get('seed', None)
    dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
    return dataset


def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
    seed = params.get('seed', None)
    deterministic = seed is not None
    num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE

    dataset = tf.data.Dataset.from_tensor_slices(files)

    if deterministic:
        dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
    else:
        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))

    if "documents" in datatype:
        def _parse_function(example_proto):
            features = {
                # "hash": tf.VarLenFeature(tf.string),
                "text": tf.VarLenFeature(tf.int64)
            }
            parsed_features = tf.parse_single_example(example_proto, features)
            return parsed_features["text"], parsed_features["text"].dense_shape[0]
    else:
        def _parse_function(example_proto):
            features = {
                "text": tf.VarLenFeature(tf.int64)
            }
            parsed_features = tf.parse_single_example(example_proto, features)
            return parsed_features["text"]  # Assuming the text is not sparse

    dataset = dataset.map(_parse_function, num_parallel_calls=1)

    # Subsample method
    if "documents" in datatype:
        # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
        # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
        # stitch * min(characters_in_text) >= amount
        def _stitch_text(x, y):
            x = tf.sparse.to_dense(x)

            def _get_x(i):
                return tf.gather(x[i], tf.range(y[i]))

            out = _get_x(0)
            eos_id = params['eos_id']

            for i in range(1, stitch):
                out = tf.concat([out, [eos_id], _get_x(i)], axis=0)  # text1<|endoftext|>text2

            return out

        # Hack-y way to stitch together multiple texts

        dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
                                                                                                   num_parallel_calls=num_parallel_calls)

        # Sample 1024(+1) tokens from the stitched together text
        is_random_documents = datatype == "documents_random"
        if sample_text_fn is not None:
            _sample_text = partial(sample_text_fn, random_documents=is_random_documents)
        else:
            _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
            _sample_text = partial(_sample_text, params)

        dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)

    if batch:
        dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)

    dataset = dataset.repeat()

    return dataset


def autoregressive_sample_text_random_documents(params, x):
    seed = params.get('seed', None)
    s = tf.size(x)
    r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
    r1 = tf.range(r, r + params["n_ctx"])
    r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
    r1 = tf.reshape(r1, [params["n_ctx"]])  # Somehow, this makes the compiler happy
    r2 = tf.reshape(r2, [params[
                             "n_ctx"]])  # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
    vals1 = tf.gather(x, r1)
    vals2 = tf.gather(x, r2)

    vals1 = tf.reshape(vals1, [params["n_ctx"]])
    vals2 = tf.reshape(vals2, [params["n_ctx"]])
    vals1 = tf.cast(vals1, dtype=tf.int32)
    vals2 = tf.cast(vals2, dtype=tf.int32)
    return vals1, vals2


def mlm_sample_text(params, x, random_documents=False):
    seed = params.get('seed', None)
    ctx_len = params["n_ctx"]
    assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'

    mask_id = params['mlm_mask_id']
    cls_token_id = params.get('mlm_cls_token_id', None)
    num_tokens = params.get('n_vocab', None)

    mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
    mask_ignore_ids.add(cls_token_id)

    mask_prob = params.get('mlm_mask_prob', 0.15)
    same_token_prob = params.get('mlm_same_token_prob', 0.10)
    random_token_prob = params.get('mlm_random_token_prob', 0.)

    seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)

    if random_documents:
        s = tf.size(x)
        r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
        r1 = tf.range(r, r + seq_len)
        r1 = tf.reshape(r1, [seq_len])
        features = tf.gather(x, r1)
    else:
        features = x[:seq_len]

    # add cls token id if specified by `mlm_cls_token_id`
    if cls_token_id is not None:
        features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)

    features = tf.cast(features, dtype=tf.int32)
    shape = features.shape

    # determine which tokens are mask-able
    can_mask = tf.not_equal(features, 0)
    for ignore_id in mask_ignore_ids:
        can_mask &= tf.not_equal(features, ignore_id)

    # generate boolean mask for masking ids
    mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
    mask_mask &= can_mask

    # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
    replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
                           1 - same_token_prob)

    # randomly replace some tokens with random tokens before masking
    if random_token_prob > 0:
        random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
                                    random_token_prob)
        random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)

        # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
        random_can_mask = tf.not_equal(random_tokens, 0)
        for ignore_id in mask_ignore_ids:
            random_can_mask &= tf.not_equal(random_tokens, ignore_id)

        features = tf.where(random_token_mask & random_can_mask, random_tokens, features)

    # mask the tokens
    mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
    masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)

    # labels will be set to 0 for all non-masked tokens
    labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)

    masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
    return masked_features, labels


================================================
FILE: main.py
================================================
"""GPT-like model in Mesh-Tensorflow"""

from functools import partial
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
from tensorflow.python.tpu import tpu_config, tpu_estimator
from tensorflow_estimator.python.estimator import estimator as estimator_lib
from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \
    check_dataset
from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text
from export import export_model
from model_fns import model_fn
from data.encoders import fetch_encoder
from configs import fetch_model_params
from tasks import task_descriptors
import argparse
import json
import numpy


def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.")
    parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"],
                        help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'")
    parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.")
    parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.")
    parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory "
                                                                   "efficient layout according to MTF auto layout.")
    parser.add_argument("--auto_layout_and_mesh_shape", action="store_true",
                        help="If set, generates and prints the most memory efficient layout and mesh shape according to"
                             " MTF auto layout.")
    parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
                                                           "starts a new training run")
    parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.")
    parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.")
    parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, "
                                                   "defaults to unicorns.",
                        default="")
    parser.add_argument("--check_dataset", action="store_true",
                        help="If set, outputs sample from the dataset and quits.")
    parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.")
    parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling")
    parser.add_argument("--export", action="store_true", help="If set, will export the model.")
    args = parser.parse_args()
    assert args.model is not None, "Model must be set"
    return args


def main(args):
    # Setup logging
    logger = setup_logging(args)

    # Read params of model
    params = fetch_model_params(args.model)

    # Fetch appropriate input functions
Download .txt
gitextract_tmcv9_l6/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   └── feature_request.md
│   └── workflows/
│       └── pytest.yml
├── .gitignore
├── CITATION.bib
├── CODEOWNERS
├── Dockerfile
├── GPTNeo_example_notebook.ipynb
├── LICENSE
├── README.md
├── configs/
│   ├── dataset_configs/
│   │   ├── example.json
│   │   ├── openwebtext2_new_inputs.json
│   │   └── pile.json
│   ├── gpt2_small.json
│   ├── gpt3_13B_256.json
│   ├── gpt3_13B_256_Pile.json
│   ├── gpt3_2-7B_256.json
│   ├── gpt3_6-7B_256.json
│   ├── gpt3_PAR_small_256.json
│   ├── gpt3_XL_256_Pile.json
│   ├── gpt3_large_256.json
│   ├── gpt3_medium_256.json
│   └── gpt3_small_256.json
├── configs.py
├── data/
│   ├── create_tfrecords.py
│   ├── encoders.py
│   └── train_tokenizer.py
├── docker-compose.yml
├── encoders.py
├── export.py
├── inputs.py
├── main.py
├── model_fns.py
├── models/
│   ├── activations.py
│   ├── gpt2/
│   │   └── gpt2.py
│   ├── layers.py
│   └── utils.py
├── optimizers.py
├── requirements.txt
├── run_experiment.py
├── sample.py
├── tasks.py
└── utils.py
Download .txt
SYMBOL INDEX (104 symbols across 17 files)

FILE: configs.py
  function fetch_model_params (line 12) | def fetch_model_params(model):

FILE: data/create_tfrecords.py
  function wikitext_detokenizer (line 45) | def wikitext_detokenizer(string):
  function _int64_feature (line 79) | def _int64_feature(value):
  function write_to_file (line 86) | def write_to_file(writer, data):
  function get_tokenizer (line 97) | def get_tokenizer(args):
  function split_list (line 104) | def split_list(l, n):
  function archive_to_tokens (line 109) | def archive_to_tokens(f, encoder, args, prefix=[]):
  function write_files (line 123) | def write_files(files, files_per, output_dir, out_name, start_no, write_...
  function get_files (line 150) | def get_files(input_dir, filetypes=None):
  function read_checkpoint (line 163) | def read_checkpoint(checkpoint_path, resume_from_checkpoint=True):
  function create_tfrecords (line 175) | def create_tfrecords(params, write_remainder=True, write_every_n_files=1...
  function create_tfrecords_mp (line 245) | def create_tfrecords_mp(files, args):

FILE: data/encoders.py
  function fetch_encoder (line 4) | def fetch_encoder(params):
  function encode (line 24) | def encode(encoder, text):

FILE: encoders.py
  function fetch_encoder (line 4) | def fetch_encoder(params):
  function encode (line 24) | def encode(encoder, text, gpt=True):

FILE: export.py
  function export_model (line 3) | def export_model(estimator, export_dir, params,

FILE: inputs.py
  function _get_number_of_documents (line 14) | def _get_number_of_documents(filename):
  function _get_number_of_documents_by_iteration (line 21) | def _get_number_of_documents_by_iteration(filename):
  function _get_skip_index (line 32) | def _get_skip_index(all_files, n_batches):
  function _parse_function (line 55) | def _parse_function(example_proto):
  function autoregressive_sample_text (line 63) | def autoregressive_sample_text(params, x):
  function sequential_input (line 74) | def sequential_input(params, global_step=None, eval=False):
  function pred_input (line 139) | def pred_input(params, logger, enc=None,
  function handle_pred_output (line 163) | def handle_pred_output(predictions, logger, enc, params, out_name="test"):
  function generic_text (line 188) | def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
  function text_dataset (line 224) | def text_dataset(files, params, stitch, datatype, batch=True, sample_tex...
  function autoregressive_sample_text_random_documents (line 297) | def autoregressive_sample_text_random_documents(params, x):
  function mlm_sample_text (line 316) | def mlm_sample_text(params, x, random_documents=False):

FILE: main.py
  function parse_args (line 21) | def parse_args():
  function main (line 51) | def main(args):

FILE: model_fns.py
  function model_fn (line 15) | def model_fn(features, labels, mode, params):

FILE: models/activations.py
  function _arcsinh (line 20) | def _arcsinh(x):
  function _var (line 24) | def _var(x, init):
  function _pos_var (line 29) | def _pos_var(x, val):
  function _rrelu (line 33) | def _rrelu(x):
  function _elish (line 38) | def _elish(x):
  function get_activation_fn (line 79) | def get_activation_fn(params):

FILE: models/gpt2/gpt2.py
  function block (line 12) | def block(params, scope, layer_num, bias, sequence_dim, memory_length_di...
  function model (line 99) | def model(mtf_features, other_features, params, mesh, variable_dtype, co...

FILE: models/layers.py
  function exists (line 15) | def exists(x):
  function identity (line 19) | def identity(x, *args, **kwargs):
  function is_incremental_inference (line 23) | def is_incremental_inference(context):
  function norm (line 27) | def norm(x, axis, epsilon=1e-8):
  function rezero (line 33) | def rezero(x, scope, dtype):
  function scale_norm (line 39) | def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5,...
  function layer_norm (line 54) | def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5,...
  function linear_attention (line 76) | def linear_attention(q, k, v):
  function causal_linear_attention (line 91) | def causal_linear_attention(q, k, v, eps = 1e-6):
  function linear (line 111) | def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=No...
  function memory_key_values (line 127) | def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_d...
  function attn (line 156) | def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, me...
  function mlp (line 277) | def mlp(x, scope, n_state, *, variable_dtype, params):
  function mlp_glu (line 288) | def mlp_glu(x, scope, n_state, *, variable_dtype, params):
  function axial_positional_emb (line 303) | def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
  function rotary_positional_emb (line 330) | def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype):
  function rotate_half (line 347) | def rotate_half(x):
  function apply_rotary_emb (line 355) | def apply_rotary_emb(x, cos, sin):

FILE: models/utils.py
  function entmax_backward (line 6) | def entmax_backward(explicit_inputs, all_inputs, forward_operations, out...
  function entmax_forward (line 21) | def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
  function entmax (line 55) | def entmax(x, alpha=1.3, dim=None, n_iter=50):
  function entmax_cross_entropy_with_logits (line 65) | def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=...
  function sample_categorical (line 90) | def sample_categorical(x, dim=None):
  function biasmask_attn_weights (line 99) | def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
  function parse_inputs (line 113) | def parse_inputs(mtf_features, other_features):

FILE: optimizers.py
  function clip_by_global_norm (line 9) | def clip_by_global_norm(grads, clip_norm):
  function get_optimizer (line 16) | def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
  class AdamWeightDecayOptimizer (line 95) | class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
    method __init__ (line 98) | def __init__(self,
    method apply_grad (line 116) | def apply_grad(self, grad, var):
    method _do_use_weight_decay (line 168) | def _do_use_weight_decay(self, param_name):

FILE: run_experiment.py
  function get_open_port (line 44) | def get_open_port(lo=8000, hi=8100):
  function train_thread (line 51) | def train_thread(args, tpu, id, q):
  function get_json (line 111) | def get_json(uri, params=None, timeout=15):
  function get_tag_sets (line 117) | def get_tag_sets(base_uri):
  function get_scalar_data (line 126) | def get_scalar_data(base_uri, run, tag):
  function get_run_data (line 132) | def get_run_data(port):
  function main (line 159) | def main(_run):
  function goodbye (line 247) | def goodbye(id):

FILE: sample.py
  function sample_autoregressive (line 8) | def sample_autoregressive(partial_sequences,

FILE: tasks.py
  function lambada_create_tokens_data (line 22) | def lambada_create_tokens_data(params, path):
  function lambada_read_or_create_tokens_data (line 34) | def lambada_read_or_create_tokens_data(params, path):
  function bin_pack (line 42) | def bin_pack(params, tokens_data):
  function lambada_init (line 61) | def lambada_init(params):
  function lambada_get_task_info (line 77) | def lambada_get_task_info(params):
  function lambada_input (line 84) | def lambada_input(params):

FILE: utils.py
  function setup_logging (line 15) | def setup_logging(args):
  function get_batch_size (line 29) | def get_batch_size(params):
  function add_mode_to_params (line 33) | def add_mode_to_params(params, mode):
  function simd_mesh_setup (line 45) | def simd_mesh_setup(params, mesh_shape, layout_rules):
  function remove_batch_from_layout (line 67) | def remove_batch_from_layout(layout):
  function yes_or_no (line 85) | def yes_or_no(question):
  function remove_gs_or_filepath (line 94) | def remove_gs_or_filepath(path):
  function save_config (line 102) | def save_config(params_dict, logdir):
  function expand_attention_types_params (line 132) | def expand_attention_types_params(params_list):
  function get_n_trainable_vars (line 140) | def get_n_trainable_vars(graph):
  function print_dim_names (line 157) | def print_dim_names(graph):
  function get_graph_info (line 177) | def get_graph_info(graph):
  function loss_denominator (line 189) | def loss_denominator(targets, num_microbatches):
  function check_dataset (line 206) | def check_dataset(input_fn, params, global_step=None):
  function auto_layout (line 224) | def auto_layout(graph, mesh_shape, logits, loss):
  function auto_layout_and_mesh_shape (line 229) | def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
  function create_host_call (line 236) | def create_host_call(model_dir):
  function natural_sort (line 289) | def natural_sort(l):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (309K chars).
[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "chars": 713,
    "preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\n**Describe the "
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "chars": 608,
    "preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: feature request\nassignees: ''\n\n---\n\n"
  },
  {
    "path": ".github/workflows/pytest.yml",
    "chars": 885,
    "preview": "# This workflow will install Python dependencies, run tests and lint with a variety of Python versions\n# For more inform"
  },
  {
    "path": ".gitignore",
    "chars": 1270,
    "preview": "# testing\n.test/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Dist"
  },
  {
    "path": "CITATION.bib",
    "chars": 505,
    "preview": "@software{gpt-neo,\n  author       = {Black, Sid and\n                  Gao, Leo and\n                  Wang, Phil and\n    "
  },
  {
    "path": "CODEOWNERS",
    "chars": 23,
    "preview": "* EleutherAI/pm-gptneo\n"
  },
  {
    "path": "Dockerfile",
    "chars": 455,
    "preview": "FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15\n\nWORKDIR /neogpt\n\n# Make RUN commands use `bash --login`:\nSHELL [\""
  },
  {
    "path": "GPTNeo_example_notebook.ipynb",
    "chars": 118381,
    "preview": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"GPTNeo_example_notebook.ipynb\",\n"
  },
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2020 EleutherAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 22999,
    "preview": "# GPT Neo\n\n[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5297715.svg)](https://doi.org/10.5281/zenodo.5297715) [!["
  },
  {
    "path": "configs/dataset_configs/example.json",
    "chars": 195,
    "preview": "{\n\t\"n_vocab\": 32768,\n\t\"path\": \"./tfrecords/openwebtext_*.tfrecords\",\n\t\"eval_path\": \"\",\n\t\"tokenizer_path\": \"./datasets/op"
  },
  {
    "path": "configs/dataset_configs/openwebtext2_new_inputs.json",
    "chars": 271,
    "preview": "{\n\t\"n_vocab\": 50257,\n\t\"path\": \"gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords\",\n\t\"eval_path\": \"gs://neo-dat"
  },
  {
    "path": "configs/dataset_configs/pile.json",
    "chars": 229,
    "preview": "{\n\t\"n_vocab\": 50257,\n\t\"path\": \"gs://neo-datasets/pile/pile_*.tfrecords\",\n\t\"eval_path\": \"gs://neo-datasets/pile_val.tfrec"
  },
  {
    "path": "configs/gpt2_small.json",
    "chars": 900,
    "preview": "{\n    \"n_head\": 6,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0.1,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warm"
  },
  {
    "path": "configs/gpt3_13B_256.json",
    "chars": 1051,
    "preview": "{\n    \"n_head\": 40,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0001,\n    \"lr_decay\": \"cosine\",\n    \"warmu"
  },
  {
    "path": "configs/gpt3_13B_256_Pile.json",
    "chars": 977,
    "preview": "\n{\n    \"n_head\": 40,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0001,\n    \"lr_decay\": \"cosine\",\n    \"warm"
  },
  {
    "path": "configs/gpt3_2-7B_256.json",
    "chars": 959,
    "preview": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00016,\n    \"lr_decay\": \"cosine\",\n    \"warm"
  },
  {
    "path": "configs/gpt3_6-7B_256.json",
    "chars": 907,
    "preview": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00012,\n    \"lr_decay\": \"cosine\",\n    \"warm"
  },
  {
    "path": "configs/gpt3_PAR_small_256.json",
    "chars": 970,
    "preview": "{\n    \"n_head\": 12,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warmu"
  },
  {
    "path": "configs/gpt3_XL_256_Pile.json",
    "chars": 974,
    "preview": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0002,\n    \"lr_decay\": \"cosine\",\n    \"warmu"
  },
  {
    "path": "configs/gpt3_large_256.json",
    "chars": 1007,
    "preview": "{\n    \"n_head\": 16,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00025,\n    \"lr_decay\": \"cosine\",\n    \"warm"
  },
  {
    "path": "configs/gpt3_medium_256.json",
    "chars": 916,
    "preview": "{\n    \"n_head\": 16,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0003,\n    \"lr_decay\": \"cosine\",\n    \"warmu"
  },
  {
    "path": "configs/gpt3_small_256.json",
    "chars": 936,
    "preview": "{\n    \"n_head\": 12,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warmu"
  },
  {
    "path": "configs.py",
    "chars": 1776,
    "preview": "import json\nfrom pathlib import Path\nfrom collections import defaultdict\n\nDATASETS = {}\n\nfor path in Path(\"configs/datas"
  },
  {
    "path": "data/create_tfrecords.py",
    "chars": 11905,
    "preview": "import argparse\nimport os\nfrom pathlib import Path\n\nimport ftfy\nimport tensorflow as tf\nfrom lm_dataformat import Reader"
  },
  {
    "path": "data/encoders.py",
    "chars": 875,
    "preview": "from tokenizers import Tokenizer\nfrom transformers import GPT2Tokenizer, GPT2TokenizerFast\n\ndef fetch_encoder(params):\n "
  },
  {
    "path": "data/train_tokenizer.py",
    "chars": 2375,
    "preview": "import os\nimport random\nimport argparse\nimport shutil\nfrom glob import glob\nfrom pathlib import Path\n\nfrom lm_dataformat"
  },
  {
    "path": "docker-compose.yml",
    "chars": 1544,
    "preview": "version: '3'\nservices:\n\n  mongo:\n    image: mongo\n    ports:\n      - 127.0.0.1:27017:27017\n    environment:\n      MONGO_"
  },
  {
    "path": "encoders.py",
    "chars": 885,
    "preview": "from tokenizers import Tokenizer\nfrom transformers import GPT2Tokenizer, GPT2TokenizerFast\n\ndef fetch_encoder(params):\n "
  },
  {
    "path": "export.py",
    "chars": 501,
    "preview": "import tensorflow.compat.v1 as tf\n\ndef export_model(estimator, export_dir, params,\n                 checkpoint_path=None"
  },
  {
    "path": "inputs.py",
    "chars": 15608,
    "preview": "import numpy as np\nimport tensorflow.compat.v1 as tf\nfrom functools import partial\nfrom data.encoders import encode\nimpo"
  },
  {
    "path": "main.py",
    "chars": 11270,
    "preview": "\"\"\"GPT-like model in Mesh-Tensorflow\"\"\"\n\nfrom functools import partial\nimport mesh_tensorflow as mtf\nimport tensorflow.c"
  },
  {
    "path": "model_fns.py",
    "chars": 14730,
    "preview": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.python.tpu import tpu_estimator\nimport m"
  },
  {
    "path": "models/activations.py",
    "chars": 3666,
    "preview": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport random\n\nBASE_FNS = {'gelu': mtf.gelu,\n           "
  },
  {
    "path": "models/gpt2/gpt2.py",
    "chars": 10173,
    "preview": "\"\"\"GPT-like model in Mesh-Tensorflow\"\"\"\nimport tensorflow.compat.v1 as tf\nimport mesh_tensorflow.transformer as mtf_tran"
  },
  {
    "path": "models/layers.py",
    "chars": 15197,
    "preview": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport math\nimport mesh_tensorflow.transformer as mtf_tr"
  },
  {
    "path": "models/utils.py",
    "chars": 4373,
    "preview": "import tensorflow as tf\nimport mesh_tensorflow as mtf\nfrom functools import partial\n\n\ndef entmax_backward(explicit_input"
  },
  {
    "path": "optimizers.py",
    "chars": 6633,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\n"
  },
  {
    "path": "requirements.txt",
    "chars": 235,
    "preview": "google-api-python-client\njsonlines\nlm_dataformat\nmesh-tensorflow==0.1.18\nnumpy\noauth2client\nortools\npytest\nsacred\ntensor"
  },
  {
    "path": "run_experiment.py",
    "chars": 9794,
    "preview": "import atexit\nimport sacred\nimport argparse\nimport time\nimport math\nimport subprocess\nimport shutil\nimport os\nimport jso"
  },
  {
    "path": "sample.py",
    "chars": 9324,
    "preview": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport mesh_tensorflow.transformer as mtf_transformer\n\nf"
  },
  {
    "path": "tasks.py",
    "chars": 4043,
    "preview": "import os.path\nimport json\nimport requests\nimport numpy as np\nimport ftfy\nfrom data.encoders import fetch_encoder, encod"
  },
  {
    "path": "utils.py",
    "chars": 10118,
    "preview": "import re\nfrom urllib.parse import urlparse\nfrom shutil import rmtree\nimport logging\nimport os\nfrom pathlib import Path\n"
  }
]

About this extraction

This page contains the full source code of the EleutherAI/gpt-neo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (285.4 KB), approximately 75.8k tokens, and a symbol index with 104 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!