[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Proposed solution**\nIf you have an idea for how we can fix this problem, describe it here. \n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Environment (please complete the following information):**\n - GPUs:\n- Configs:\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: feature request\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/workflows/pytest.yml",
    "content": "# This workflow will install Python dependencies, run tests and lint with a variety of Python versions\n# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions\n\nname: Python package\n\non:\n  push:\n    branches: [ master ]\n  pull_request:\n    branches: [ master ]\n\njobs:\n  build:\n\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [3.6, 3.7]\n\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v2\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install pytest\n        if [ -f requirements.txt ]; then pip install -r requirements.txt; fi\n    - name: Test with pytest\n      run: |\n        pytest\n"
  },
  {
    "path": ".gitignore",
    "content": "# testing\n.test/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\nlogs/\n*.log\ntest_*\ntest/\n.vscode\n\n\nrun_configs/\n"
  },
  {
    "path": "CITATION.bib",
    "content": "@software{gpt-neo,\n  author       = {Black, Sid and\n                  Gao, Leo and\n                  Wang, Phil and\n                  Leahy, Connor and\n                  Biderman, Stella},\n  title        = {{GPT-Neo: Large Scale Autoregressive Language \n                   Modeling with Mesh-Tensorflow}},\n  month        = mar,\n  year         = 2021,\n  publisher    = {Zenodo},\n  version      = {1.0},\n  doi          = {10.5281/zenodo.5297715},\n  url          = {https://doi.org/10.5281/zenodo.5297715}\n}\n"
  },
  {
    "path": "CODEOWNERS",
    "content": "* EleutherAI/pm-gptneo\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15\n\nWORKDIR /neogpt\n\n# Make RUN commands use `bash --login`:\nSHELL [\"/bin/bash\", \"--login\", \"-c\"]\nENV DEBIAN_FRONTEND=noninteractive \nRUN apt-get update -y && apt-get install tmux -y\nRUN conda install gcc_linux-64 gxx_linux-64 -y \nADD requirements.txt .\nRUN pip install -r requirements.txt \nRUN apt-get install screen htop -y\nRUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15\n\nCMD tmux"
  },
  {
    "path": "GPTNeo_example_notebook.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"GPTNeo_example_notebook.ipynb\",\n      \"provenance\": [],\n      \"collapsed_sections\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"accelerator\": \"TPU\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"J0i5MRP0SV8D\"\n      },\n      \"source\": [\n        \"Welcome to the colab notebook for [GPTNeo](https://github.com/EleutherAI/GPTNeo) - a fully open source implementation of GPT like models for mesh-tensorflow by [EleutherAI](eleuther.ai).\\n\",\n        \"\\n\",\n        \"Our library provides training and inference for GPT models up to GPT3 sizes on both TPUs and GPUs. \\n\",\n        \"\\n\",\n        \"In this notebook we walk you through TPU training (or finetuning!) and sampling using the freely available colab TPUs.\\n\",\n        \"\\n\",\n        \"If you find our repo useful, come join [our discord](https://discord.gg/BK2v3EJ) and say hi! 😬\\n\",\n        \"\\n\",\n        \"Before we get going - make sure you are running this notebook with a TPU available. Go to Runtime -> Change Runtime Type and select 'TPU' under hardware accelerator.\\n\",\n        \"\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"K-53qkZV6Lv9\",\n        \"cellView\": \"form\"\n      },\n      \"source\": [\n        \"#@title Setup\\n\",\n        \"%tensorflow_version 2.x\\n\",\n        \"!git clone https://github.com/EleutherAI/GPTNeo\\n\",\n        \"%cd GPTNeo\\n\",\n        \"!pip3 install -q -r requirements.txt\\n\",\n        \"pretrained_model = None\\n\",\n        \"dataset = None\\n\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"M0R1owh2qvp8\"\n      },\n      \"source\": [\n        \"## Set Up Google Cloud\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0PmzM4dy7diP\"\n      },\n      \"source\": [\n        \"To train on TPUs we need to store our data on a google cloud bucket - as TPUs can't read from local filesystems.\\n\",\n        \"\\n\",\n        \"You can set up a bucket by signing up for a free trial here: https://console.cloud.google.com/\\n\",\n        \"\\n\",\n        \"Make a bucket at https://console.cloud.google.com/storage and come back when that's done.\\n\",\n        \"\\n\",\n        \"Make sure to select 'Uniform' access control when setting up the bucket, or the colab notebook won't have the required permissions to read from it.\\n\",\n        \"\\n\",\n        \"The next cell sets up google authentication and gives the notebook read and write access to your bucket.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"71bQUjPA7qvj\"\n      },\n      \"source\": [\n        \"from google.colab import auth\\n\",\n        \"auth.authenticate_user()\\n\",\n        \"!gcloud init\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"Cr_c6A2NBK5i\",\n        \"cellView\": \"form\"\n      },\n      \"source\": [\n        \"path_to_cloud_bucket = 'gs://your-cloud-bucket/' #@param {type:\\\"string\\\"}\"\n      ],\n      \"execution_count\": 3,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"EZGbzUPD0tad\"\n      },\n      \"source\": [\n        \"## Set Up Dataset\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"R918l14UhrBR\"\n      },\n      \"source\": [\n        \"We first need to download and tokenize a dataset. If you just want to sample from a pretrained model, you can skip this step and move on to the `Pretrained Model` section.\\n\",\n        \"\\n\",\n        \"You can choose from:\\n\",\n        \"\\n\",\n        \"*   Sampling Only - choose this option if you only wish to sample from our trained models, then move on to the `Pretrained Model` section.\\n\",\n        \"\\n\",\n        \"*   OpenWebText - an opensource clone of OpenAI's WebText dataset, the original training data of GPT2.\\n\",\n        \"\\n\",\n        \"*   YoutubeSubtitles - a dataset of subtitles scraped from youtube videos.\\n\",\n        \"\\n\",\n        \"* Hackernews - comments scraped from hackernews\\n\",\n        \"\\n\",\n        \"* NIHExporter - Data relating to various projects from the national institute of health.\\n\",\n        \"\\n\",\n        \"* Custom - if this option is chosen you will be prompted to enter the path to your own dataset. It should be a directory containing .txt or .jsonl files.\\n\",\n        \"\\n\",\n        \"All these datasets are from EleutherAI's side project - [The Pile™](https://github.com/EleutherAI/The-Pile) - an effort to gather a general purpose, diverse and open source plain text dataset large enough to train 1T+ parameter language models.\\n\",\n        \"\\n\",\n        \"Even the smallest datasets are fairly large files, so this step will likely take a while. Select a dataset in the next cell, then run the next two cells, and go grab a snack and a cup of tea 😊\\n\",\n        \"\\n\",\n        \"Alternatively, you can provide your own dataset in the form of a folder or gzip archive of .txt files. Simply select 'Custom' below and follow input the path to your data and the name of your dataset when prompted.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"pM8jP3Am_hsx\",\n        \"cellView\": \"form\"\n      },\n      \"source\": [\n        \"# Select a Dataset:\\n\",\n        \"import os\\n\",\n        \"dataset = 'Sampling_Only' #@param [\\\"Sampling_Only\\\", \\\"OpenWebText\\\", \\\"YoutubeSubtitles\\\", \\\"HackerNews\\\", \\\"NIHExporter\\\", \\\"Custom\\\"]\\n\",\n        \"\\n\",\n        \"if dataset == \\\"Sampling_Only\\\":\\n\",\n        \"  pass\\n\",\n        \"elif dataset == 'OpenWebText':\\n\",\n        \"  !wget https://the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar -O openwebtext.tar.xz\\n\",\n        \"  !tar xf openwebtext.tar.xz\\n\",\n        \"  dataset_path = \\\"openwebtext\\\"\\n\",\n        \"  dataset_name = dataset_path\\n\",\n        \"  out_name = dataset_name + \\\"_tokenized\\\"\\n\",\n        \"elif dataset == 'YoutubeSubtitles':\\n\",\n        \"  os.makedirs('data', exist_ok=True)\\n\",\n        \"  !wget https://the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst -O data/yt_subs.jsonl.zst\\n\",\n        \"  dataset_path = 'data'\\n\",\n        \"  dataset_name = 'ytsubs'\\n\",\n        \"  out_name = dataset_name + \\\"_tokenized\\\"\\n\",\n        \"elif dataset == 'HackerNews':\\n\",\n        \"  os.makedirs('data', exist_ok=True)\\n\",\n        \"  !wget https://the-eye.eu/public/AI/pile_preliminary_components/hn.tar.gz -O data/hn.tar.gz\\n\",\n        \"  dataset_path = 'data'\\n\",\n        \"  dataset_name = 'hackernews'\\n\",\n        \"  out_name = dataset_name + \\\"_tokenized\\\"\\n\",\n        \"elif dataset == \\\"NIHExporter\\\":\\n\",\n        \"  os.makedirs('data', exist_ok=True)\\n\",\n        \"  !wget https://the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst -O data/NIH_ExPORTER_awarded_grant_text.jsonl.zst\\n\",\n        \"  dataset_path = 'data'\\n\",\n        \"  os.system('mv NIH_ExPORTER_awarded_grant_text.jsonl.zst ./data')\\n\",\n        \"  dataset_name = 'nihexporter'\\n\",\n        \"  out_name = dataset_name + \\\"_tokenized\\\"\\n\",\n        \"elif dataset == \\\"Custom\\\":\\n\",\n        \"  dataset_path = input('Enter the path to the folder containing your data: ')\\n\",\n        \"  dataset_name = input('Enter the name of your dataset: ')\\n\",\n        \"  out_name = dataset_name + \\\"_tokenized\\\"\\n\",\n        \"else:\\n\",\n        \"  raise NotImplementedError('please select from available options: [\\\"OpenWebText\\\", \\\"YoutubeSubtitles\\\", \\\"HackerNews\\\", \\\"NIHExporter\\\", \\\"Custom\\\"]')\\n\"\n      ],\n      \"execution_count\": 4,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"zMl1cHtN5I_W\"\n      },\n      \"source\": [\n        \"### Tokenize and Upload Data\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"6IBIompTJaqm\"\n      },\n      \"source\": [\n        \"Now tokenize the dataset and copy it over to your google cloud bucket. You may skip this step if you are sampling from a pre-trained model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"Pq5u0WUSJWwz\",\n        \"cellView\": \"both\"\n      },\n      \"source\": [\n        \"# Tokenize Data\\n\",\n        \"!python data/create_tfrecords.py --input_dir /content/GPTNeo/$dataset_path --name $dataset_name --files_per 1000 --output_dir $out_name --write_dataset_config --processes 1\\n\",\n        \"\\n\",\n        \"# copy the data to your bucket\\n\",\n        \"if not path_to_cloud_bucket.endswith('/'):\\n\",\n        \"  path_to_cloud_bucket += '/'\\n\",\n        \"copy_loc = path_to_cloud_bucket + \\\"datasets/\\\" + dataset\\n\",\n        \"!gsutil -m cp -r /content/GPTNeo/$out_name $copy_loc\\n\",\n        \"!gsutil ls $path_to_cloud_bucket\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"NhvmTFD7b_fb\"\n      },\n      \"source\": [\n        \"Before starting training - you'll need to edit your dataset & model configs to point to your buckets / data. You need to do this even if you are sampling from a pre-trained model.\\n\",\n        \"\\n\",\n        \"*   First change the writefile path to point to your chosen dataset - e.g `%%writefile configs/dataset_configs/ytsubs.json`\\n\",\n        \"*   Change the \\\"path\\\" field to point to your cloud bucket location - e.g `gs://neo_lmdatasets/datasets/ytsubs_*.tfrecords`\\n\",\n        \"* Change `dataset_name` in `%%writefile configs/dataset_configs/dataset_name.json` to the name of your chosen dataset.\\n\",\n        \"* Once you've made the edits, then run the cell below to overwrite the existing files.\\n\",\n        \"\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"MCsZP48vavCP\"\n      },\n      \"source\": [\n        \"%%writefile configs/dataset_configs/Sampling_Only.json\\n\",\n        \"\\n\",\n        \"{\\n\",\n        \"  \\\"path\\\": \\\"gs://eleutherai/datasets/Sampling_Only/Sampling_Only*.tfrecords\\\",\\n\",\n        \"  \\\"eval_path\\\": \\\"\\\",\\n\",\n        \"  \\\"n_vocab\\\": 50256,\\n\",\n        \"  \\\"tokenizer_is_pretrained\\\": true,\\n\",\n        \"  \\\"tokenizer_path\\\": \\\"gpt2\\\",\\n\",\n        \"  \\\"eos_id\\\": 50256,\\n\",\n        \"  \\\"padding_id\\\": 50257\\n\",\n        \"}\\n\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"dH0x3dI9j85P\"\n      },\n      \"source\": [\n        \"## Set Model Configs\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"I6GnCgAkB7GQ\"\n      },\n      \"source\": [\n        \"The model below is identical to our pretrained GPT3XL model (1.3B Params). \\n\",\n        \"\\n\",\n        \"If you want to use a smaller model, you can modify any of the config files in ../configs/ ending in _8.json, all of which are designed to train on tpu-v8s.\\n\",\n        \"\\n\",\n        \"For a more detailed breakdown on what each item in the configuration file means - please read through our training and config guides in our [github README](https://github.com/EleutherAI/GPTNeo#training-guide). \\n\",\n        \"\\n\",\n        \"You'll want to change the first item in the `datasets` list to the name of your chosen dataset. (the filename minus .json in ./configs/dataset_configs)\\n\",\n        \"\\n\",\n        \"You'll also want to modify the `model_path` field to point to your google cloud bucket, so checkpoints get saved to there.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"L9hUDdokiWj6\"\n      },\n      \"source\": [\n        \"%%writefile configs/GPT3_XL.json\\n\",\n        \"\\n\",\n        \"{\\n\",\n        \"    \\\"n_head\\\": 16,\\n\",\n        \"    \\\"n_vocab\\\": 50257,\\n\",\n        \"    \\\"embed_dropout\\\": 0,\\n\",\n        \"    \\\"lr\\\": 0.0002,\\n\",\n        \"    \\\"lr_decay\\\": \\\"cosine\\\",\\n\",\n        \"    \\\"warmup_steps\\\": 3000,\\n\",\n        \"    \\\"beta1\\\": 0.9,\\n\",\n        \"    \\\"beta2\\\": 0.95,\\n\",\n        \"    \\\"epsilon\\\": 1e-8,\\n\",\n        \"    \\\"opt_name\\\": \\\"adam\\\",\\n\",\n        \"    \\\"weight_decay\\\": 0,\\n\",\n        \"    \\\"train_batch_size\\\": 256,\\n\",\n        \"    \\\"attn_dropout\\\": 0,\\n\",\n        \"    \\\"train_steps\\\": 600000,\\n\",\n        \"    \\\"eval_steps\\\": 0,\\n\",\n        \"    \\\"predict_steps\\\": 1,\\n\",\n        \"    \\\"res_dropout\\\": 0,\\n\",\n        \"    \\\"eval_batch_size\\\": 4,\\n\",\n        \"    \\\"predict_batch_size\\\": 1,\\n\",\n        \"    \\\"iterations\\\": 100,\\n\",\n        \"    \\\"n_embd\\\": 2048,\\n\",\n        \"    \\\"datasets\\\": [[\\\"pile\\\", null, null, null]],\\n\",\n        \"    \\\"model\\\": \\\"GPT\\\",\\n\",\n        \"    \\\"model_path\\\": \\\"gs://eleutherai/GPT3_XL\\\",\\n\",\n        \"    \\\"n_ctx\\\": 2048,\\n\",\n        \"    \\\"n_layer\\\": 24,\\n\",\n        \"    \\\"scale_by_depth\\\": true,\\n\",\n        \"    \\\"scale_by_in\\\": false,\\n\",\n        \"    \\\"attention_types\\\" :  [[[\\\"global\\\", \\\"local\\\"],12]],\\n\",\n        \"    \\\"mesh_shape\\\": \\\"x:4,y:2\\\",\\n\",\n        \"    \\\"layout\\\": \\\"intermediate_expanded:x,heads:x,vocab:n_vocab,memory_length:y,embd:y\\\",\\n\",\n        \"    \\\"activation_function\\\": \\\"gelu\\\",\\n\",\n        \"    \\\"recompute_grad\\\": true,\\n\",\n        \"    \\\"gradient_clipping\\\": 1.0,\\n\",\n        \"    \\\"tokens_per_mb_per_replica\\\": 2048,\\n\",\n        \"    \\\"precision\\\": \\\"bfloat16\\\"\\n\",\n        \"}\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"GWK9MJqwcXKn\"\n      },\n      \"source\": [\n        \"## Training from Scratch\\n\",\n        \"\\n\",\n        \"Now we will begin to train the model. If no previous model is found in \\\"model_path\\\", the model will start training from scratch. If you'd prefer to finetune from pretrained, skip to the `Finetune a Pretrained Model` section.\\n\",\n        \"\\n\",\n        \"If everything's set up correctly, you can now run the main.py function to start training!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"VUtrysOSBzjJ\"\n      },\n      \"source\": [\n        \"!python3 main.py --model colab_XL --steps_per_checkpoint 500 --tpu colab\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"koKQHA5ikCvD\"\n      },\n      \"source\": [\n        \"## Pretrained Model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0QZv4_pnkk26\"\n      },\n      \"source\": [\n        \"If you want to sample from or finetune a pretrained model, EleutherAI has pretrained two models for release. One with [1.3B parameters](https://the-eye.eu/public/AI/gptneo-release/GPT3_XL/), and another with [2.7B](https://the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/). \\n\",\n        \"\\n\",\n        \"Select an option below to download the weights locally. You will then need to upload them to your cloud bucket in order to finetune from them. If the download command isn't working, try the commented out code to download from a different source.\\n\",\n        \"\\n\",\n        \"The 2-7B model likely won't fit into the colab TPUs memory, and you may have to get some larger pods to finetune from it.\\n\",\n        \"\\n\",\n        \"Sampling from it, however, works just fine.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"lgTG1ammqGB0\",\n        \"cellView\": \"form\"\n      },\n      \"source\": [\n        \"# @title Download pretrained model weights:\\n\",\n        \"pretrained_model = 'GPT3_2-7B' #@param [\\\"GPT3_XL\\\", \\\"GPT3_2-7B\\\"]\\n\",\n        \"!wget -m -np -c -U \\\"eye02\\\" -w 2 -R \\\"index.html*\\\" \\\"https://the-eye.eu/public/AI/gptneo-release/$pretrained_model/\\\"\\n\",\n        \"path_to_local_weights = f\\\"/content/GPTNeo/the-eye.eu/public/AI/gptneo-release/{pretrained_model}\\\"\\n\",\n        \"\\n\",\n        \"# URL = f\\\"http://eaidata.bmk.sh/data/gptneo-release/{pretrained_model}/\\\"\\n\",\n        \"# FOLDER_NAME = \\\"GPT3_XL\\\"\\n\",\n        \"# !curl $URL | grep -i \\\"</a>\\\" | sed -n 's/.*href=\\\"\\\\([^\\\"]*\\\\).*/\\\\1/p' | sed \\\"s|^|$URL|\\\" | xargs -n 1 -P 4 wget -P $pretrained_model\\n\",\n        \"# path_to_local_weights = pretrained_model\\n\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"GU3BDNJN_ZXE\"\n      },\n      \"source\": [\n        \"# upload to your bucket\\n\",\n        \"bucket_base = \\\"gs://\\\" + path_to_cloud_bucket.replace('gs://', '').split('/')[0]\\n\",\n        \"!gsutil -m cp -r $path_to_local_weights $bucket_base\"\n      ],\n      \"execution_count\": 9,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"bnqkKBTOn0ox\"\n      },\n      \"source\": [\n        \"If everything has worked successfully you should now see your model listed in your bucket below.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"80t9MMionm2h\"\n      },\n      \"source\": [\n        \"!gsutil ls $bucket_base\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"QDKL8fCSoApL\"\n      },\n      \"source\": [\n        \"Now we want to make a few modifications to the model config in order to get training / sampling working on colab.\\n\",\n        \"\\n\",\n        \"If you are just sampling from our pretrained models, you can leave the settings as is, run the cell below, then move on to the `Sample from your model` section.\\n\",\n        \"\\n\",\n        \"If finetuning, you can change parameters below. \\n\",\n        \"\\n\",\n        \"* `path_to_model` should point to the model weights location in your cloud bucket, and will default to `$bucket_base/${pretrained_model}` if nothing is entered.\\n\",\n        \"\\n\",\n        \"* `batch_size` is your train batch size - if you're encountering memory errors, try lowering this.\\n\",\n        \"\\n\",\n        \"* `dataset_name` is the name of your dataset, if nothing is entered, this should default to the dataset you selected in the `Prepare Data` section.\\n\",\n        \"\\n\",\n        \"* `mesh_shape` specifies the way the model will be divided up across the TPU cores. We suggest leaving this alone unless you know what you're doing.\\n\",\n        \"\\n\",\n        \"* `train_steps` specifies how many steps you want the model to finetune for. We set this to 1000 for demonstrative purposes but you may need to increase this a little depending on your goals. If you are just sampling from the model, you can leave this as is.\\n\",\n        \"\\n\",\n        \"* `steps_per_checkpoint` specifies how often you want to save model weights during training.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"Laf0slBMDCUj\",\n        \"cellView\": \"form\"\n      },\n      \"source\": [\n        \"# @title Modify config for colab. \\n\",\n        \"  \\n\",\n        \"import json\\n\",\n        \"from pprint import pprint\\n\",\n        \"\\n\",\n        \"path_to_model = \\\"\\\" #@param {type:\\\"string\\\"}\\n\",\n        \"batch_size = 8 #@param {type:\\\"integer\\\"}\\n\",\n        \"dset = \\\"\\\"  #@param {type:\\\"string\\\"}\\n\",\n        \"mesh_shape = \\\"x:4,y:2\\\" #@param {type:\\\"string\\\"}\\n\",\n        \"train_steps = 1000 #@param {type:\\\"integer\\\"}\\n\",\n        \"steps_per_checkpoint = 500 #@param {type:\\\"integer\\\"}\\n\",\n        \"start_step = 400000 if pretrained_model == \\\"GPT3_2-7B\\\" else 362000\\n\",\n        \"\\n\",\n        \"if path_to_model == \\\"\\\":\\n\",\n        \"  path_to_model = f'{bucket_base.strip(\\\"/\\\")}/{pretrained_model}'\\n\",\n        \"print(f'MODEL PATH: {path_to_model}\\\\n')\\n\",\n        \"\\n\",\n        \"if dset == \\\"\\\" and dataset != \\\"Sampling_Only\\\":\\n\",\n        \"  dset = dataset\\n\",\n        \"elif dataset is None and dset == \\\"\\\":\\n\",\n        \"  dset = \\\"pile\\\"\\n\",\n        \"\\n\",\n        \"def pad_to_multiple_of(n, mult):\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  pads n to a multiple of mult\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  extra = n % mult\\n\",\n        \"  if extra > 0:\\n\",\n        \"      n = n + mult - extra\\n\",\n        \"  return n\\n\",\n        \"\\n\",\n        \"with open(f'{path_to_local_weights}/config.json', 'r') as f:\\n\",\n        \"  data = json.load(f)\\n\",\n        \"  pprint(data)\\n\",\n        \"  dset_val = [[dset, None, None, None]] if dset != \\\"\\\" else data[\\\"datasets\\\"]\\n\",\n        \"  mods = {\\n\",\n        \"          \\\"mesh_shape\\\": mesh_shape,\\n\",\n        \"          \\\"layout\\\": \\\"intermediate_expanded:x,heads:x,memory_length:y,embd:y\\\",\\n\",\n        \"          \\\"model_path\\\": path_to_model,\\n\",\n        \"          \\\"datasets\\\": dset_val,\\n\",\n        \"          \\\"train_steps\\\": start_step + train_steps,\\n\",\n        \"          \\\"eval_steps\\\": 0,\\n\",\n        \"          \\\"train_batch_size\\\": batch_size,\\n\",\n        \"          \\\"predict_batch_size\\\": batch_size\\n\",\n        \"        }\\n\",\n        \"  data.update(mods)\\n\",\n        \"  print('\\\\n--->\\\\n')\\n\",\n        \"  pprint(data)\\n\",\n        \"  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\\n\",\n        \"    json.dump(data, outfile, indent=2)\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fPwwbPCA6O7r\"\n      },\n      \"source\": [\n        \"### Begin Fine-Tuning\\n\",\n        \"\\n\",\n        \"If you are fine-tuning the pretrained model, this line of code will begin the training.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"0YlaHzyXuMaj\"\n      },\n      \"source\": [\n        \"!python3 main.py --model $pretrained_model --steps_per_checkpoint $steps_per_checkpoint --tpu colab\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"I_HxtEmBGTGT\"\n      },\n      \"source\": [\n        \"### Sample from your model\\n\",\n        \"\\n\",\n        \"Once training is finished, (or your pretrained model is on your bucket), you can run the same command with the --predict flag to sample from your model.\\n\",\n        \"\\n\",\n        \"To pass in a prompt, save it to a .txt file, and pass in the name of the file with the --prompt flag.\\n\",\n        \"\\n\",\n        \"use the cell below to enter your prompt, and run it to save it to example_prompt.txt.\\n\",\n        \"\\n\",\n        \"You may need to decrease the predict batch size in your config if you're facing OOM errors.\\n\",\n        \"\\n\",\n        \"Let's see if the GPTNeo model can finish coding itself, with a sample prompt consisting of the beginning of a `torch.nn.Module`:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"CQE1Y5wPFx7h\",\n        \"outputId\": \"e1a92c0c-18ee-4014-a0b8-d67161384940\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        }\n      },\n      \"source\": [\n        \"%%writefile example_prompt.txt\\n\",\n        \"\\n\",\n        \"class GPT(nn.Module):\\n\",\n        \"    \\\"\\\"\\\"  the full GPT language model, with a context size of block_size \\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"    def __init__(self, config):\\n\",\n        \"        super().__init__()\\n\",\n        \"\\n\",\n        \"        # input embedding stem\\n\",\n        \"        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\\n\",\n        \"        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\\n\",\n        \"        self.drop = nn.Dropout(config.embd_pdrop)\\n\",\n        \"        # transformer\\n\",\n        \"        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\\n\",\n        \"        # decoder head\\n\",\n        \"        self.ln_f = nn.LayerNorm(config.n_embd)\\n\",\n        \"        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n        \"\\n\",\n        \"        self.block_size = config.block_size\\n\",\n        \"        self.apply(self._init_weights)\\n\",\n        \"\\n\",\n        \"        logger.info(\\\"number of parameters: %e\\\", sum(p.numel() for p in self.parameters()))\"\n      ],\n      \"execution_count\": 13,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Overwriting example_prompt.txt\\n\"\n          ],\n          \"name\": \"stdout\"\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"sf_5E4fHFQIh\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"f3c12a94-7ef8-43c1-a668-6365966d42b4\"\n      },\n      \"source\": [\n        \"!python3 main.py --model $pretrained_model --steps_per_checkpoint 500 --tpu colab --predict --prompt example_prompt.txt\"\n      ],\n      \"execution_count\": 14,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"2021-03-22 12:20:43.411018: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\\n\",\n            \"WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\\n\",\n            \"Instructions for updating:\\n\",\n            \"non-resource variables are not supported in the long term\\n\",\n            \"Current step 400000\\n\",\n            \"Saving config to gs://test-bucket-neo/GPT3_2-7B\\n\",\n            \"2021-03-22 12:20:50.689547: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\\n\",\n            \"2021-03-22 12:20:50.691059: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1\\n\",\n            \"2021-03-22 12:20:50.701975: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\\n\",\n            \"2021-03-22 12:20:50.702051: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (eeb4af61eb99): /proc/driver/nvidia/version does not exist\\n\",\n            \"2021-03-22 12:20:52.229703: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)\\n\",\n            \"Done!\\n\",\n            \"params = defaultdict(<function fetch_model_params.<locals>.<lambda> at 0x7f64ee76fb90>, {'n_head': 20, 'n_vocab': 50257, 'embed_dropout': 0, 'lr': 0.00016, 'lr_decay': 'cosine', 'warmup_steps': 3000, 'beta1': 0.9, 'beta2': 0.95, 'epsilon': 1e-08, 'ada_epsilon1': '1e-30', 'ada_epsilon2': 0.001, 'opt_name': 'adam', 'weight_decay': 0, 'train_batch_size': 16, 'attn_dropout': 0, 'train_steps': 401000, 'lr_decay_end': 300000, 'eval_steps': 0, 'predict_steps': 0, 'res_dropout': 0, 'eval_batch_size': 128, 'predict_batch_size': 4, 'iterations': 500, 'n_embd': 2560, 'datasets': [['pile', None, None, None]], 'model_path': 'gs://test-bucket-neo/GPT3_2-7B', 'n_ctx': 2048, 'n_layer': 32, 'scale_by_depth': True, 'scale_by_in': False, 'attention_types': ['global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local'], 'mesh_shape': 'x:4,y:2', 'layout': 'intermediate_expanded:x,heads:x,memory_length:y,embd:y', 'activation_function': 'gelu', 'recompute_grad': True, 'gradient_clipping': 1.0, 'tokens_per_mb_per_replica': 4096, 'padding_id': 50257, 'eos_id': 50256, 'dataset_configs': {'pile': {'n_vocab': 50257, 'path': 'gs://neo-datasets/pile/pile_*.tfrecords', 'eval_path': 'gs://neo-datasets/pile_val.tfrecords', 'tokenizer_is_pretrained': True, 'tokenizer_path': 'gpt2', 'eos_id': 50256, 'padding_id': 50257}}, 'mlm_training': False, 'causal': True, 'num_cores': 8, 'auto_layout': False, 'auto_layout_and_mesh_shape': False, 'use_tpu': True, 'gpu_ids': ['device:GPU:0'], 'steps_per_checkpoint': 500, 'predict': True, 'model': 'GPT', 'export': False, 'sampling_use_entmax': False, 'moe_layers': None, 'slow_sampling': False})\\n\",\n            \"Using config: {'_model_dir': 'gs://test-bucket-neo/GPT3_2-7B', '_tf_random_seed': None, '_save_summary_steps': 500, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\\n\",\n            \"cluster_def {\\n\",\n            \"  job {\\n\",\n            \"    name: \\\"worker\\\"\\n\",\n            \"    tasks {\\n\",\n            \"      key: 0\\n\",\n            \"      value: \\\"10.82.219.162:8470\\\"\\n\",\n            \"    }\\n\",\n            \"  }\\n\",\n            \"}\\n\",\n            \"isolate_session_state: true\\n\",\n            \", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.82.219.162:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.82.219.162:8470', '_evaluation_master': 'grpc://10.82.219.162:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=500, num_shards=8, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7f64ee774a90>}\\n\",\n            \"_TPUContext: eval_on_tpu True\\n\",\n            \"Predictions generated\\n\",\n            \"Querying Tensorflow master (grpc://10.82.219.162:8470) for TPU system metadata.\\n\",\n            \"2021-03-22 12:20:53.623443: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:373] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.\\n\",\n            \"Initializing TPU system (master: grpc://10.82.219.162:8470) to fetch topology for model parallelism. This might take a while.\\n\",\n            \"Found TPU system:\\n\",\n            \"*** Num TPU Cores: 8\\n\",\n            \"*** Num TPU Workers: 1\\n\",\n            \"*** Num TPU Cores Per Worker: 8\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 6478766768852144079)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1341089584581626564)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -607673649088781696)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -4050793109911027603)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, -6683233089843062258)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -4741539030516422912)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 2164395643386766058)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 3352841220362516620)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 5726423099271110669)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 7316344872981758207)\\n\",\n            \"*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 7432402242254058183)\\n\",\n            \"Calling model_fn.\\n\",\n            \"num_cores_per_replica: 1\\n\",\n            \"computation_shape: [1, 1, 1, 1]\\n\",\n            \"num_replicas: 8\\n\",\n            \"device_assignment.topology.device_coordinates: [[[0 0 0 0]\\n\",\n            \"  [0 0 0 1]\\n\",\n            \"  [1 0 0 0]\\n\",\n            \"  [1 0 0 1]\\n\",\n            \"  [0 1 0 0]\\n\",\n            \"  [0 1 0 1]\\n\",\n            \"  [1 1 0 0]\\n\",\n            \"  [1 1 0 1]]]\\n\",\n            \"device_assignment.core_assignment: [[[0 0 0 0]]\\n\",\n            \"\\n\",\n            \" [[0 0 0 1]]\\n\",\n            \"\\n\",\n            \" [[1 0 0 0]]\\n\",\n            \"\\n\",\n            \" [[1 0 0 1]]\\n\",\n            \"\\n\",\n            \" [[0 1 0 0]]\\n\",\n            \"\\n\",\n            \" [[0 1 0 1]]\\n\",\n            \"\\n\",\n            \" [[1 1 0 0]]\\n\",\n            \"\\n\",\n            \" [[1 1 0 1]]]\\n\",\n            \"2021-03-22 12:21:11.005988: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\\n\",\n            \"device_list = ['/job:worker/task:0/device:CPU:0']\\n\",\n            \"SimdMeshImpl ignoring devices ['', '', '', '', '', '', '', '']\\n\",\n            \"SimdMeshImpl init: Shape[x=4, y=2] LayoutRules{('heads', 'x'), ('embd', 'y'), ('intermediate_expanded', 'x'), ('memory_length', 'y')}\\n\",\n            \"Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f64e9078050>\\n\",\n            \"Create pnum_tensor\\n\",\n            \"Variable gpt2/h0/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h0/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h0/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h0/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h0/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h0/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h1/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h1/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h1/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h1/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h1/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h1/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h10/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h10/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h10/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h10/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h10/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h10/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h11/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h11/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h11/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h11/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h11/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h11/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h12/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h12/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h12/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h12/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h12/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h12/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h13/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h13/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h13/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h13/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h13/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h13/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h14/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h14/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h14/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h14/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h14/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h14/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h15/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h15/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h15/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h15/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h15/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h15/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h16/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h16/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h16/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h16/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h16/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h16/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h17/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h17/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h17/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h17/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h17/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h17/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h18/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h18/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h18/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h18/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h18/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h18/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h19/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h19/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h19/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h19/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h19/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h19/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h2/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h2/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h2/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h2/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h2/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h2/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h20/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h20/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h20/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h20/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h20/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h20/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h21/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h21/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h21/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h21/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h21/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h21/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h22/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h22/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h22/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h22/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h22/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h22/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h23/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h23/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h23/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h23/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h23/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h23/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h24/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h24/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h24/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h24/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h24/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h24/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h25/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h25/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h25/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h25/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h25/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h25/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h26/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h26/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h26/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h26/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h26/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h26/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h27/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h27/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h27/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h27/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h27/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h27/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h28/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h28/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h28/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h28/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h28/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h28/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h29/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h29/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h29/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h29/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h29/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h29/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h3/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h3/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h3/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h3/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h3/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h3/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h30/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h30/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h30/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h30/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h30/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h30/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h31/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h31/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h31/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h31/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h31/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h31/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h4/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h4/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h4/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h4/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h4/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h4/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h5/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h5/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h5/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h5/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h5/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h5/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h6/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h6/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h6/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h6/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h6/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h6/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h7/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h7/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h7/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h7/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h7/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h7/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h8/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h8/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h8/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h8/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h8/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h8/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/h9/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h9/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \\n\",\n            \"Variable gpt2/h9/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h9/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \\n\",\n            \"Variable gpt2/h9/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \\n\",\n            \"Variable gpt2/h9/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \\n\",\n            \"Variable gpt2/wpe                                                     size 5242880      slice_size 2621440      Shape[embed_sequence=2048, embd=2560]                       \\n\",\n            \"Variable gpt2/wte                                                     size 128657920    slice_size 64328960     Shape[vocab=50257, embd=2560]                               \\n\",\n            \"Variable stacked/gpt2/h0/mlp/conv1d_main/c_fc/bias                    size 256000       slice_size 64000        Shape[stacked=25, intermediate_expanded=10240]              \\n\",\n            \"    gpt2/h0/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h1/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h2/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h3/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h4/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h5/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h6/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h7/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h8/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h9/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h10/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h11/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h12/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h13/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h14/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h15/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h16/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h17/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h18/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h19/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h20/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h21/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h22/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h23/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h24/mlp/conv1d_main/c_fc/bias\\n\",\n            \"Variable stacked/gpt2/h0/norm_1/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \\n\",\n            \"    gpt2/h0/norm_1/g\\n\",\n            \"    gpt2/h0/norm_1/b\\n\",\n            \"    gpt2/h0/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h0/norm_2/g\\n\",\n            \"    gpt2/h0/norm_2/b\\n\",\n            \"    gpt2/h0/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h1/norm_1/g\\n\",\n            \"    gpt2/h1/norm_1/b\\n\",\n            \"    gpt2/h1/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h1/norm_2/g\\n\",\n            \"    gpt2/h1/norm_2/b\\n\",\n            \"    gpt2/h1/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h2/norm_1/g\\n\",\n            \"    gpt2/h2/norm_1/b\\n\",\n            \"    gpt2/h2/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h2/norm_2/g\\n\",\n            \"    gpt2/h2/norm_2/b\\n\",\n            \"    gpt2/h2/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h3/norm_1/g\\n\",\n            \"    gpt2/h3/norm_1/b\\n\",\n            \"    gpt2/h3/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h3/norm_2/g\\n\",\n            \"    gpt2/h3/norm_2/b\\n\",\n            \"    gpt2/h3/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h4/norm_1/g\\n\",\n            \"    gpt2/h4/norm_1/b\\n\",\n            \"    gpt2/h4/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h4/norm_2/g\\n\",\n            \"    gpt2/h4/norm_2/b\\n\",\n            \"    gpt2/h4/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h5/norm_1/g\\n\",\n            \"    gpt2/h5/norm_1/b\\n\",\n            \"    gpt2/h5/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h5/norm_2/g\\n\",\n            \"    gpt2/h5/norm_2/b\\n\",\n            \"    gpt2/h5/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h6/norm_1/g\\n\",\n            \"    gpt2/h6/norm_1/b\\n\",\n            \"    gpt2/h6/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h6/norm_2/g\\n\",\n            \"    gpt2/h6/norm_2/b\\n\",\n            \"    gpt2/h6/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h7/norm_1/g\\n\",\n            \"    gpt2/h7/norm_1/b\\n\",\n            \"    gpt2/h7/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h7/norm_2/g\\n\",\n            \"    gpt2/h7/norm_2/b\\n\",\n            \"    gpt2/h7/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h8/norm_1/g\\n\",\n            \"    gpt2/h8/norm_1/b\\n\",\n            \"    gpt2/h8/attn/compute_output_bias/o_b\\n\",\n            \"Variable stacked/gpt2/h17/norm_1/g                                    size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \\n\",\n            \"    gpt2/h17/norm_1/g\\n\",\n            \"    gpt2/h17/norm_1/b\\n\",\n            \"    gpt2/h17/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h17/norm_2/g\\n\",\n            \"    gpt2/h17/norm_2/b\\n\",\n            \"    gpt2/h17/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h18/norm_1/g\\n\",\n            \"    gpt2/h18/norm_1/b\\n\",\n            \"    gpt2/h18/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h18/norm_2/g\\n\",\n            \"    gpt2/h18/norm_2/b\\n\",\n            \"    gpt2/h18/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h19/norm_1/g\\n\",\n            \"    gpt2/h19/norm_1/b\\n\",\n            \"    gpt2/h19/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h19/norm_2/g\\n\",\n            \"    gpt2/h19/norm_2/b\\n\",\n            \"    gpt2/h19/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h20/norm_1/g\\n\",\n            \"    gpt2/h20/norm_1/b\\n\",\n            \"    gpt2/h20/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h20/norm_2/g\\n\",\n            \"    gpt2/h20/norm_2/b\\n\",\n            \"    gpt2/h20/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h21/norm_1/g\\n\",\n            \"    gpt2/h21/norm_1/b\\n\",\n            \"    gpt2/h21/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h21/norm_2/g\\n\",\n            \"    gpt2/h21/norm_2/b\\n\",\n            \"    gpt2/h21/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h22/norm_1/g\\n\",\n            \"    gpt2/h22/norm_1/b\\n\",\n            \"    gpt2/h22/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h22/norm_2/g\\n\",\n            \"    gpt2/h22/norm_2/b\\n\",\n            \"    gpt2/h22/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h23/norm_1/g\\n\",\n            \"    gpt2/h23/norm_1/b\\n\",\n            \"    gpt2/h23/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h23/norm_2/g\\n\",\n            \"    gpt2/h23/norm_2/b\\n\",\n            \"    gpt2/h23/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h24/norm_1/g\\n\",\n            \"    gpt2/h24/norm_1/b\\n\",\n            \"    gpt2/h24/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h24/norm_2/g\\n\",\n            \"    gpt2/h24/norm_2/b\\n\",\n            \"    gpt2/h24/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h25/norm_1/g\\n\",\n            \"    gpt2/h25/norm_1/b\\n\",\n            \"    gpt2/h25/attn/compute_output_bias/o_b\\n\",\n            \"Variable stacked/gpt2/h25/mlp/conv1d_main/c_fc/bias                   size 71680        slice_size 17920        Shape[stacked=7, intermediate_expanded=10240]               \\n\",\n            \"    gpt2/h25/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h26/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h27/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h28/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h29/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h30/mlp/conv1d_main/c_fc/bias\\n\",\n            \"    gpt2/h31/mlp/conv1d_main/c_fc/bias\\n\",\n            \"Variable stacked/gpt2/h25/norm_2/g                                    size 104960       slice_size 52480        Shape[stacked=41, embd=2560]                                \\n\",\n            \"    gpt2/h25/norm_2/g\\n\",\n            \"    gpt2/h25/norm_2/b\\n\",\n            \"    gpt2/h25/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h26/norm_1/g\\n\",\n            \"    gpt2/h26/norm_1/b\\n\",\n            \"    gpt2/h26/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h26/norm_2/g\\n\",\n            \"    gpt2/h26/norm_2/b\\n\",\n            \"    gpt2/h26/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h27/norm_1/g\\n\",\n            \"    gpt2/h27/norm_1/b\\n\",\n            \"    gpt2/h27/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h27/norm_2/g\\n\",\n            \"    gpt2/h27/norm_2/b\\n\",\n            \"    gpt2/h27/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h28/norm_1/g\\n\",\n            \"    gpt2/h28/norm_1/b\\n\",\n            \"    gpt2/h28/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h28/norm_2/g\\n\",\n            \"    gpt2/h28/norm_2/b\\n\",\n            \"    gpt2/h28/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h29/norm_1/g\\n\",\n            \"    gpt2/h29/norm_1/b\\n\",\n            \"    gpt2/h29/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h29/norm_2/g\\n\",\n            \"    gpt2/h29/norm_2/b\\n\",\n            \"    gpt2/h29/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h30/norm_1/g\\n\",\n            \"    gpt2/h30/norm_1/b\\n\",\n            \"    gpt2/h30/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h30/norm_2/g\\n\",\n            \"    gpt2/h30/norm_2/b\\n\",\n            \"    gpt2/h30/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h31/norm_1/g\\n\",\n            \"    gpt2/h31/norm_1/b\\n\",\n            \"    gpt2/h31/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h31/norm_2/g\\n\",\n            \"    gpt2/h31/norm_2/b\\n\",\n            \"    gpt2/h31/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/ln_f/g\\n\",\n            \"    gpt2/ln_f/b\\n\",\n            \"Variable stacked/gpt2/h8/norm_2/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \\n\",\n            \"    gpt2/h8/norm_2/g\\n\",\n            \"    gpt2/h8/norm_2/b\\n\",\n            \"    gpt2/h8/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h9/norm_1/g\\n\",\n            \"    gpt2/h9/norm_1/b\\n\",\n            \"    gpt2/h9/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h9/norm_2/g\\n\",\n            \"    gpt2/h9/norm_2/b\\n\",\n            \"    gpt2/h9/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h10/norm_1/g\\n\",\n            \"    gpt2/h10/norm_1/b\\n\",\n            \"    gpt2/h10/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h10/norm_2/g\\n\",\n            \"    gpt2/h10/norm_2/b\\n\",\n            \"    gpt2/h10/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h11/norm_1/g\\n\",\n            \"    gpt2/h11/norm_1/b\\n\",\n            \"    gpt2/h11/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h11/norm_2/g\\n\",\n            \"    gpt2/h11/norm_2/b\\n\",\n            \"    gpt2/h11/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h12/norm_1/g\\n\",\n            \"    gpt2/h12/norm_1/b\\n\",\n            \"    gpt2/h12/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h12/norm_2/g\\n\",\n            \"    gpt2/h12/norm_2/b\\n\",\n            \"    gpt2/h12/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h13/norm_1/g\\n\",\n            \"    gpt2/h13/norm_1/b\\n\",\n            \"    gpt2/h13/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h13/norm_2/g\\n\",\n            \"    gpt2/h13/norm_2/b\\n\",\n            \"    gpt2/h13/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h14/norm_1/g\\n\",\n            \"    gpt2/h14/norm_1/b\\n\",\n            \"    gpt2/h14/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h14/norm_2/g\\n\",\n            \"    gpt2/h14/norm_2/b\\n\",\n            \"    gpt2/h14/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h15/norm_1/g\\n\",\n            \"    gpt2/h15/norm_1/b\\n\",\n            \"    gpt2/h15/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h15/norm_2/g\\n\",\n            \"    gpt2/h15/norm_2/b\\n\",\n            \"    gpt2/h15/mlp/conv1d_main/c_proj/bias\\n\",\n            \"    gpt2/h16/norm_1/g\\n\",\n            \"    gpt2/h16/norm_1/b\\n\",\n            \"    gpt2/h16/attn/compute_output_bias/o_b\\n\",\n            \"    gpt2/h16/norm_2/g\\n\",\n            \"    gpt2/h16/norm_2/b\\n\",\n            \"    gpt2/h16/mlp/conv1d_main/c_proj/bias\\n\",\n            \"Trainable Variables            count: 200     Total size: 2651307520       Total slice_size: 381853440      \\n\",\n            \"All Variables                  count: 200     Total size: 2651307520       Total slice_size: 381853440      \\n\",\n            \"Counters:\\n\",\n            \"allreduce: 1.68e+10\\n\",\n            \" allreduce/[0]: 5.37e+09\\n\",\n            \"  allreduce/[0]/einsum_op: 5.37e+09\\n\",\n            \" allreduce/[1]: 1.14e+10\\n\",\n            \"  allreduce/[1]/einsum_op: 1.14e+10\\n\",\n            \"  allreduce/[1]/reduce_op: 1.9e+07\\n\",\n            \"einsum: 3.19e+13\\n\",\n            \"einsum_unique: 2.48e+13\\n\",\n            \"output: 2.02e+11\\n\",\n            \" output/AddOperation: 5.68e+10\\n\",\n            \" output/BinaryOpWithBroadcasting: 6.88e+08\\n\",\n            \" output/BroadcastOperation: 5.4e+09\\n\",\n            \" output/ConcatOperation: 2.69e+09\\n\",\n            \" output/Constant: 2.62e+05\\n\",\n            \" output/EinsumOperation: 5.59e+10\\n\",\n            \" output/ImportOperation: 1.31e+05\\n\",\n            \" output/OneHotOperation: 3.33e+09\\n\",\n            \" output/RangeOperation: 3.19e+05\\n\",\n            \" output/ReduceOperation: 2.95e+07\\n\",\n            \" output/ReshapeOperation: 1.01e+10\\n\",\n            \" output/ScalarAddOperation: 5.37e+09\\n\",\n            \" output/ScalarMultiplyOperation: 1.89e+10\\n\",\n            \" output/ShiftOperation: 1.34e+09\\n\",\n            \" output/SlicewiseOperation: 2.73e+10\\n\",\n            \" output/StackedVariable: 2.64e+06\\n\",\n            \" output/StopGradient: 8.05e+09\\n\",\n            \" output/UnstackOperation: 2.64e+06\\n\",\n            \" output/Variable: 3.05e+09\\n\",\n            \" output/WhileLoopOperation: 2.68e+09\\n\",\n            \"output_unique: 1.09e+11\\n\",\n            \" output_unique/AddOperation: 3.1e+10\\n\",\n            \" output_unique/BinaryOpWithBroadcasting: 8.81e+07\\n\",\n            \" output_unique/BroadcastOperation: 5.38e+09\\n\",\n            \" output_unique/ConcatOperation: 1.34e+09\\n\",\n            \" output_unique/Constant: 3.28e+04\\n\",\n            \" output_unique/EinsumOperation: 2.53e+10\\n\",\n            \" output_unique/ImportOperation: 1.64e+04\\n\",\n            \" output_unique/OneHotOperation: 4.16e+08\\n\",\n            \" output_unique/RangeOperation: 4.1e+04\\n\",\n            \" output_unique/ReduceOperation: 1.16e+07\\n\",\n            \" output_unique/ReshapeOperation: 5.37e+09\\n\",\n            \" output_unique/ScalarAddOperation: 2.68e+09\\n\",\n            \" output_unique/ScalarMultiplyOperation: 8.75e+09\\n\",\n            \" output_unique/ShiftOperation: 6.71e+08\\n\",\n            \" output_unique/SlicewiseOperation: 1.75e+10\\n\",\n            \" output_unique/StackedVariable: 8.24e+05\\n\",\n            \" output_unique/StopGradient: 6.71e+09\\n\",\n            \" output_unique/UnstackOperation: 8.24e+05\\n\",\n            \" output_unique/Variable: 2.65e+09\\n\",\n            \" output_unique/WhileLoopOperation: 1.34e+09\\n\",\n            \"variables: 2.65e+09\\n\",\n            \" variables/trainable: 2.65e+09\\n\",\n            \"Done calling model_fn.\\n\",\n            \"TPU job name worker\\n\",\n            \"Graph was finalized.\\n\",\n            \"Restoring parameters from gs://test-bucket-neo/GPT3_2-7B/model.ckpt-400000\\n\",\n            \"Running local_init_op.\\n\",\n            \"Done running local_init_op.\\n\",\n            \"From /usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:840: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\\n\",\n            \"Instructions for updating:\\n\",\n            \"Prefer Variable.assign which has equivalent behavior in 2.X.\\n\",\n            \"Starting infeed thread controller.\\n\",\n            \"Starting outfeed thread controller.\\n\",\n            \"Initialized dataset iterators in 0 seconds\\n\",\n            \"Before copy master to slices.\\n\",\n            \"Done with copy master to slices.\\n\",\n            \"Enqueue next (1) batch(es) of data to infeed.\\n\",\n            \"Dequeue next (1) batch(es) of data from outfeed.\\n\",\n            \"Outfeed finished for iteration (0, 0)\\n\",\n            \"======================================== SAMPLE 0 ========================================\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"class GPT(nn.Module):\\n\",\n            \"    \\\"\\\"\\\"  the full GPT language model, with a context size of block_size \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"    def __init__(self, config):\\n\",\n            \"        super().__init__()\\n\",\n            \"\\n\",\n            \"        # input embedding stem\\n\",\n            \"        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\\n\",\n            \"        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\\n\",\n            \"        self.drop = nn.Dropout(config.embd_pdrop)\\n\",\n            \"        # transformer\\n\",\n            \"        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\\n\",\n            \"        # decoder head\\n\",\n            \"        self.ln_f = nn.LayerNorm(config.n_embd)\\n\",\n            \"        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n            \"\\n\",\n            \"        self.block_size = config.block_size\\n\",\n            \"        self.apply(self._init_weights)\\n\",\n            \"\\n\",\n            \"        logger.info(\\\"number of parameters: %e\\\", sum(p.numel() for p in self.parameters()))\\n\",\n            \"\\n\",\n            \"    def forward(self, input):\\n\",\n            \"        \\\"\\\"\\\" return gpt from position embedding (embedding for position and context)\\\"\\\"\\\"\\n\",\n            \"        return GPT(input, self.pos_emb, self.tok_emb, self.drop, self.ln_f, self.head)\\n\",\n            \"\\n\",\n            \"    def get_type_log_probability(self, input, target, p_type):\\n\",\n            \"        \\\"\\\"\\\" get negative log-likelihood for the current probability (p_type)\\n\",\n            \"        \\\"\\\"\\\"\\n\",\n            \"        embedding = self.tok_emb(input)\\n\",\n            \"        return nn.log_softmax(embedding, dim=1) / sum(input.size(1) for input in input)\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"def update_parameters_for_training(model, input_length, targets,\\n\",\n            \"                                   target_length, context_size, apply_onehot=False):\\n\",\n            \"    \\\"\\\"\\\" update parameters after re-training or training in 2-shot.\\n\",\n            \"\\n\",\n            \"            model.set_params(...)..returns(model_post_training)\\n\",\n            \"            model_post_training: the updated model\\n\",\n            \"    \\\"\\\"\\\"\\n\",\n            \"    if not model.sampler:\\n\",\n            \"        model.reset_params()\\n\",\n            \"    elif model.sampler.get_seed()!= 0 or limit_sampled_sequences(model.sampler.get_seed()):\\n\",\n            \"        if apply_onehot:\\n\",\n            \"            model.reset_params()\\n\",\n            \"\\n\",\n            \"    loss = nn.BCELoss()\\n\",\n            \"    model.loss = loss\\n\",\n            \"    model.disp = model.disp + (1.0 - model.disp) * model.log_prob(input_length, target_length)\\n\",\n            \"    model.mean_disp = model.disp\\n\",\n            \"    model.mean_pos = model.pos\\n\",\n            \"    score = model.log_prob(target_length, target_length)\\n\",\n            \"    if (input_length == target_length):\\n\",\n            \"        # single shot - ignore intro, ilux and outros\\n\",\n            \"        if apply_onehot:\\n\",\n            \"            target[0][0] = '%s %s' % (target_length, target_length)\\n\",\n            \"        else:\\n\",\n            \"            target[0][0] = '%s %d' % (target_length, target_length)\\n\",\n            \"    else:\\n\",\n            \"        # 2-shot - batch one of the input embedding, multi-shot - batch by sequence.\\n\",\n            \"        targets = torch.cat([tuple([chr[0] if chr[0] in target[0] else '?' for chr in target])\\n\",\n            \"                             for target in target_length], 2)\\n\",\n            \"        target_length = len(targets)\\n\",\n            \"\\n\",\n            \"    pos_emb = self.pos_emb(input)\\n\",\n            \"    tok_emb = self.tok_emb(input)\\n\",\n            \"    drop = self.drop(input)\\n\",\n            \"\\n\",\n            \"    head_drop = tok_emb.nonlinearity * drop\\n\",\n            \"\\n\",\n            \"    for ln_f in self.ln_f:\\n\",\n            \"        self.ln_f = nn.LayerNorm(self.n_embd)\\n\",\n            \"        self.ln_f.weight.data.zero_()\\n\",\n            \"        self.ln_f.bias.data.zero_()\\n\",\n            \"\\n\",\n            \"    for block in self.blocks:\\n\",\n            \"        self.head_drop.weight.data.zero_()\\n\",\n            \"        self.head_drop.bias.data.zero_()\\n\",\n            \"\\n\",\n            \"    for i in range(self.n_layer):\\n\",\n            \"        param_tuple = (i, block, head_drop, len(targets), config.init_lstm_c)\\n\",\n            \"        t_pos, t_targets, _ = torch.max(target, param_tuple[0], param_tuple[1])\\n\",\n            \"\\n\",\n            \"        # fast threshold -> 1 will be equal to target, non-zero will not be all 0\\n\",\n            \"        t_pos = t_pos if t_pos == 0 else 1\\n\",\n            \"        t_targets = t_targets if t_targets == 0 else 1\\n\",\n            \"        self.pixel_to_pos = target[t_pos:t_pos+1]\\n\",\n            \"\\n\",\n            \"        # linear decrease\\n\",\n            \"        self.disp_drop = tok_emb.nonlinearity * self.drop(t_pos)\\n\",\n            \"        self.disp_drop.weight.data.zero_()\\n\",\n            \"\\n\",\n            \"        self.weight_reset = torch.zeros(2)\\n\",\n            \"        self.bias_reset = torch.zeros(2)\\n\",\n            \"\\n\",\n            \"        emb_tok_id = model.pixel_to_pos\\n\",\n            \"        weight_last = Embedding(1, config.n_embd)\\n\",\n            \"        self.q = weight_last(emb_tok_id)\\n\",\n            \"        self.q_last = weight_last(self.q)\\n\",\n            \"\\n\",\n            \"        mask_name = '%s/%s/%s_%d' % (config.tok_id, config.pos_id, tok_emb.size(), pos_emb.size())\\n\",\n            \"        self.loss_mask = nn.LogSoftmax(dim=1)\\n\",\n            \"        self.loss_state = nn.Linear(config.n_embd+num_tok_c, config.n_embd)\\n\",\n            \"        self.target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\\n\",\n            \"        self.loss_target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\\n\",\n            \"        self.mask_loss_name = \\\"loss_mask\\\"\\n\",\n            \"        target_to_pos = nn.LogSoftmax(dim=1)\\n\",\n            \"        for i in\\n\",\n            \"\\n\",\n            \"================================================================================\\n\",\n            \"\\n\",\n            \"======================================== SAMPLE 1 ========================================\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"class GPT(nn.Module):\\n\",\n            \"    \\\"\\\"\\\"  the full GPT language model, with a context size of block_size \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"    def __init__(self, config):\\n\",\n            \"        super().__init__()\\n\",\n            \"\\n\",\n            \"        # input embedding stem\\n\",\n            \"        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\\n\",\n            \"        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\\n\",\n            \"        self.drop = nn.Dropout(config.embd_pdrop)\\n\",\n            \"        # transformer\\n\",\n            \"        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\\n\",\n            \"        # decoder head\\n\",\n            \"        self.ln_f = nn.LayerNorm(config.n_embd)\\n\",\n            \"        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n            \"\\n\",\n            \"        self.block_size = config.block_size\\n\",\n            \"        self.apply(self._init_weights)\\n\",\n            \"\\n\",\n            \"        logger.info(\\\"number of parameters: %e\\\", sum(p.numel() for p in self.parameters()))\\n\",\n            \"        # normalization\\n\",\n            \"        self.weight_gpu = nn.Parameter(torch.Tensor(self.weight.size(1).num()))\\n\",\n            \"        self.bias_gpu = nn.Parameter(torch.zeros(1).type(torch.float32))\\n\",\n            \"\\n\",\n            \"    def _init_weights(self):\\n\",\n            \"        num_b = self.head.weight.size(1)\\n\",\n            \"        drop_b = self.head.bias.size(0)\\n\",\n            \"        self.weight = nn.Parameter(torch.Tensor(num_b, drop_b))\\n\",\n            \"        self.bias = nn.Parameter(torch.zeros(drop_b).type(torch.float32))\\n\",\n            \"\\n\",\n            \"    def forward(self, H, g, X): \\n\",\n            \"        \\\"\\\"\\\"  - token-level feed forward\\n\",\n            \"        - Embed Otherwise\\n\",\n            \"            (f) g is ignored for the embeddings, and this is only used to save the\\n\",\n            \"                gpt translation encoder memory.\\n\",\n            \"        \\\"\\\"\\\"\\n\",\n            \"        output = {}\\n\",\n            \"        if self.head.keep:\\n\",\n            \"            X_top = X_top.view(-1, self.emb_size, 1)\\n\",\n            \"\\n\",\n            \"            for j in range(self.head.nheads):\\n\",\n            \"                dX = X_top[:, 0]\\n\",\n            \"                dX = dX.transpose(0, 1)[0]\\n\",\n            \"                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\\n\",\n            \"                X_top = self.head(dX)\\n\",\n            \"                dX = X_top[:, 0]\\n\",\n            \"                dX = dX.transpose(0, 1)[0]\\n\",\n            \"                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\\n\",\n            \"                X_top = self.head(dX)\\n\",\n            \"\\n\",\n            \"            for i in range(self.head.n_layer):\\n\",\n            \"                H = torch.cat([H, self.ln_f(H)[0]]).view(-1)\\n\",\n            \"                if self.drop > 0:\\n\",\n            \"                    g = torch.zeros_like(H.long()).float()\\n\",\n            \"                else:\\n\",\n            \"                    g = H.long()\\n\",\n            \"\\n\",\n            \"                g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\\n\",\n            \"                if self.apply_del_emb:\\n\",\n            \"                    output[j] = g.transpose(0, 1)\\n\",\n            \"                else:\\n\",\n            \"                    output[j] = self.head(g)\\n\",\n            \"\\n\",\n            \"                H = H.transpose(0, 1)\\n\",\n            \"        else:\\n\",\n            \"            X_top = X_top.view(-1, self.emb_size, 1)\\n\",\n            \"            for j in range(self.head.nheads):\\n\",\n            \"                dX = X_top[:, 0]\\n\",\n            \"                dX = dX.transpose(0, 1)[0]\\n\",\n            \"                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\\n\",\n            \"                X_top = self.head(dX)\\n\",\n            \"                dX = X_top[:, 0]\\n\",\n            \"                dX = dX.transpose(0, 1)[0]\\n\",\n            \"                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\\n\",\n            \"                X_top = self.head(dX)\\n\",\n            \"\\n\",\n            \"            for i in range(self.head.n_layer):\\n\",\n            \"                g = torch.cat([self.ln_f(H)[0], g])[0]\\n\",\n            \"                if self.drop > 0:\\n\",\n            \"                    g = torch.zeros_like(g).float()\\n\",\n            \"                else:\\n\",\n            \"                    g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\\n\",\n            \"                if self.apply_del_emb:\\n\",\n            \"                    output[j] = g.transpose(0, 1)\\n\",\n            \"                else:\\n\",\n            \"                    output[j] = self.head(g)\\n\",\n            \"\\n\",\n            \"        output = output[\\\"h\\\"].transpose(0, 1)\\n\",\n            \"        return output\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"================================================================================\\n\",\n            \"\\n\",\n            \"======================================== SAMPLE 2 ========================================\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"class GPT(nn.Module):\\n\",\n            \"    \\\"\\\"\\\"  the full GPT language model, with a context size of block_size \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"    def __init__(self, config):\\n\",\n            \"        super().__init__()\\n\",\n            \"\\n\",\n            \"        # input embedding stem\\n\",\n            \"        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\\n\",\n            \"        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\\n\",\n            \"        self.drop = nn.Dropout(config.embd_pdrop)\\n\",\n            \"        # transformer\\n\",\n            \"        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\\n\",\n            \"        # decoder head\\n\",\n            \"        self.ln_f = nn.LayerNorm(config.n_embd)\\n\",\n            \"        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n            \"\\n\",\n            \"        self.block_size = config.block_size\\n\",\n            \"        self.apply(self._init_weights)\\n\",\n            \"\\n\",\n            \"        logger.info(\\\"number of parameters: %e\\\", sum(p.numel() for p in self.parameters()))\\n\",\n            \"\\n\",\n            \"        self.optimizer = optim.Adam(\\n\",\n            \"            self.head,\\n\",\n            \"            parameters_ub=self.parameters(),\\n\",\n            \"            lam=config.initial_learning_rate\\n\",\n            \"        )\\n\",\n            \"\\n\",\n            \"    def forward(self, input_text):\\n\",\n            \"        \\\"\\\"\\\" the overall model.co: forward pass \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"        limit = self.head.output_size(0)\\n\",\n            \"        head = self.head\\n\",\n            \"        attn = self.head.weight\\n\",\n            \"        # tagwith = self.head.weight\\n\",\n            \"\\n\",\n            \"        block = self.blocks[:,0][self.block_size:,:]\\n\",\n            \"        forward_attn = block(attn)\\n\",\n            \"        forward_text = forward_attn + input_text\\n\",\n            \"        forward_text = conv_block(forward_text)\\n\",\n            \"        forward_text = forward_linear(forward_text)\\n\",\n            \"        forward_text = forward_linear(forward_linear(forward_text))\\n\",\n            \"\\n\",\n            \"        lower_attn = (self.tok_emb(forward_text)).sum(1, keepdim=True)\\n\",\n            \"        # lower_attn = self.tok_emb(1)\\n\",\n            \"\\n\",\n            \"        #rnn_basic_block1 = forward_attn[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:].view(1, 1, self.block_size, -1)\\n\",\n            \"        #rnn_basic_block1 = rnn_tok(forward_text[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0), :].transpose(1, 0, 2) + forward_text[self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:])\\n\",\n            \"        #rnn_basic_block1_drop = nn.Dropout(config.drop_rate)\\n\",\n            \"        #print(rnn_basic_block1_drop.shape)\\n\",\n            \"        #print(rnn_basic_block1.weight.shape)\\n\",\n            \"        # post_drop = rnn_basic_block1_drop.view(1, 1, self.block_size, 1)\\n\",\n            \"        # rnn_part_block1 = forward_attn[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(1, self.block_size, self.head.n_layer)\\n\",\n            \"        # post_drop = post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\\n\",\n            \"\\n\",\n            \"        #rnn_part_block1 = (head(head(numpy.squeeze(forward_text), 1)))[:self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.block_size, self.head.layers_[0].n_layer, -1)\\n\",\n            \"        #rnn_part_block1 = rnn_basic_block1_drop + post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\\n\",\n            \"\\n\",\n            \"        lower_rnn_text = (head(head(numpy.squeeze(forward_text), 1)))[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.head.n_layer, self.block_size, -1)\\n\",\n            \"        lower_rnn = rnn_tok(lower_rnn_text)\\n\",\n            \"\\n\",\n            \"        #attention_layers = self.attention_layer\\n\",\n            \"        #context_attention_layers = self.context_attention_layer\\n\",\n            \"        #attn_context_layers = self.attention_layer + self.context_attention_layer\\n\",\n            \"        #attn_context_layers = self.attention_layer\\n\",\n            \"        #propagation_layers = self.proper_layer + self.context_attention_layer\\n\",\n            \"\\n\",\n            \"        return lower_attn + lower_rnn_text + lower_rnn\\n\",\n            \"\\n\",\n            \"    def backward(self, grad_output, grad_input):\\n\",\n            \"        \\\"\\\"\\\" the model.co: backward pass \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"        grad_weight = torch.matmul(grad_output[self.head.layers_[0].output_size(0)], grad_input.contiguous())\\n\",\n            \"        return grad_weight.view(batch_size, -1, self.head.n_layer), grad_weight.view(batch_size, -1, self.head.n_layer)\\n\",\n            \"\\n\",\n            \"    def clip_gradient(self, grad_input):\\n\",\n            \"        \\\"\\\"\\\" clip gradient \\\"\\\"\\\"\\n\",\n            \"        logger.warning(\\\"clip_gradient: clip(grad_input, 0.0 - 1.0)\\\")\\n\",\n            \"        return grad_input.clamp(0.0 - 1.0).detach().cpu().numpy()\\n\",\n            \"\\n\",\n            \"    def _get_cell(self, name):\\n\",\n            \"        if self.args.tied_base_model:\\n\",\n            \"            return self.head.layers_[name].n_op\\n\",\n            \"\\n\",\n            \"        return self.head.layers_[name]\\n\",\n            \"\\n\",\n            \"    def _get_head(self, head_name):\\n\",\n            \"        if self.args.tied_base_model:\\n\",\n            \"            return head_name\\n\",\n            \"\\n\",\n            \"        return self.head.n_layer\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"    def forward_gpt_cell(self, head):\\n\",\n            \"        \\\"\\\"\\\" the forward pass of the gpt\\n\",\n            \"\\n\",\n            \"================================================================================\\n\",\n            \"\\n\",\n            \"======================================== SAMPLE 3 ========================================\\n\",\n            \"\\n\",\n            \"\\n\",\n            \"class GPT(nn.Module):\\n\",\n            \"    \\\"\\\"\\\"  the full GPT language model, with a context size of block_size \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"    def __init__(self, config):\\n\",\n            \"        super().__init__()\\n\",\n            \"\\n\",\n            \"        # input embedding stem\\n\",\n            \"        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\\n\",\n            \"        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\\n\",\n            \"        self.drop = nn.Dropout(config.embd_pdrop)\\n\",\n            \"        # transformer\\n\",\n            \"        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\\n\",\n            \"        # decoder head\\n\",\n            \"        self.ln_f = nn.LayerNorm(config.n_embd)\\n\",\n            \"        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n            \"\\n\",\n            \"        self.block_size = config.block_size\\n\",\n            \"        self.apply(self._init_weights)\\n\",\n            \"\\n\",\n            \"        logger.info(\\\"number of parameters: %e\\\", sum(p.numel() for p in self.parameters()))\\n\",\n            \"        logger.info(\\\"images size: %e\\\", config.images_len)\\n\",\n            \"        logger.info(\\\"embedding size: %e\\\", config.embedding_size)     \\n\",\n            \"\\n\",\n            \"        self.vocab_size = config.vocab_size\\n\",\n            \"        self.hidden_size = config.hidden_size\\n\",\n            \"        self.n_layer = config.n_layer\\n\",\n            \"        self.block_size = config.block_size\\n\",\n            \"        self.cell_dim = config.cell_dim\\n\",\n            \"        self.n_embd = config.n_embd\\n\",\n            \"        self.n_embd = config.n_embd\\n\",\n            \"        self.embd_pdrop = config.embd_pdrop\\n\",\n            \"        self.n_batch = config.n_batch\\n\",\n            \"        self.n_dembd = config.n_dembd\\n\",\n            \"        self.101k_embd = config.101k_embd\\n\",\n            \"        self.shotting_dist = config.shotting_dist\\n\",\n            \"        self.dropout = config.embd_pdrop\\n\",\n            \"\\n\",\n            \"        # init variables\\n\",\n            \"        self._init_weights()\\n\",\n            \"\\n\",\n            \"    def _init_weights(self):\\n\",\n            \"        for layer in self.blocks:\\n\",\n            \"            for cell in layer:\\n\",\n            \"                param_init = cell.init_weights()\\n\",\n            \"                self.parameters()[layer][cell] = param_init.assign(param_init)\\n\",\n            \"\\n\",\n            \"    def forward(self, x, gpt_emb, gpt_state, gpt_emb_dim, gpt_state_dim):\\n\",\n            \"        \\\"\\\"\\\"  a forward pass for language model derivations\\n\",\n            \"\\n\",\n            \"            input, latent and context embeddings of convolutional layers as well as the entity embedding to obtain the\\n\",\n            \"            topic-embedding applied to the gpt entity embedding to generate the knowledge graph representation\\n\",\n            \"            gpt latent representations are then transformed into some vector representation\\n\",\n            \"\\n\",\n            \"            latent representation is then used as input to the decoder head, to produce the gpt entity representation\\n\",\n            \"\\n\",\n            \"            finally, the gpt entity representation is used as input to the decoder head, to produce the gpt latent representation on top of which\\n\",\n            \"            the knowledge graph representation is constructed\\n\",\n            \"        \\\"\\\"\\\"\\n\",\n            \"\\n\",\n            \"        n_ctx = len(x)\\n\",\n            \"        x_bn = x.nonzero()[0]/n_ctx\\n\",\n            \"        latent_bn = x_bn.nonzero()[0]/n_ctx\\n\",\n            \"        cv_emb = self.tok_emb(x_bn)\\n\",\n            \"        # consider the entity embedding to get the gpt latent representation\\n\",\n            \"        entity_mask = self.apply(gpt_emb_dim) if self.embd else 0\\n\",\n            \"        latent = self.apply(gpt_state_dim)\\n\",\n            \"        latent = latent * gpt_emb + entity_mask * gpt_state + self.drop\\n\",\n            \"\\n\",\n            \"        self.ln_f.weight.data.fill_(1.0)\\n\",\n            \"        self.ln_f.bias.data.zero_()\\n\",\n            \"        self.ln_f.weight.data[0].copy_(self.tok_emb)\\n\",\n            \"        self.ln_f.bias.data[0].copy_(self.pos_emb)\\n\",\n            \"        mlp = nn.Linear(config.hidden_size, config.n_embd)\\n\",\n            \"        mlp.bias.data[0].copy_(self.hidden_size)\\n\",\n            \"        self.ln_f.weight.data[0].copy_(mlp.weight.data)\\n\",\n            \"        # get the gpt latent representation on top of which the knowledge graph\\n\",\n            \"        ln_gpt_emb = self.apply(gpt_emb_dim)\\n\",\n            \"        # ln_gpt_emb = logits.sample(self.shotting_dist)\\n\",\n            \"        # ln_gpt_emb_shape = [1]\\n\",\n            \"        # gpt_ln_emb.data[0].copy_(ln_gpt_emb.data[0])\\n\",\n            \"        # gpt_ln_emb_shape = [0]\\n\",\n            \"        # gpt_gpt_emb = gpt_ln_emb.gather([0], gpt_ln_emb.shape)\\n\",\n            \"        # get the gpt latent representation to be used as the starting latent embedding of the decoder\\n\",\n            \"        ln_src_emb = gpt_ln_emb\\n\",\n            \"        ln_state = gpt_ln_emb.gather([0], ln_gpt_emb.shape)\\n\",\n            \"\\n\",\n            \"        # get the context and latent embedding representation of the entire input\\n\",\n            \"        # x_ext = x_bn[latex_str].squeeze()\\n\",\n            \"        self.apply(n)\\n\",\n            \"        # get the context representation used to decode the embedded gpt representation\\n\",\n            \"        x_ext = x_bn[latex_str].squeeze()\\n\",\n            \"        x_ext = x_ext.transpose(1, 0)\\n\",\n            \"        x_ext = F.relu(self.apply(x_ext))\\n\",\n            \"        x_ext_bn = x_ext.transpose(1, 0)\\n\",\n            \"        # x_ext_bn = x_ext_bn.transpose(1, 0)\\n\",\n            \"        # initialize the decoder hidden state\\n\",\n            \"        ln_src_emb, diff_emb = collections.defaultdict(list), []\\n\",\n            \"        for i, i_emb in enumerate(self.ln_f):\\n\",\n            \"            i_blk = int(self.block_size*(i+1))\\n\",\n            \"            mlp = nn.Linear(context_embedding_dim, n\\n\",\n            \"\\n\",\n            \"================================================================================\\n\",\n            \"\\n\",\n            \"Enqueue next (1) batch(es) of data to infeed.\\n\",\n            \"Dequeue next (1) batch(es) of data from outfeed.\\n\",\n            \"Outfeed finished for iteration (1, 0)\\n\",\n            \"Stop infeed thread controller\\n\",\n            \"Shutting down InfeedController thread.\\n\",\n            \"InfeedController received shutdown signal, stopping.\\n\",\n            \"Infeed thread finished, shutting down.\\n\",\n            \"infeed marked as finished\\n\",\n            \"Stop output thread controller\\n\",\n            \"Shutting down OutfeedController thread.\\n\",\n            \"OutfeedController received shutdown signal, stopping.\\n\",\n            \"Outfeed thread finished, shutting down.\\n\",\n            \"outfeed marked as finished\\n\",\n            \"Shutdown TPU system.\\n\",\n            \"prediction_loop marked as finished\\n\",\n            \"prediction_loop marked as finished\\n\"\n          ],\n          \"name\": \"stdout\"\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nE9VImzHaI0z\"\n      },\n      \"source\": [\n        \"# Evaluating the model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"XGGbkgaFfp6f\"\n      },\n      \"source\": [\n        \"This section assumes you are using a pretrained model and relies on variables created in the `Pretrained model` section.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"I45yUIpbaLUJ\"\n      },\n      \"source\": [\n        \"## Wikitext\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"zwBDB9U2keFV\"\n      },\n      \"source\": [\n        \"Download the wikitext test set:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"uuugiBmJaNxf\"\n      },\n      \"source\": [\n        \"wikitext103_src = \\\"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip\\\"\\n\",\n        \"!wget $wikitext103_src\\n\",\n        \"!unzip wikitext-103-raw-v1.zip\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"J5wf3QWKkhZt\"\n      },\n      \"source\": [\n        \"Tokenize and upload to bucket:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"6mo8UUtDdctH\"\n      },\n      \"source\": [\n        \"\\n\",\n        \"!mkdir wikitext\\n\",\n        \"!mv /content/GPTNeo/wikitext-103-raw/wiki.test.raw wikitext/wikitext_test.txt\\n\",\n        \"\\n\",\n        \"# Tokenize Data\\n\",\n        \"!python data/create_tfrecords.py --input_dir wikitext --name wikitext --files_per 1000 --output_dir wikitext_tokenized --write_dataset_config --processes 1 --wikitext-detokenize\\n\",\n        \"\\n\",\n        \"# copy the data to your bucket\\n\",\n        \"if not path_to_cloud_bucket.endswith('/'):\\n\",\n        \"  path_to_cloud_bucket += '/'\\n\",\n        \"copy_loc = path_to_cloud_bucket \\n\",\n        \"!gsutil -m cp -r wikitext_tokenized $copy_loc\\n\",\n        \"!gsutil ls $path_to_cloud_bucket\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"GE84TUd1fAzf\"\n      },\n      \"source\": [\n        \"Now make a dataset config that points to the tokenized wikitext data:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"Z5UU7DQeeY0S\"\n      },\n      \"source\": [\n        \"%%writefile configs/dataset_configs/wikitext.json\\n\",\n        \"\\n\",\n        \"{\\n\",\n        \"  \\\"path\\\": \\\"\\\",\\n\",\n        \"  \\\"eval_path\\\": \\\"gs://test-bucket-neo/wikitext_tokenized/*.tfrecords\\\",\\n\",\n        \"  \\\"n_vocab\\\": 50256,\\n\",\n        \"  \\\"tokenizer_is_pretrained\\\": true,\\n\",\n        \"  \\\"tokenizer_path\\\": \\\"gpt2\\\",\\n\",\n        \"  \\\"eos_id\\\": 50256,\\n\",\n        \"  \\\"padding_id\\\": 50257\\n\",\n        \"}\\n\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"egvdwIOqfFER\"\n      },\n      \"source\": [\n        \"And update your model config to point to that dataset:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"AtdoIFMgfOe8\"\n      },\n      \"source\": [\n        \"# @title Modify config for wikitext. \\n\",\n        \"  \\n\",\n        \"import json\\n\",\n        \"from pprint import pprint\\n\",\n        \"\\n\",\n        \"batch_size = 8 #@param {type:\\\"integer\\\"}\\n\",\n        \"assert pretrained_model is not None\\n\",\n        \"with open(f'configs/{pretrained_model}.json', 'r') as f:\\n\",\n        \"  data = json.load(f)\\n\",\n        \"  pprint(data)\\n\",\n        \"  dset_val = [[\\\"wikitext\\\", None, None, None]]\\n\",\n        \"  mods = {\\n\",\n        \"          \\\"datasets\\\": dset_val,\\n\",\n        \"          \\\"eval_steps\\\": 139 // batch_size,\\n\",\n        \"          \\\"train_batch_size\\\": batch_size,\\n\",\n        \"          \\\"eval_batch_size\\\": batch_size,\\n\",\n        \"        }\\n\",\n        \"  data.update(mods)\\n\",\n        \"  print('\\\\n--->\\\\n')\\n\",\n        \"  pprint(data)\\n\",\n        \"  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\\n\",\n        \"    json.dump(data, outfile, indent=2)\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"U2d5eTHEg6Xj\"\n      },\n      \"source\": [\n        \"Now run model in eval mode over tokenized data:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"s1Uz3PXzg5Pm\"\n      },\n      \"source\": [\n        \"!python3 main.py --eval --tpu colab --model $pretrained_model\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"9dbkPVcMhVaR\"\n      },\n      \"source\": [\n        \"## Lambada\\n\",\n        \"\\n\",\n        \"Lambada eval is built into the codebase and can be run by adding a field to your model config\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"z4FJXOlJiEYo\"\n      },\n      \"source\": [\n        \"# @title Modify config for Lambada. \\n\",\n        \"  \\n\",\n        \"import json\\n\",\n        \"from pprint import pprint\\n\",\n        \"\\n\",\n        \"batch_size = 8 #@param {type:\\\"integer\\\"}\\n\",\n        \"assert pretrained_model is not None\\n\",\n        \"with open(f'configs/{pretrained_model}.json', 'r') as f:\\n\",\n        \"  data = json.load(f)\\n\",\n        \"  mods = {\\n\",\n        \"          \\\"datasets\\\": dset_val,\\n\",\n        \"          \\\"eval_steps\\\": 0,\\n\",\n        \"          \\\"train_batch_size\\\": batch_size,\\n\",\n        \"          \\\"eval_batch_size\\\": batch_size,\\n\",\n        \"          \\\"eval_tasks\\\": [\\\"lambada\\\"]\\n\",\n        \"        }\\n\",\n        \"  data.update(mods)\\n\",\n        \"  print('\\\\n--->\\\\n')\\n\",\n        \"  pprint(data)\\n\",\n        \"  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\\n\",\n        \"    json.dump(data, outfile, indent=2)\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Upp-bGMriVPK\"\n      },\n      \"source\": [\n        \"Now run the eval:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"OOA1YZDRiUhN\"\n      },\n      \"source\": [\n        \"!python3 main.py --eval --tpu colab --model $pretrained_model\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    }\n  ]\n}\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 EleutherAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# GPT Neo\n\n[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5297715.svg)](https://doi.org/10.5281/zenodo.5297715) [![arXiv](https://img.shields.io/badge/arXiv-2101.00027-f9f107.svg)](https://arxiv.org/abs/2101.00027)\n\n**As of August, 2021 code is no longer maintained. It is preserved here in archival form for people who wish to continue to use it.*\n\n🎉 1T or bust my dudes 🎉\n\nAn implementation of model & data parallel [GPT3](https://arxiv.org/abs/2005.14165)-like models using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library.\n\n**If you're just here to play with our pre-trained models, we strongly recommend you try out the [HuggingFace Transformer integration](https://huggingface.co/EleutherAI).**\n\nTraining and inference is officially supported on TPU and should work on GPU as well. This repository will be (mostly) archived as we move focus to our GPU-specific repo, [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).\n\nIn addition to the functionality offered by GPT-3, we also offer the following:\n* [Local attention](https://arxiv.org/abs/2004.05150)\n* [Linear attention](https://arxiv.org/abs/1812.01243)\n* [Mixture of Experts](https://arxiv.org/abs/1701.06538)\n* [Axial Positional embedding](https://arxiv.org/abs/1912.12180)\n\nNB, while neo can *technically* run a training step at 200B+ parameters, it is very inefficient at those scales. This, as well as the fact that many GPUs became available to us, among other things, prompted us to move development over to [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).\n\n# Pretrained Models\n\n**Update 21/03/2021:**\n\nWe're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/public/AI/gptneo-release/).\n\n1.3B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_XL/\n\n2.7B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/\n\nFor more information on how to get these set up, see the colab notebook, or read through the rest of the readme.\n\n## Model Evaluations\n\n#### Linguistic Reasoning\n\n| Model and Size   | Pile BPB   | Pile PPL  | Wikitext PPL | Lambada PPL | Lambada Acc | Winogrande | Hellaswag  |\n|------------------|------------|-----------|--------------|-------------|-------------|------------|------------|\n| **GPT-Neo 125M** | -----      | -----     | **32.285**   | **30.266**  | **37.36%**  | **50.43%** | **28.67%** |\n| GPT-3 125M       | -----      | -----     | -----        | 18.6        | 42.7%       | 52.0%      | 33.7%      |\n| **GPT-Neo 350M** | -----      | -----     | **22.5657**  | **13.876**  | **47.27%**  | **51.14%** | **32.16%** |\n| GPT-3 350M       | -----      | -----     | -----        | 9.09        | 54.3%       | 52.1%      | 43.6%      |\n| GPT-3 Ada        | 0.9631     | -----     | -----        | 9.954       | 51.60%      | 52.90%     | 35.93%     |\n| **GPT-Neo 1.3B** | **0.7527** | **6.159** | **13.10**    | **7.498**   | **57.23%**  | **55.01%** | **38.66%** |\n| GPT-3 1.3B       | -----      | -----     | -----        | 5.44        | 63.6%       | 58.7%      | 54.7%      |\n| GPT-2 1.5B       | 1.0468     | -----     | 17.48        | 10.634      | 51.21%      | 59.40%     | 40.03%     |\n| **GPT-Neo 2.7B** | **0.7165** | **5.646** | **11.39**    | **5.626**   | **62.22%**  | **56.50%** | **42.73%** |\n| GPT-3 2.7B       | -----      | -----     | -----        | 4.60        | 67.1%       | 62.3%      | 62.8%      |\n\n\n#### Physical and Scientific Reasoning\n\n| Model and Size   | MathQA     | PubMedQA   | Piqa       |\n|------------------|------------|------------|------------|\n| **GPT-Neo 125M** | **22.78%** | **55.10%** | **63.06%** |\n| GPT-3 125M       | -----      | -----      | 64.6%      |\n| **GPT-Neo 350M** | **23.45%** | **53.80%** | **65.07%** |\n| GPT-3 350M       | -----      | -----      | 70.2%      |\n| GPT-3 Ada        | 24.29%     | 52.80%     | 68.88%     |\n| **GPT-Neo 1.3B** | **24.05%** | **54.40%** | **71.11%** |\n| GPT-3 1.3B       | -----      | -----      | 75.1%      |\n| GPT-2 1.5B       | 23.64%     | 58.33%     | 70.78%     |\n| **GPT-Neo 2.7B** | **24.72%** | **57.54%** | **72.14%** |\n| GPT-3 2.7B       | -----      | -----      | 75.6%      |\n\n\n**Note:** All evaluations were done using our [evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). Some results for GPT-2 and GPT-3 are inconsistent with the values reported in the respective papers. We are currently looking into why, and would greatly appreciate feedback and further testing of our eval harness.\n\n# Setup\n\n```bash\ngit clone https://github.com/EleutherAI/GPTNeo\ncd GPTNeo\npip3 install -r requirements.txt\n```\n# Training Setup\n\n## TPUs:\n\nSign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage). \n\nCreate your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and install the requirements with pip (see above).\n\nGoogle colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes.\nClick [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EleutherAI/GPTNeo/blob/master/GPTNeo_example_notebook.ipynb) to run through our example colab notebook.\n\nFor more detailed instructions, run through our [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below.\n\n## GPUs:\n\nYou can also choose to train GPTNeo locally on your GPUs. To do so, you can omit the Google cloud setup steps above, and git clone the repo locally. Run through the [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below, then when running main.py, you simply have to omit the `tpu` flag, and pass in GPU ids instead.\n\nNote: Some users have reported having difficulty getting MTF to recognize their GPUs. See [here](https://github.com/EleutherAI/gpt-neo/issues/150) for details and instructions on how to fix it.\n\n# Generating Text\n\nOnce you have a trained model, or you've downloaded one of our pre-trained models, generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so:\n\n```bash\npython3 main.py --predict --prompt <example_prompt.txt> --tpu <tpu_name> --model <config_name>\n```\n\nor, if using GPUs:\n\n```bash\npython3 main.py --predict --prompt <example_prompt.txt> --gpu_ids <device:GPU:0 device:GPU:1> --model <config_name>\n```\n\n# Training Guide\n\n## 1. Create your Tokenizer (OPTIONAL)\n\nWe recommend you use [Huggingface's pretrained GPT2 tokenizer](https://huggingface.co/transformers/model_doc/gpt2.html#transformers.GPT2Tokenizer) with our repo (instructions provided below), but if you want to train a model with a different vocabulary size, we provide facilities to train your own tokenizer like so:\n\n```bash\npython data/train_tokenizer.py \\\n    --base_dir ./path/to/your/txt/files \\\n    --output_dir ./output/path \\\n    --file_type txt \\\n    --vocab_size 50257\n\n# if it succeeded, you should see the message\n# 'tokenizer saved at ./output/path/byte-level-bpe.tokenizer.json'\n```\n\n## 2. Tokenizing your Dataset\n\nIf you just want to test training, you can skip this step and download some dummy data like so:\n\n```\nwget https://storage.googleapis.com/connors-datasets/bundestag/bundestag_0.tfrecords\n```\n\nThen copy the data to your bucket, or if using GPUs, a local directory: \n\n```\ngsutil cp bundestag_0.tfrecords gs://<your bucket>/\n```\n\nIf using your own data to train, you can use the `data/create_tfrecords.py` script to encode your text data into tfrecords.\n\nYour data must either be in the form of lots of normal .txt files (one document per file), or in any format supported by [lm_dataformat](https://github.com/leogao2/lm_dataformat). \n\nYou can run the script without parameters to see help for all options.\n\nIn **document mode** Each example in the tfrecords is one (variably sized) document. This is to be used with the `documents_fixed` and `documents_random` sampling modes (For more details see the parameters reference section).\nDocument mode is the default mode.\n\nThe below command will tokenize all files in acceptable formats in *base_dir* using gpt2 tokenizer and save them to *output_dir*\n```\npython3 create_tfrecords.py --mode documents --input_dir <base> --name <name> --output_dir <output> --use_gpt2_tokenizer --minimum_size <min> \n```\n\n- `input_dir`: Defines the folder where your data is located. The script will encode all files present in this folder.\n- `name`: Name of output files will be `name_i.tfrecords` where i is the number of the file.\n- `output_dir`: Where to save the tfrecords to\n- `use_gpt2_tokenizer`: Whether to use the pretrained HuggingFace GPT2 tokenizer, in which case the separator will be set to [50256].\n- `encoder_path`: if not using the pretrained gpt2 tokenizer, use this flag to provide a path to your generated tokenizer json.\n- `separator`: Written in list format, the separator token(s) to insert between documents (e.g. \"[0]\"). Will depend on your encoder.\n- `minimum_size`: The minimum size (in tokens) a document must have, otherwise it is discarded. This is what will later determine your `stitch` parameter: `stitch * minimum_size` must always be greater or equal `n_ctx` (For more details see the parameters reference section).\n\n## 4. Using a Dataset in a Model\n\nTo use a dataset in a model, you must first register that dataset under `./configs/dataset_configs` folder. First choose a filename with a `.json` extension. That filename will serve as the dataset identification. The config should be filled out the following manner.\n\nIf you have a dataset encoded using the pretrained gpt2 tokenizer, you can specify that like so:\n\n```json\n{\n    \"n_vocab\": 50257,\n    \"path\": \"gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords\",\n    \"eval_path\": \"gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords\",\n    \"tokenizer_is_pretrained\": true,\n    \"tokenizer_path\": \"gpt2\"\n}\n```\n\nor if you've trained a custom tokenizer, like so:\n\n```json\n{\n    \"n_vocab\": 32768,\n    \"path\": \"./path/to/your/*.tfrecords\",\n    \"eval_path\": \"./path/to/your/eval/*.tfrecords\",\n    \"tokenizer_path\": \"./path/to/your/byte-level-bpe.tokenizer.json\"\n}\n```\n\nFinally, in your model config, add the filename that you created above to the `datasets` array.\n\nThe `<dataset id>` will be the filename, excluding the `.json`, that you created above\n\n```\n\"datasets\": [[<dataset id>, <stitch>, <datatype>, <weight>]] # datasets key defines at run time how each dataset is processed for training\n```\n\n## 5. Choose a model configuration\n\nOnce you have your datasets set up, find a suitable config in `/configs`.\n\nHere we use a GPT3-XL sized model as an example, but there are many more in `./configs`, all of which have short summaries in the Available Configs section.\n\nAll you need to do is edit the dataset id as described above, and edit `model_path` (where logs and checkpoints will be saved) to point to a cloud bucket you have write access to (or local path, if using GPUs).\n\n```json\n{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0.1,\n    \"lr\": 0.0002,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.1,\n    \"train_batch_size\": 512,\n    \"attn_dropout\": 0.1,\n    \"train_steps\": 286150,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0.1,\n    \"eval_batch_size\": 128,\n    \"predict_batch_size\": 1,\n    \"iterations\": 2500,\n    \"n_embd\": 2048,\n    \"datasets\": [[\"your_dataset_name\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_XL\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 24,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],24]],\n    \"mesh_shape\": \"x:128,y:2\",\n    \"layout\": \"batch:x,memory_length:y,embd:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0,\n    \"tokens_per_mb_per_replica\": 2048\n}\n```\n\n\n## 6. Run Training\n\n```\npython3 main.py --model <your_config_name> --steps_per_checkpoint <n> --tpu <tpu-name>\n```\n\n- `tpu`: Name of the TPU to use.\n- `steps_per_checkpoint`: The frequency in steps at which to save checkpoints.\n- `--auto_layout` and `--auto_layout_and_mesh_shape` (Optional): Disable training and instead auto generate a memory efficient `layout` (and `mesh_shape`)\n- `gpu_ids`: if training using GPUs, omit the `tpu` flag and pass in the ids of your gpus. In the example below, we train on 3 GPUs, specifying their device ids delimited by spaces:\n\n```\npython3 main.py --model <your_config_name> --steps_per_checkpoint <n> --gpu_ids <device:GPU:0 device:GPU:1>\n```\n\n# Available Configs\n\nWe have several model sizes available, but some of our configs require large TPUs and will need tweaking to run on smaller machines, or GPUs. Below is a short guide to each model in the configs directory:\n\nTODO\n\n# Extra Features: \n\n## Training (with Sacred)\n\n[Sacred](https://github.com/IDSIA/sacred) helps track experiments and is much nicer to work with than tensorboard.\n\nTo setup:\n\n1. Install Docker and Docker-compose\n\n2. Run `docker-compose up`\n\nTo use: \n\n1. Ensure model_dir doesn't have any metric logs in it (it trips up the metric stuff for tensorboard, which assumes that it's a continuation of the existing run). You can use `gsutil rm -r ...` to delete model dir\n\n2. Run `python3 run_experiment.py --tpu sometpuhere --model someconfig.json` Options are the same as `main.py`. \n\n3. You can go to http://server_ip_goes_here:8081/ to see the Omniboard overview. If you prefer to see a tensorboard, the script also spins one up and automatically assigns it a port. The script should print out the tensorboard port near the top of the log. \n\n## Peeking at a Dataset\n\nIf you are ever confused by the dataset of a particular config file, you can easily check the minimum and maximum token ids with a single command. This is useful for making sure that the vocabulary size of the model is at least as large as the maximum token id. Tensorflow will not error if you try to gather on a matrix with out of bounds indices, so you need to make sure your vocabulary size is sufficiently large.\n\n```bash\npython main --model {config_name} --check_dataset\n```\n\n## Masked Language Modeling\n\nIn addition to being able to train large GPT's, this repository also allows you to easily do masked language modeling (BERT, RoBERTa). In order to do so, you must follow two additional steps.\n\n1. When tokenizing your dataset, you must reserve a special id for the `[mask]` token.\n\n2. In the configs, you will have to define two additional fields\n\n```python\n\"mlm_training\": true,                           # must be set to true\n\"mlm_mask_id\": <mask id>                        # the mask id that you reserved from above\n```\n\nThat's all you need to train a model with the MLM objective, good for any type of data that you have encoded properly. If you would like to tweak the other related hyperparameters, please continue reading.\n\n```python\n\"mlm_cls_token_id\": <cls token id>,                # auto append specified CLS token id on the left\n\"mlm_mask_prob\": 0.15,                             # the probability of masking a token, defaults to 15%\n\"mlm_same_token_prob\": 0.10,                       # probability of keeping the token the same, defaults to 10%\n\"mlm_random_token_prob\": 0.10,                     # probability of tokens that are replaced with random tokens, 10% was recommended by the BERT paper\n\"mlm_mask_ignore_ids\": [<cls token>, <sep token>]  # ignore masking other special tokens, if any\n```\n\n## Parameter Reference\n\nPick a valid config from `/configs` and tweak the parameters as needed:\n\n- `n_heads`: The number of attention heads.\n- `n_embd`: Size of the hidden layers, must be divisible by `n_heads`.\n- `n_vocab`: Vocabulary size.\n- `embed_dropout`, `res_dropout`, `attn_dropout`: Dropout probability for word embedding/residuals/attention\n- `lr`: Learning rate\n- `warmup_steps`: Number of steps before full learning rate is reached (linear ramp from `0` to `lr`).\n- `lr_decay`: `cosine` or `linear`.\n- `opt_name`: `adam` or `adafactor`.\n- `beta1`, `beta2` and `epsilon`: `adam` optimizer params.\n- `beta1`, `ada_epsilon1` and `ada_epsilon2`: `adafactor` optimizer params.\n- `weight_decay`: Weight decay parameter, if not present no weight decay is used (the weight decay fix for Adam is used) (default: 0.01) (optional).\n- `train_batch_size`: Batch size during training.\n- `train_steps`: Number of training steps (batches), set to roughly ~1 epoch for now (total number of tokens in your dataset / number of tokens per batch (= `train_batch_size` / `n_ctx`)).\n- `eval_steps`: Number of steps to run for each evaluation. Set to `0` for no eval. i.e After every checkpoint, the model is tested for `eval_steps`\n- `iterations`: Number of steps queued to the TPU, must be smaller than `steps_per_checkpoint`. (default: 500)\n- `datasets`: List of tfrecords datasets to use. Each dataset is a list with the following parameters: `[train glob , eval glob, stitch, sampling_mode, weight]`. So for example for a single dataset (note the double list): `[[\"bundestag_*.tfrecords\", \"\", 10, \"random_sample\", 1.0]]`\n    + `dataset_id`: The name of a dataset configuration file in `./configs/dataset_configs`\n    + `stitch`: If `sampling_mode` `random_sample` is used, the input pipeline samples this amount of texts into one to sample from. You must select stitch so that `stitch * minimum_document_length >= n_ctx`\n    + `sampling_mode`: `chunks` (tfrecords are preprocessed into the correct length and are read sequentially) or `documents_random` (`stitch` amount of documents are concatenated and then a `n_ctx` chunk is randomly subsampled)\n    + `weights`: How much relative weight this dataset should have compared to others\n- `model`: Which model to train. Currently only `GPT` is supported, and it defaults to this if not present.\n- `model_path`: Google storage bucket location (or local path, if using GPUs) to save model checkpoints and logs.\n- `n_ctx`: Size of context window. Default is 2048\n- `n_layer`: Number of layers (blocks) in the model.\n- `scale_by_depth`: If true, the weight initialization of layers are scaled by their depth as in the GPT2 paper.\n- `scale_by_in`: If true, the weight initialization of layers are scaled by their number of inputs as in the GPT2 paper.\n- `mesh_shape`: A Mesh is an n-dimensional array of processors with named dimensions used for parallelism in the mesh-tensorflow library. Each Tensor is split evenly across mesh dimensions according to the layout (see below). The 'mesh_shape' is the shape of this array, and must be equal to the number of processors. e.g., for a v3-128 TPU \"mesh_shape\": “x:16,y:8”.\n- `layout`: A Tensor is laid out on its mesh with one slice on each processor. A Tensor \"layout\", is an injective partial map specifying which dimensions of the tensor are (evenly) split across which dimensions of the mesh. No dimension of a tensor may be split across two dimensions of its mesh and no two dimensions of a tensor may be split across the same dimension of its mesh. The user defines a global set of layout rules in the form of (tensor-dimension-name, mesh-dimension-name) pairs. A dimension of a tensor is split across a dimension of its mesh if there is a matching rule, e.g. (for the above example mesh_shape: \"layout\":\"batch:x,heads:y\"\n- `activation_function`: `selu` (self normalizing) or `gelu` (used by OA), activation function used in feed-forward passes. (default: gelu)\n- `attention_types`: the type of attention for each layer in a list of the following format [[[\"attention_type\"], n_layers]]. e.g. for a 12 layer net [[[\"global\"], 12]] or [[[\"local\"], 10], [[\"global\"], 2]].\n    + Choose from: `linear`, `global`, `local` or `none`. We have found a 50/50 mix of `global` and `linear` to work well. `none` allows you to create feed-forward only layers for more efficient [PAR Transformer](https://arxiv.org/abs/2009.04534) models.\n- `precision`: `float32` or `bfloat16`.\n- `tokens_per_mb_per_replica`: If not None, will split the batch up into smaller microbatches containing `tokens_per_mb_per_replica` tokens to avoid OOMs. Gradients are accumulated locally and reduced once. IMPORTANT: mb refers to *minibatch* not megabyte here. \n\n**Mixture of Experts**\n\n- `moe_layers`: A list of layer numbers to append a [mixture of experts](https://arxiv.org/abs/1701.06538) layer onto. E.G: `[2,4,6,8,10,12]`.\nWe have experimentally found a moe layer for every two self-attention layers to work well.\n-  `moe_params`: a dictionary of additional kwargs to pass in to the moe layer. E.G\n    `{\"moe_dropout_rate\": 0.0 }`\n    \n**Experimental features** \n\n- `axial_pos_emb_`: If true, uses [axial positional embedding](https://arxiv.org/abs/1912.12180. \n- `mlp_glu`: If true, uses a gated linear unit variant of feed forward layers.\n- `scalenorm`: If true, uses scalenorm instead of layernorm.\n- `rezero`: If true, uses [rezero](https://www.groundai.com/project/rezero-is-all-you-need-fast-convergence-at-large-depth/1) instead of layernorm.\n- `num_mem_kv`: adds memory / key values from the [all-attention paper](https://arxiv.org/pdf/1907.01470.pdf). Param is an int with the number of desired mem/key values.\n- `macaron`: if true - uses a [macaron transformer](https://arxiv.org/pdf/1906.02762.pdf) for each layer block.\n\n## TODO: \n\n- [x] finalize documentation\n- [ ] update configs\n\n## Citing GPT-Neo\n\nIf you have found GPT-Neo helpful in your work, you can cite this repository as\n\n```\n@software{gpt-neo,\n  author       = {Black, Sid and\n                  Gao, Leo and\n                  Wang, Phil and\n                  Leahy, Connor and\n                  Biderman, Stella},\n  title        = {{GPT-Neo: Large Scale Autoregressive Language \n                   Modeling with Mesh-Tensorflow}},\n  month        = mar,\n  year         = 2021,\n  note         = {{If you use this software, please cite it using \n                   these metadata.}},\n  publisher    = {Zenodo},\n  version      = {1.0},\n  doi          = {10.5281/zenodo.5297715},\n  url          = {https://doi.org/10.5281/zenodo.5297715}\n}\n\n```\nThe version number should be replaced with the version number you are using, and the year corresponds to the project's open-source release.\n\nIf you are specifically interested in citing the GPT-Neo models trained on [the Pile](https://arxiv.org/abs/2101.00027), we would appreciate also citing\n```\n@article{gao2020pile,\n  title={The Pile: An 800GB Dataset of Diverse Text for Language Modeling},\n  author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others},\n  journal={arXiv preprint arXiv:2101.00027},\n  year={2020}\n}\n```\n"
  },
  {
    "path": "configs/dataset_configs/example.json",
    "content": "{\n\t\"n_vocab\": 32768,\n\t\"path\": \"./tfrecords/openwebtext_*.tfrecords\",\n\t\"eval_path\": \"\",\n\t\"tokenizer_path\": \"./datasets/openwebtext/byte-level-bpe.tokenizer.json\",\n\t\"eos_id\": 1,\n\t\"padding_id\": 0\n}\n"
  },
  {
    "path": "configs/dataset_configs/openwebtext2_new_inputs.json",
    "content": "{\n\t\"n_vocab\": 50257,\n\t\"path\": \"gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords\",\n\t\"eval_path\": \"gs://neo-datasets/openwebtext2_new_inputs/eval/*.tfrecords\",\n\t\"tokenizer_is_pretrained\": true,\n\t\"tokenizer_path\": \"gpt2\",\n\t\"eos_id\": 50256,\n\t\"padding_id\": 50257\n}\n"
  },
  {
    "path": "configs/dataset_configs/pile.json",
    "content": "{\n\t\"n_vocab\": 50257,\n\t\"path\": \"gs://neo-datasets/pile/pile_*.tfrecords\",\n\t\"eval_path\": \"gs://neo-datasets/pile_val.tfrecords\",\n\t\"tokenizer_is_pretrained\": true,\n\t\"tokenizer_path\": \"gpt2\",\n\t\"eos_id\": 50256,\n\t\"padding_id\": 50257\n}\n"
  },
  {
    "path": "configs/gpt2_small.json",
    "content": "{\n    \"n_head\": 6,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0.1,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0,\n    \"train_batch_size\": 512,\n    \"attn_dropout\": 0.1,\n    \"train_steps\": 1000000,\n    \"lr_decay_end\": 300000,\n    \"eval_steps\": 30,\n    \"predict_steps\": 0,\n    \"res_dropout\": 0.1,\n    \"eval_batch_size\": 128,\n    \"predict_batch_size\": 8,\n    \"iterations\": 2500,\n    \"n_embd\": 768,\n    \"datasets\": [\"openwebtext2_new_inputs\"],\n    \"model_path\": \"gs://neo-models/GPT2_SMALL\",\n    \"n_ctx\": 1024,\n    \"n_layer\": 12,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],12]],\n    \"activation_function\": \"gelu\",\n    \"mesh_shape\": \"all:64\",\n    \"layout\": \"batch:all\",\n    \"recompute_grad\": false,\n    \"gradient_clipping\": 1.0\n}"
  },
  {
    "path": "configs/gpt3_13B_256.json",
    "content": "{\n    \"n_head\": 40,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0001,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"ada_epsilon1\": 1e-30,\n    \"ada_epsilon2\": 1e-3,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 1024,\n    \"attn_dropout\": 0,\n    \"train_steps\": 143075,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 128,\n    \"predict_batch_size\": 1,\n    \"iterations\": 500,\n    \"n_embd\": 5120,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_13B\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 40,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\", \"local\"],20]],\n    \"mesh_shape\": \"x:16,y:16\",\n    \"layout\": \"batch:x,embd:y,memory_length:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0,\n    \"tokens_per_mb_per_replica\": 2048,\n    \"precision\": \"bfloat16\"\n}\n\n"
  },
  {
    "path": "configs/gpt3_13B_256_Pile.json",
    "content": "\n{\n    \"n_head\": 40,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0001,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.1,\n    \"train_batch_size\": 1024,\n    \"attn_dropout\": 0,\n    \"train_steps\": 286150,\n    \"eval_steps\": 10,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 512,\n    \"predict_batch_size\": 1,\n    \"iterations\": 500,\n    \"n_embd\": 5120,\n    \"datasets\": [[\"pile\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_13B_Pile\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 40,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],40]],\n    \"mesh_shape\": \"x:16,y:16\",\n    \"layout\": \"batch:x,memory_length:y,embd:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0,\n    \"tokens_per_mb_per_replica\": 2048,\n    \"precision\": \"bfloat16\"\n}\n"
  },
  {
    "path": "configs/gpt3_2-7B_256.json",
    "content": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00016,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"ada_epsilon1\": 1e-30,\n    \"ada_epsilon2\": 1e-3,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 512,\n    \"attn_dropout\": 0,\n    \"train_steps\": 286150,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 128,\n    \"predict_batch_size\": 1,\n    \"iterations\": 500,\n    \"n_embd\": 2560,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_2-7B\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 32,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],32]],\n    \"mesh_shape\": \"x:128,y:2\",\n    \"layout\": \"embd:y,batch:x\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0\n}\n\n"
  },
  {
    "path": "configs/gpt3_6-7B_256.json",
    "content": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00012,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 1024,\n    \"attn_dropout\": 0,\n    \"train_steps\": 143075,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 128,\n    \"predict_batch_size\": 1,\n    \"iterations\": 500,\n    \"n_embd\": 4096,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_6-7B\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 32,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],32]],\n    \"mesh_shape\": \"x:128,y:2\",\n    \"layout\": \"embd:y,batch:x\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0\n}\n\n"
  },
  {
    "path": "configs/gpt3_PAR_small_256.json",
    "content": "{\n    \"n_head\": 12,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 256,\n    \"attn_dropout\": 0,\n    \"train_steps\": 572300,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 64,\n    \"predict_batch_size\": 1,\n    \"iterations\": 1000,\n    \"n_embd\": 768,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_PAR_SMALL\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 19,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\": [[[\"global\", \"none\", \"none\"],5], [[\"none\"], 4]],\n    \"mesh_shape\": \"x:64,y:4\",\n    \"layout\": \"batch:x,heads:y,vocab:y,intermediate_expanded:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": false,\n    \"gradient_clipping\": 1.0\n}\n\n"
  },
  {
    "path": "configs/gpt3_XL_256_Pile.json",
    "content": "{\n    \"n_head\": 32,\n    \"n_vocab\": 50257,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0002,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.1,\n    \"train_batch_size\": 512,\n    \"attn_dropout\": 0,\n    \"train_steps\": 286150,\n    \"eval_steps\": 10,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 512,\n    \"predict_batch_size\": 1,\n    \"iterations\": 500,\n    \"n_embd\": 2048,\n    \"datasets\": [[\"pile\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_XL_Pile\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 24,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],24]],\n    \"mesh_shape\": \"x:128,y:2\",\n    \"layout\": \"batch:x,memory_length:y,embd:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0,\n    \"tokens_per_mb_per_replica\": 2048,\n    \"precision\": \"bfloat16\"\n}\n"
  },
  {
    "path": "configs/gpt3_large_256.json",
    "content": "{\n    \"n_head\": 16,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.00025,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"ada_epsilon1\": 1e-30,\n    \"ada_epsilon2\": 1e-3,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 256,\n    \"attn_dropout\": 0,\n    \"train_steps\": 572300,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 64,\n    \"predict_batch_size\": 1,\n    \"iterations\": 2500,\n    \"n_embd\": 1536,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_LARGE\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 24,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],24]],\n    \"mesh_shape\": \"x:64,y:4\",\n    \"layout\": \"batch:x,vocab:y,heads:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": true,\n    \"gradient_clipping\": 1.0,\n    \"tokens_per_mb_per_replica\": 2048\n}\n\n"
  },
  {
    "path": "configs/gpt3_medium_256.json",
    "content": "{\n    \"n_head\": 16,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0003,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 256,\n    \"attn_dropout\": 0,\n    \"train_steps\": 572300,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 64,\n    \"predict_batch_size\": 1,\n    \"iterations\": 2500,\n    \"n_embd\": 1024,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_MEDIUM\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 24,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\" :  [[[\"global\"],24]],\n    \"mesh_shape\": \"x:64,y:4\",\n    \"layout\": \"batch:x,heads:y,vocab:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": false,\n    \"gradient_clipping\": 1.0\n}\n\n"
  },
  {
    "path": "configs/gpt3_small_256.json",
    "content": "{\n    \"n_head\": 12,\n    \"n_vocab\": 50304,\n    \"embed_dropout\": 0,\n    \"lr\": 0.0006,\n    \"lr_decay\": \"cosine\",\n    \"warmup_steps\": 3000,\n    \"beta1\": 0.9,\n    \"beta2\": 0.95,\n    \"epsilon\": 1e-8,\n    \"opt_name\": \"adam\",\n    \"weight_decay\": 0.10,\n    \"train_batch_size\": 256,\n    \"attn_dropout\": 0,\n    \"train_steps\": 572300,\n    \"eval_steps\": 0,\n    \"predict_steps\": 1,\n    \"res_dropout\": 0,\n    \"eval_batch_size\": 64,\n    \"predict_batch_size\": 1,\n    \"iterations\": 2500,\n    \"n_embd\": 768,\n    \"datasets\": [[\"openwebtext-documents\", 25, \"documents_random\", 1.0]],\n    \"model_path\": \"gs://neo-models/GPT3_SMALL\",\n    \"n_ctx\": 2048,\n    \"n_layer\": 12,\n    \"scale_by_depth\": true,\n    \"scale_by_in\": false,\n    \"attention_types\": [[[\"global\"],12]],\n    \"mesh_shape\": \"x:64,y:4\",\n    \"layout\": \"batch:x,heads:y,vocab:y,intermediate_expanded:y\",\n    \"activation_function\": \"gelu\",\n    \"recompute_grad\": false,\n    \"gradient_clipping\": 1.0\n}\n\n"
  },
  {
    "path": "configs.py",
    "content": "import json\nfrom pathlib import Path\nfrom collections import defaultdict\n\nDATASETS = {}\n\nfor path in Path(\"configs/dataset_configs\").glob(\"*.json\"):\n    dataset_id = path.stem\n    DATASETS[dataset_id] = json.loads(path.read_text())\n\n\ndef fetch_model_params(model):\n    model_path = model if model.endswith(\".json\") else f\"configs/{model}.json\"\n    with open(model_path) as f:\n        params = json.load(f)\n\n    dataset_ids = []\n    for d in params.get(\"datasets\"):\n        if isinstance(d, list):\n            dataset_ids.append(d[0])\n        else:\n            dataset_ids.append(d)\n    no_datasets = params.get(\"no_dataset\", False)\n    assert no_datasets or len(dataset_ids) > 0, \"You must specify at least one dataset id in the model config\"\n\n    datasets = {}\n    last_dataset = None\n    for dataset_id in dataset_ids:\n        assert dataset_id in DATASETS, f\"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder.\"\n        dataset = DATASETS[dataset_id]\n        assert params[\"n_vocab\"] >= dataset[\"n_vocab\"], f\"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})\"\n        datasets[dataset_id] = dataset\n        last_dataset = dataset\n\n    if last_dataset is not None:\n        params[\"padding_id\"] = last_dataset.get(\"padding_id\", 0)\n        params[\"eos_id\"] = last_dataset.get(\"eos_id\", 1)\n\n    params[\"dataset_configs\"] = datasets\n\n    # Set some other parameter defaults\n    params[\"mlm_training\"] = params.get(\"mlm_training\") == True\n    params[\"causal\"] = not params[\"mlm_training\"]\n\n    # Set all other parameter values to default to None\n    params = defaultdict(lambda: None, params)\n    return params\n"
  },
  {
    "path": "data/create_tfrecords.py",
    "content": "import argparse\nimport os\nfrom pathlib import Path\n\nimport ftfy\nimport tensorflow as tf\nfrom lm_dataformat import Reader\nfrom tokenizers import Tokenizer\nfrom transformers import GPT2TokenizerFast\nfrom tqdm import tqdm\nimport logging\nfrom multiprocessing import Pool, cpu_count\nfrom itertools import repeat\nimport re\n\nlogging.getLogger(\"transformers\").setLevel(logging.ERROR)\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--input_dir\", type=str, help=\"Path to where your files are located. Files ending in .zst are \"\n                                                  \"treated as archives, all others as raw text.\")\nparser.add_argument(\"--files_per\", type=int, default=100000, help=\"Text files per tfrecord\")\nparser.add_argument(\"--name\", type=str, default=\"openwebtext\",\n                    help=\"Name of output files will be name_i.tfrecords where i is the number of the file\")\nparser.add_argument(\"--output_dir\", type=str, default=\"./tfrecords\", help=\"Where to put tfrecords\")\nparser.add_argument(\"--encoder_path\", type=str,\n                    help=\"Path to encoder files, or leave unspecified to use GPT2 tokenizer\")\nparser.add_argument(\"--minimum_size\", type=int, default=100, help=\"Minimum size a document has to be to be included\")\nparser.add_argument(\"--ftfy\", action=\"store_false\", help=\"normalize with ftfy\")\nparser.add_argument(\"--wikitext-detokenize\", action=\"store_false\", help=\"use wikitext detokenizer\")\nparser.add_argument(\"--separator\", nargs=\"+\", type=int, default=[50256],\n                    help=\"separator to place between files in chunk mode\")\nparser.add_argument(\"--chunk_size\", type=int, default=2048, help=\"How big a chunk should be in chunk mode. \"\n                                                                 \"Should equal your model's context size\")\nparser.add_argument(\"--write_dataset_config\", action=\"store_true\", help=\"Write the dataset config file on completion\")\nparser.add_argument(\"--processes\", type=int, default=0, help=\"Number of processes to use. Defaults to cpu count.\")\n\nargs = parser.parse_args()\nif not args.output_dir.endswith(\"/\"):\n    args.output_dir = args.output_dir + \"/\"\nif not args.input_dir.endswith(\"/\"):\n    args.input_dir = args.input_dir + \"/\"\nassert len(args.separator) == 1\n\n\ndef wikitext_detokenizer(string):\n    # contractions\n    string = string.replace(\"s '\", \"s'\")\n    string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n    # number separators\n    string = string.replace(\" @-@ \", \"-\")\n    string = string.replace(\" @,@ \", \",\")\n    string = string.replace(\" @.@ \", \".\")\n    # punctuation\n    string = string.replace(\" : \", \": \")\n    string = string.replace(\" ; \", \"; \")\n    string = string.replace(\" . \", \". \")\n    string = string.replace(\" ! \", \"! \")\n    string = string.replace(\" ? \", \"? \")\n    string = string.replace(\" , \", \", \")\n    # double brackets\n    string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n    string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n    string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n    string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n    string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n    # miscellaneous\n    string = string.replace(\"= = = =\", \"====\")\n    string = string.replace(\"= = =\", \"===\")\n    string = string.replace(\"= =\", \"==\")\n    string = string.replace(\" \" + chr(176) + \" \", chr(176))\n    string = string.replace(\" \\n\", \"\\n\")\n    string = string.replace(\"\\n \", \"\\n\")\n    string = string.replace(\" N \", \" 1 \")\n    string = string.replace(\" 's\", \"'s\")\n\n    return string\n\n\ndef _int64_feature(value):\n    \"\"\"\n    Returns an int64_list from a bool / enum / int / uint.\n    \"\"\"\n    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))\n\n\ndef write_to_file(writer, data):\n    \"\"\"\n    writes data to tfrecord file\n    \"\"\"\n    feature = {\n        \"text\": _int64_feature(data)\n    }\n    tf_example = tf.train.Example(features=tf.train.Features(feature=feature))\n    writer.write(tf_example.SerializeToString())\n\n\ndef get_tokenizer(args):\n    if args.encoder_path is None:\n        return GPT2TokenizerFast.from_pretrained('gpt2')\n    else:\n        return Tokenizer.from_file(args.encoder_path)\n\n\ndef split_list(l, n):\n    # splits list/string into n size chunks\n    return [l[i:i + n] for i in range(0, len(l), n)]\n\n\ndef archive_to_tokens(f, encoder, args, prefix=[]):\n    # Generator that yields the contents of the files in an archive\n    # if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data\n    reader = Reader(f)\n    for doc in reader.stream_data(threaded=False):\n        if args.ftfy:  # fix text with ftfy if specified\n            doc = ftfy.fix_text(doc, normalization='NFKC')\n        if args.wikitext_detokenize:\n            doc = wikitext_detokenizer(doc)\n        doc = encoder.encode(doc) + args.separator  # read document from lmd and append separator token\n        yield split_list(prefix + doc, args.chunk_size)  # split into n_ctx + 1 size chunks\n        prefix = []\n\n\ndef write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):\n    # writes a list of files to .tfrecords\n    if files == None:\n        return\n    chunks = split_list(files, files_per)\n    if not chunks:\n        return\n      \n    if len(chunks[-1]) != files_per and not write_remainder:  # pop the last file if it's length != files per\n        remainder = chunks.pop(-1)\n    else:\n        remainder = None  # assuming files = remainder from an old chunk here\n        files_per = len(chunks[-1])\n\n    for files in chunks:\n        fp = f\"{output_dir}/{out_name}_{start_no}\"\n        if process_no is not None:\n            fp += f\"_{process_no}\"\n        fp += f\"_{files_per}\"  # add number of files in tfrecord to end of fp\n        fp += \".tfrecords\"\n        with tf.io.TFRecordWriter(fp) as writer:\n            for f in files:\n                write_to_file(writer, f)\n        start_no += 1\n    return start_no, remainder\n\n\ndef get_files(input_dir, filetypes=None):\n    # gets all files of <filetypes> in input_dir\n    if filetypes == None:\n        filetypes = [\"jsonl.zst\", \".txt\", \".xz\", \".tar.gz\"]\n    files = [list(Path(input_dir).glob(f\"*{ft}\")) for ft in filetypes]\n    # flatten list of list -> list and stringify Paths\n    flattened_list = [str(item) for sublist in files for item in sublist]\n    if not flattened_list:\n        raise Exception(f\"\"\"did not find any files at this path {input_dir},\\\n please also ensure your files are in format {filetypes}\"\"\")\n    return flattened_list\n\n\ndef read_checkpoint(checkpoint_path, resume_from_checkpoint=True):\n    # init checkpointing\n    if resume_from_checkpoint and os.path.isfile(checkpoint_path):\n        try:\n            resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, \"r\").read().split(\", \")]\n            print(f\"\\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}\")\n            return resume_files_processed, tfrecord_count\n        except:\n            pass\n    return 0, 0\n\n\ndef create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False,\n                     resume_from_checkpoint=False, display_pbar=False):\n    # iterates through files in input_dir, splitting into <args.chunk_size> chunks and saving a tfrecords file every <args.files_per> chunks.\n    files, args, process_no = params\n    enc = get_tokenizer(args)  # get tokenizer\n\n    # init metadata\n    discarded_files = 0\n    files_processed = 0\n    pbar = tqdm(desc=f\"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written \",\n                disable=not display_pbar)\n    checkpoint_path = f\"{args.output_dir}/checkpoint.txt\"\n    resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint)\n\n    data_to_prepend = []\n    tokenized_files_array = []\n\n    for f in files:\n        for tokenized_files in archive_to_tokens(f, enc, args, prefix=data_to_prepend):\n            files_processed += 1\n            if files_processed < resume_files_processed:\n                continue  # resume from checkpoint\n\n            # if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file\n            data_to_prepend = []\n            n_tokens = len(tokenized_files[-1])\n            if n_tokens < args.chunk_size:\n                data = tokenized_files.pop(-1)\n                if n_tokens >= args.minimum_size:\n                    data_to_prepend = data\n                else:\n                    discarded_files += 1\n\n            # add tokenized files > chunk size to main array\n            tokenized_files_array.extend(tokenized_files)\n\n            if len(tokenized_files_array) >= args.files_per * write_every_n_files:  # write every n files\n                _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,\n                                                         output_dir=args.output_dir, out_name=args.name,\n                                                         start_no=tfrecord_count, process_no=process_no)\n                pbar.update(_tfrecord_count - tfrecord_count)  # update progress bar\n                pbar.set_description(\n                    f\"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written \")\n                tfrecord_count = _tfrecord_count\n                tokenized_files_array = remainder if remainder is not None else []  # add remaining files to next chunk\n                with open(checkpoint_path, \"w\") as checkpoint_file:\n                    checkpoint_file.write(f\"{files_processed}, {tfrecord_count}\")\n\n    if len(tokenized_files_array) >= args.files_per:  # also write at end\n        _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,\n                                                 output_dir=args.output_dir, out_name=args.name,\n                                                 start_no=tfrecord_count, process_no=process_no)\n        pbar.update(_tfrecord_count - tfrecord_count)\n        pbar.set_description(\n            f\"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written \")\n        tfrecord_count = _tfrecord_count\n        with open(checkpoint_path, \"w\") as checkpoint_file:\n            checkpoint_file.write(f\"{files_processed}, {tfrecord_count}\")\n    else:\n        remainder = tokenized_files_array  # add remaining to remainder\n\n    if write_remainder:\n        # write out the remaining files even if there's less than files_per\n        write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name,\n                    start_no=tfrecord_count, write_remainder=True)\n\n    successful_files = files_processed - discarded_files\n    return {\"discarded\": discarded_files, \"processed\": files_processed, \"successful\": successful_files}\n\n\ndef create_tfrecords_mp(files, args):\n    files = split_list(files, len(files) // args.processes)\n    with Pool(processes=args.processes) as pool:\n        pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files)))))\n        meta = {\"discarded\": 0, \"processed\": 0, \"successful\": 0}\n        for results in pbar:\n            pbar.update()\n            for k, v in results.items():\n                meta[k] += v  # update metadata\n        return meta\n\n\nif __name__ == \"__main__\":\n    os.makedirs(args.output_dir, exist_ok=True)  # make output dir if it doesn't exist\n    files = get_files(args.input_dir)\n    args.chunk_size += 1  # we shift the data by 1 to the right for targets, so increment the chunk size here\n\n    if args.processes == 0:\n        args.processes = cpu_count()\n    if args.processes > 1:\n        results = create_tfrecords_mp(files, args)\n    else:\n        results = create_tfrecords((files, args, 0), display_pbar=True)\n    print(results)\n"
  },
  {
    "path": "data/encoders.py",
    "content": "from tokenizers import Tokenizer\nfrom transformers import GPT2Tokenizer, GPT2TokenizerFast\n\ndef fetch_encoder(params):\n    no_dataset = params.get('no_dataset', False)\n    if no_dataset:\n        return None\n\n    dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict\n    path = dataset[\"tokenizer_path\"]\n    is_pretrained = dataset.get(\"tokenizer_is_pretrained\", False)\n\n    if is_pretrained:\n        tok = GPT2TokenizerFast.from_pretrained(path)\n\n        # Will add a padding token id of 50257 at run-time\n        tok.add_special_tokens({'pad_token': '<|padding|>'})\n        return tok\n\n    return Tokenizer.from_file(path)\n\n\n# GPT2Tokenizer and Tokenizer have different ways of fetching token ids\ndef encode(encoder, text):\n    result = encoder.encode(text)\n    if isinstance(result, list):\n        return result\n    return result.ids\n"
  },
  {
    "path": "data/train_tokenizer.py",
    "content": "import os\nimport random\nimport argparse\nimport shutil\nfrom glob import glob\nfrom pathlib import Path\n\nfrom lm_dataformat import Reader\nfrom tokenizers import (Tokenizer, decoders, models, pre_tokenizers,\n                        processors, trainers)\nfrom tokenizers.normalizers import NFKC\nfrom tqdm import tqdm\n\n# parser\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--base_dir\", type=str, help=\"Path to where your files are located. Files ending in .zst are treated as \\\n                    archives, all others as raw text.\")\nparser.add_argument(\"--output_dir\", type=str, default=\"tokenizers\", help=\"Where to put the tokenizer\")\nparser.add_argument(\"--file_type\", type=str, choices=[\"xz\", \"txt\"], default=\"xz\", help=\"Extension of file to parse\")\nparser.add_argument(\"--vocab_size\", type=int, help=\"Size of vocabulary\", required = True)\nargs = parser.parse_args()\n\n# main script\n\ndata_path = Path(args.base_dir)\narchives = glob(str(data_path / f\"*.{args.file_type}\"))\n\nout_path = Path(args.output_dir)\n\nif os.path.exists(out_path):\n    shutil.rmtree(out_path)\n\nif not out_path.is_dir():\n    out_path.mkdir()\n\n    for arch in tqdm(archives):\n        name = os.path.basename(arch).split(\".\")[0] + \".txt\"\n        fp = out_path / name\n\n        if args.file_type == 'xz':\n            g = Reader(arch).stream_data()\n\n            with open(fp, \"w\") as f:\n                for s in g:\n                    f.write(s)\n                    f.write(\"\\n\\n\")\n        elif args.file_type == 'txt':\n            shutil.copyfile(str(arch), str(fp))\n\ndata_files = glob(str(out_path / \"*.txt\"))\ndata_files = random.sample(data_files, int(0.2 * len(data_files)))\n\nassert len(data_files) > 0, 'No data files found'\n\n# Initialize a tokenizer\ntokenizer = Tokenizer(models.BPE())\n\n# Customize pre-tokenization and decoding\ntokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)\ntokenizer.decoder = decoders.ByteLevel()\ntokenizer.post_processor = processors.ByteLevel(trim_offsets=True)\ntokenizer.normalizer = NFKC()\n\n# And then train\ntrainer = trainers.BpeTrainer(vocab_size=args.vocab_size, min_frequency=2, special_tokens=[\"<|endoftext|>\", \"<|padding|>\"])\ntokenizer.train(trainer, data_files)\n\n# And Save it\ntokenizer_path = out_path / \"byte-level-bpe.tokenizer.json\"\ntokenizer.save(str(tokenizer_path), pretty=True)\n\nprint(f'tokenizer saved at {str(tokenizer_path)}')"
  },
  {
    "path": "docker-compose.yml",
    "content": "version: '3'\nservices:\n\n  mongo:\n    image: mongo\n    ports:\n      - 127.0.0.1:27017:27017\n    environment:\n      MONGO_INITDB_ROOT_USERNAME: user\n      MONGO_INITDB_ROOT_PASSWORD: password\n      MONGO_INITDB_DATABASE: db\n    expose:\n      - 27017\n    networks:\n      - omniboard\n    volumes:\n      - ./data:/data/db\n\n  mongoClientTemp:\n   image: mongo:latest\n   container_name: mongoClientTemp\n   links:\n    - mongo:mongo\n   command: mongo --host mongo -u user -p password --eval  \"db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});\"\n   depends_on:\n    - mongo\n   networks:\n    - omniboard\n\n  omniboard_readonly:\n          #image: vivekratnavel/omniboard:latest\n    build: https://github.com/lucidrains/omniboard.git\n    command: [\"--mu\", \"mongodb://readonly:password@mongo:27017/db\"]\n    ports:\n            - 0.0.0.0:8081:9000\n    networks:\n      - omniboard\n    depends_on:\n      - mongo\n\n  omniboard:\n          #image: vivekratnavel/omniboard:latest\n    build: https://github.com/lucidrains/omniboard.git\n    command: [\"--mu\", \"mongodb://user:password@mongo:27017/db?authSource=admin\"]\n    expose:\n      - 9000\n    networks:\n      - omniboard\n    depends_on:\n      - mongo\n\n  nginx:\n    image: dhswt/nginx-basic-auth:1.3\n    environment:\n      - HTPASSWD=isaac: #put passwd here\n      - FORWARD_HOST=omniboard\n      - FORWARD_PORT=9000\n    networks:\n      - omniboard\n    depends_on:\n      - omniboard\n    ports:\n            - 0.0.0.0:8080:80\n    expose:\n      - 8080\nnetworks:\n  omniboard:\n"
  },
  {
    "path": "encoders.py",
    "content": "from tokenizers import Tokenizer\nfrom transformers import GPT2Tokenizer, GPT2TokenizerFast\n\ndef fetch_encoder(params):\n    no_dataset = params.get('no_dataset', False)\n    if no_dataset:\n        return None\n\n    dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict\n    path = dataset[\"tokenizer_path\"]\n    is_pretrained = dataset.get(\"tokenizer_is_pretrained\", False)\n\n    if is_pretrained:\n        tok = GPT2TokenizerFast.from_pretrained(path)\n\n        # Will add a padding token id of 50257 at run-time\n        tok.add_special_tokens({'pad_token': '<|padding|>'})\n        return tok\n\n    return Tokenizer.from_file(path)\n\n\n# GPT2Tokenizer and Tokenizer have different ways of fetching token ids\ndef encode(encoder, text, gpt=True):\n    result = encoder.encode(text)\n    if isinstance(result, list):\n        return result\n    return result.ids\n"
  },
  {
    "path": "export.py",
    "content": "import tensorflow.compat.v1 as tf\n\ndef export_model(estimator, export_dir, params,\n                 checkpoint_path=None):\n\n\n    def serving_input_receiver_fn():\n        t = tf.placeholder(dtype=tf.int64,\n                            shape=[1, params[\"n_ctx\"]],\n                            name='input_example_tensor')\n        return tf.estimator.export.ServingInputReceiver(t, t)\n\n    return estimator.export_saved_model(\n        export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)"
  },
  {
    "path": "inputs.py",
    "content": "import numpy as np\nimport tensorflow.compat.v1 as tf\nfrom functools import partial\nfrom data.encoders import encode\nimport random\nimport re\nimport logging\nfrom itertools import cycle\nfrom utils import natural_sort\n\n\n### IN USE ###\n\ndef _get_number_of_documents(filename):\n    # extracts number of files from a filename formatted \"<name>_<num_documents>.tfrecords.\"\n    # if no pattern is matched, returns None\n    match = re.search(\"_(\\d{1,}).tfrecords$\", filename)\n    return int(match.group(1)) if match is not None else match\n\n\ndef _get_number_of_documents_by_iteration(filename):\n    # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename\n    # this could be very slow.\n    logging.warning(\n        \"inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length\")\n    count = 0\n    for item in tf.io.tf_record_iterator(filename):\n        count += 1\n    return count\n\n\ndef _get_skip_index(all_files, n_batches):\n    prev_cumsum = 0\n    cumsum = 0\n    global_n_documents = None\n    for count, f in cycle(enumerate(all_files)):\n        prev_cumsum = cumsum\n        if _get_number_of_documents(f) is not None:\n            cumsum += _get_number_of_documents(f)\n        elif global_n_documents is None:\n            global_n_documents = _get_number_of_documents_by_iteration(f)\n            cumsum += global_n_documents\n        else:\n            cumsum += global_n_documents\n        if cumsum == n_batches:\n            remainder = 0\n            skip_idx = count + 1\n        elif cumsum > n_batches:\n            remainder = n_batches - prev_cumsum\n            skip_idx = count\n            break\n    return skip_idx, remainder\n\n\ndef _parse_function(example_proto):\n    features = {\n        \"text\": tf.VarLenFeature(tf.int64)\n    }\n    parsed_features = tf.parse_single_example(example_proto, features)\n    return tf.sparse.to_dense(parsed_features[\"text\"], parsed_features[\"text\"].dense_shape[0])\n\n\ndef autoregressive_sample_text(params, x):\n    vals1 = x[:params[\"n_ctx\"]]\n    vals2 = x[1:params[\"n_ctx\"] + 1]\n\n    vals1 = tf.reshape(vals1, [params[\"n_ctx\"]])\n    vals2 = tf.reshape(vals2, [params[\"n_ctx\"]])\n    vals1 = tf.cast(vals1, dtype=tf.int32)\n    vals2 = tf.cast(vals2, dtype=tf.int32)\n    return vals1, vals2\n\n\ndef sequential_input(params, global_step=None, eval=False):\n    \"\"\"\n    Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:\n\n        - has the number of documents for each tfrecord file encoded in the title in the format\n          <name>_<n_documents>.tfrecords.\n\n          OR\n\n        - has a fixed number of documents per tfrecord file.\n\n    If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.\n    If this isn't the case, it may result in errors, or some samples being missed.\n\n    This means we can calculate the number of samples we've seen so far using the global step,\n    and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.\n\n    If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model\n    performance, as it results in less repeated data.\n    \"\"\"\n    if not eval:\n        assert global_step is not None\n    logging.warning(\n        \"Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.\")\n    batch_size = params['eval_batch_size' if eval else 'train_batch_size']\n\n    filenames = []\n    for dataset_config in params['dataset_configs'].values():  # iterate through each dataset and read params\n        path_key = 'path' if not eval else 'eval_path'\n        path = dataset_config[path_key]\n        filenames.extend(\n            tf.io.gfile.glob(path))  # then glob all files that fit the pattern specified in dataset_configs\n\n    filenames = natural_sort(filenames)\n    shuffle_filenames = params.get(\"shuffle_input_filenames\", True)\n    if shuffle_filenames:\n        seed = params.get('seed', 1)  # shuffle deterministically\n        random.seed(seed)\n        random.shuffle(filenames)\n\n    dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat()  # repeat filenames to infinity\n\n    if not eval:\n        # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files\n        skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[\n            \"train_batch_size\"])  # TODO: fix for > 1 epoch\n        dataset = dataset.skip(skip_idx)  # skip to skip idx\n\n        # read tfrecord examples and skip remainder\n        dataset = dataset.apply(tf.data.TFRecordDataset)\n        dataset = dataset.skip(remainder)\n    else:\n        # shuffle filenames if in eval mode\n        dataset = dataset.shuffle(len(filenames))\n        dataset = dataset.apply(tf.data.TFRecordDataset)\n\n    # parse the tokenized data from the tfrecord files and shuffle\n    dataset = dataset.map(_parse_function, num_parallel_calls=1)\n    dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)\n\n    # batch data and repeat to infinity\n    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params[\"iterations\"] * 2)\n    return dataset.repeat()\n\n\ndef pred_input(params, logger, enc=None,\n               path_to_prompt=\"\"):\n    unicorns = \"In a shocking finding, scientists discovered a herd of unicorns living in a remote, \" \\\n               \"previously unexplored valley, in the Andes Mountains. Even more surprising to the \" \\\n               \"researchers was the fact that the unicorns spoke perfect English.\"\n\n    text = unicorns if path_to_prompt == \"\" else open(path_to_prompt, \"r\").read()\n    tokens = encode(enc, text)\n\n    if len(tokens) > params[\"n_ctx\"]:\n        logger.info(\"The length of your input prompt is longer than the model's context length - truncating input.\")\n        tokens = tokens[len(tokens) - params[\"n_ctx\"]:]\n    if len(tokens) < params[\"n_ctx\"]:\n        tokens = tf.pad(tokens, [[0, params[\"n_ctx\"] - len(tokens)]], constant_values=params[\"padding_id\"])\n    t = tf.broadcast_to(tokens, [params[\"batch_size\"], params[\"n_ctx\"]])\n    dataset = tf.data.Dataset.from_tensors(t)\n\n    def _dummy_labels(x):\n        return x, x\n\n    dataset = dataset.map(_dummy_labels)\n    return dataset\n\n\ndef handle_pred_output(predictions, logger, enc, params, out_name=\"test\"):\n    with tf.gfile.Open(f\"{out_name}.txt\", \"w\") as f:\n        for i, p in enumerate(predictions):\n            p = p[\"outputs\"]\n\n            # remove eos + padding ids from output\n            idx = np.argmax(p == params['eos_id'])\n            if idx > 0:\n                p = p[:idx]\n            idx = np.argmax(p == params['padding_id'])\n            if idx > 0:\n                p = p[:idx]\n\n            text = enc.decode(p)\n            f.write(\"=\" * 40 + \" SAMPLE \" + str(i) + \" \" + \"=\" * 40 + \"\\n\")\n            f.write(text)\n            f.write(\"\\n\" + \"=\" * 80 + \"\\n\")\n\n            logger.info(\"=\" * 40 + \" SAMPLE \" + str(i) + \" \" + \"=\" * 40 + \"\\n\")\n            logger.info(text)\n            logger.info(\"\\n\" + \"=\" * 80 + \"\\n\")\n\n\n### DEPRECATED ###\n\ndef generic_text(params, eval=False, sample_text_fn=None, **kwargs):\n    logging.warning(\"DEPRECATION WARNING: generic_text will be phased out in future versions.\")\n    i = 0 if not eval else 1\n\n    weights = []\n    datasets = []\n\n    for dataset in params[\"datasets\"]:\n        dataset_id, stitch, datatype, weight = dataset\n\n        assert dataset_id in params[\n            'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'\n        dataset_config = params['dataset_configs'][dataset_id]\n\n        path_key = 'path' if not eval else 'eval_path'\n        path = dataset_config[path_key]\n\n        datasets.append(text_dataset(\n            tf.io.gfile.glob(path),\n            params,\n            stitch=stitch,\n            datatype=datatype,\n            batch=False,\n            sample_text_fn=sample_text_fn\n        ))\n\n        weights.append(weight)\n\n    batch_size = params['eval_batch_size' if eval else 'train_batch_size']\n\n    seed = params.get('seed', None)\n    dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)\n    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params[\"iterations\"] * 2)\n    return dataset\n\n\ndef text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):\n    seed = params.get('seed', None)\n    deterministic = seed is not None\n    num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE\n\n    dataset = tf.data.Dataset.from_tensor_slices(files)\n\n    if deterministic:\n        dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)\n    else:\n        dataset = dataset.apply(\n            tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))\n\n    if \"documents\" in datatype:\n        def _parse_function(example_proto):\n            features = {\n                # \"hash\": tf.VarLenFeature(tf.string),\n                \"text\": tf.VarLenFeature(tf.int64)\n            }\n            parsed_features = tf.parse_single_example(example_proto, features)\n            return parsed_features[\"text\"], parsed_features[\"text\"].dense_shape[0]\n    else:\n        def _parse_function(example_proto):\n            features = {\n                \"text\": tf.VarLenFeature(tf.int64)\n            }\n            parsed_features = tf.parse_single_example(example_proto, features)\n            return parsed_features[\"text\"]  # Assuming the text is not sparse\n\n    dataset = dataset.map(_parse_function, num_parallel_calls=1)\n\n    # Subsample method\n    if \"documents\" in datatype:\n        # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples\n        # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that\n        # stitch * min(characters_in_text) >= amount\n        def _stitch_text(x, y):\n            x = tf.sparse.to_dense(x)\n\n            def _get_x(i):\n                return tf.gather(x[i], tf.range(y[i]))\n\n            out = _get_x(0)\n            eos_id = params['eos_id']\n\n            for i in range(1, stitch):\n                out = tf.concat([out, [eos_id], _get_x(i)], axis=0)  # text1<|endoftext|>text2\n\n            return out\n\n        # Hack-y way to stitch together multiple texts\n\n        dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,\n                                                                                                   num_parallel_calls=num_parallel_calls)\n\n        # Sample 1024(+1) tokens from the stitched together text\n        is_random_documents = datatype == \"documents_random\"\n        if sample_text_fn is not None:\n            _sample_text = partial(sample_text_fn, random_documents=is_random_documents)\n        else:\n            _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text\n            _sample_text = partial(_sample_text, params)\n\n        dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)\n\n    if batch:\n        dataset = dataset.batch(params[\"train_batch_size\"], drop_remainder=True).prefetch(params[\"iterations\"] * 2)\n\n    dataset = dataset.repeat()\n\n    return dataset\n\n\ndef autoregressive_sample_text_random_documents(params, x):\n    seed = params.get('seed', None)\n    s = tf.size(x)\n    r = tf.random.uniform([], maxval=s - (params[\"n_ctx\"] + 1), dtype=tf.dtypes.int32, seed=seed)\n    r1 = tf.range(r, r + params[\"n_ctx\"])\n    r2 = tf.range(r + 1, (r + 1) + params[\"n_ctx\"])\n    r1 = tf.reshape(r1, [params[\"n_ctx\"]])  # Somehow, this makes the compiler happy\n    r2 = tf.reshape(r2, [params[\n                             \"n_ctx\"]])  # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input\n    vals1 = tf.gather(x, r1)\n    vals2 = tf.gather(x, r2)\n\n    vals1 = tf.reshape(vals1, [params[\"n_ctx\"]])\n    vals2 = tf.reshape(vals2, [params[\"n_ctx\"]])\n    vals1 = tf.cast(vals1, dtype=tf.int32)\n    vals2 = tf.cast(vals2, dtype=tf.int32)\n    return vals1, vals2\n\n\ndef mlm_sample_text(params, x, random_documents=False):\n    seed = params.get('seed', None)\n    ctx_len = params[\"n_ctx\"]\n    assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'\n\n    mask_id = params['mlm_mask_id']\n    cls_token_id = params.get('mlm_cls_token_id', None)\n    num_tokens = params.get('n_vocab', None)\n\n    mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))\n    mask_ignore_ids.add(cls_token_id)\n\n    mask_prob = params.get('mlm_mask_prob', 0.15)\n    same_token_prob = params.get('mlm_same_token_prob', 0.10)\n    random_token_prob = params.get('mlm_random_token_prob', 0.)\n\n    seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)\n\n    if random_documents:\n        s = tf.size(x)\n        r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)\n        r1 = tf.range(r, r + seq_len)\n        r1 = tf.reshape(r1, [seq_len])\n        features = tf.gather(x, r1)\n    else:\n        features = x[:seq_len]\n\n    # add cls token id if specified by `mlm_cls_token_id`\n    if cls_token_id is not None:\n        features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)\n\n    features = tf.cast(features, dtype=tf.int32)\n    shape = features.shape\n\n    # determine which tokens are mask-able\n    can_mask = tf.not_equal(features, 0)\n    for ignore_id in mask_ignore_ids:\n        can_mask &= tf.not_equal(features, ignore_id)\n\n    # generate boolean mask for masking ids\n    mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)\n    mask_mask &= can_mask\n\n    # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same\n    replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),\n                           1 - same_token_prob)\n\n    # randomly replace some tokens with random tokens before masking\n    if random_token_prob > 0:\n        random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),\n                                    random_token_prob)\n        random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)\n\n        # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`\n        random_can_mask = tf.not_equal(random_tokens, 0)\n        for ignore_id in mask_ignore_ids:\n            random_can_mask &= tf.not_equal(random_tokens, ignore_id)\n\n        features = tf.where(random_token_mask & random_can_mask, random_tokens, features)\n\n    # mask the tokens\n    mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id\n    masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)\n\n    # labels will be set to 0 for all non-masked tokens\n    labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)\n\n    masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))\n    return masked_features, labels\n"
  },
  {
    "path": "main.py",
    "content": "\"\"\"GPT-like model in Mesh-Tensorflow\"\"\"\n\nfrom functools import partial\nimport mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.python.tpu import tpu_config, tpu_estimator\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \\\n    check_dataset\nfrom inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text\nfrom export import export_model\nfrom model_fns import model_fn\nfrom data.encoders import fetch_encoder\nfrom configs import fetch_model_params\nfrom tasks import task_descriptors\nimport argparse\nimport json\nimport numpy\n\n\ndef parse_args():\n    # Parse command line arguments\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--tpu\", type=str, help=\"Name of TPU to train on, if any.\")\n    parser.add_argument(\"--gpu_ids\", nargs=\"+\", type=str, default=[\"device:GPU:0\"],\n                        help=\"If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'\")\n    parser.add_argument(\"--model\", type=str, default=None, help=\"JSON file that contains model parameters.\")\n    parser.add_argument(\"--steps_per_checkpoint\", type=int, default=5000, help=\"Save a model checkpoint every X steps.\")\n    parser.add_argument(\"--auto_layout\", action=\"store_true\", help=\"If set, generates and prints the most memory \"\n                                                                   \"efficient layout according to MTF auto layout.\")\n    parser.add_argument(\"--auto_layout_and_mesh_shape\", action=\"store_true\",\n                        help=\"If set, generates and prints the most memory efficient layout and mesh shape according to\"\n                             \" MTF auto layout.\")\n    parser.add_argument(\"--new\", action=\"store_true\", help=\"If set, deletes previous checkpoint, if it exists, and \"\n                                                           \"starts a new training run\")\n    parser.add_argument(\"--predict\", action=\"store_true\", help=\"If set, uses the model to predict rather than train.\")\n    parser.add_argument(\"--eval\", action=\"store_true\", help=\"If set, run model in evaluation mode.\")\n    parser.add_argument(\"--prompt\", type=str, help=\"path to .txt file containing a prompt for prediction. If empty, \"\n                                                   \"defaults to unicorns.\",\n                        default=\"\")\n    parser.add_argument(\"--check_dataset\", action=\"store_true\",\n                        help=\"If set, outputs sample from the dataset and quits.\")\n    parser.add_argument(\"--sacred_id\", type=str, default=\"nosacred\", help=\"Sacred run id.\")\n    parser.add_argument(\"--entmax_sampling\", action=\"store_true\", help=\"(experimental) use entmax sampling\")\n    parser.add_argument(\"--export\", action=\"store_true\", help=\"If set, will export the model.\")\n    args = parser.parse_args()\n    assert args.model is not None, \"Model must be set\"\n    return args\n\n\ndef main(args):\n    # Setup logging\n    logger = setup_logging(args)\n\n    # Read params of model\n    params = fetch_model_params(args.model)\n\n    # Fetch appropriate input functions\n    input_fn = params.get(\"input_fn\", \"sequential_input\")\n    if input_fn == \"sequential_input\":\n        input_fn = sequential_input\n    elif input_fn == \"generic_text\":\n        input_fn = generic_text\n    pred_input_fn = pred_input\n    handle_pred_output_fn = handle_pred_output\n\n    # get current step\n    current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params[\"model_path\"]))\n    logger.info(f\"Current step {current_step}\")\n\n    if params[\"mlm_training\"]:\n        mlm_sample_text_fn = partial(mlm_sample_text, params)\n        input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)\n        if args.check_dataset:\n            check_dataset(input_fn, params)\n\n\n    # Fetch encoder per params\n    encoder = fetch_encoder(params)\n\n    pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder)\n\n    # Sample from Dataset if check dataset flag is on\n    if args.check_dataset:\n        check_dataset(input_fn, params, global_step=current_step)\n\n    # Confirm deletion of checkpoint files if --new flag is set\n    if args.new:\n        if yes_or_no(f\"Are you sure you want to remove '{params['model_path']}' to start afresh?\"):\n            remove_gs_or_filepath(params[\"model_path\"])\n        else:\n            exit()\n\n    # Save config to logdir for experiment management\n    save_config(params, params[\"model_path\"])\n\n    # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores\n    mesh_shape = mtf.convert_to_shape(params[\"mesh_shape\"])\n    params[\"num_cores\"] = mesh_shape.size\n    params[\"auto_layout\"] = args.auto_layout\n    params[\"auto_layout_and_mesh_shape\"] = args.auto_layout_and_mesh_shape\n    params[\"use_tpu\"] = True if not args.tpu is None else False\n    params[\"gpu_ids\"] = args.gpu_ids\n    params[\"steps_per_checkpoint\"] = args.steps_per_checkpoint\n    # Expand attention types param\n    params[\"attention_types\"] = expand_attention_types_params(params[\"attention_types\"])\n    assert len(params[\"attention_types\"]) == params[\"n_layer\"]  # Assert that the length of expanded list = num layers\n    params[\"predict_batch_size\"] = params.get(\"predict_batch_size\", 1)  # Default to 1\n    params[\"predict\"] = args.predict\n    params['model'] = params.get(\"model\", \"GPT\") # Default model selection to GPT since it's the only option for now\n    params[\"export\"] = args.export\n    # Set sampling parameters\n    params[\"sampling_use_entmax\"] = args.entmax_sampling\n\n    # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if\n    # moe layers are present\n    params[\"slow_sampling\"] = True if params[\"moe_layers\"] is not None else False\n\n    logger.info(f\"params = {params}\")\n\n    # Get eval tasks from params\n    eval_tasks = params.get(\"eval_tasks\", [])\n    has_predict_or_eval_steps_or_eval_tasks = params[\"predict_steps\"] > 0 or params[\"eval_steps\"] > 0 or len(\n        eval_tasks) > 0\n\n    for t in eval_tasks:\n        assert t in task_descriptors, f\"Eval task '{t}' is not known\"\n        task_descriptors[t][\"init_fn\"](params)\n\n    # Set up TPUs and Estimator\n    if args.tpu == \"colab\":\n        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params[\"use_tpu\"] else None\n    else:\n        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params[\"use_tpu\"] else None\n\n    config = tpu_config.RunConfig(\n        cluster=tpu_cluster_resolver,\n        model_dir=params[\"model_path\"],\n        save_checkpoints_steps=None,  # Disable the default saver\n        save_checkpoints_secs=None,  # Disable the default saver\n        log_step_count_steps=params[\"iterations\"],\n        save_summary_steps=params[\"iterations\"],\n        tpu_config=tpu_config.TPUConfig(\n            num_shards=mesh_shape.size,\n            iterations_per_loop=params[\"iterations\"],\n            num_cores_per_replica=1,\n            per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))\n\n    estimator = tpu_estimator.TPUEstimator(\n        use_tpu=params[\"use_tpu\"],\n        model_fn=model_fn,\n        config=config,\n        train_batch_size=params[\"train_batch_size\"],\n        eval_batch_size=params[\"train_batch_size\"],\n        predict_batch_size=params[\"predict_batch_size\"],\n        params=params)\n\n    def _make_task_estimator(task):\n        task_params = params.copy()\n        task_params[\"eval_task\"] = task\n        return tpu_estimator.TPUEstimator(\n            use_tpu=params[\"use_tpu\"],\n            model_fn=model_fn,\n            config=config,\n            train_batch_size=params[\"train_batch_size\"],\n            eval_batch_size=params[\"eval_batch_size\"],\n            predict_batch_size=params[\"predict_batch_size\"],\n            params=task_params)\n\n    eval_task_estimators = {\n        task: _make_task_estimator(task)\n        for task in eval_tasks\n    }\n\n    if args.export:\n        export_model(estimator, \"export\", params)\n        return\n\n    if args.predict:\n        # Predict\n        predictions = estimator.predict(input_fn=pred_input_fn)\n        logger.info(\"Predictions generated\")\n        enc = fetch_encoder(params)\n        handle_pred_output_fn(predictions, logger, enc, params, out_name=f\"predictions_{args.sacred_id}_{current_step}\")\n        return\n\n    def save_eval_results(task, eval_results):\n        def as_python(x):\n            if isinstance(x, numpy.generic):\n                return x.item()\n            return x\n        eval_results = {k: as_python(v) for k, v in eval_results.items()}\n        with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh:\n            json.dump({'task': task, 'current_step': current_step, **eval_results}, fh)\n            fh.write('\\n')\n\n    def run_eval():\n        logger.info(\"Running evaluation...\")\n        eval_results = estimator.evaluate(\n                input_fn=partial(input_fn, eval=True),\n                steps=params[\"eval_steps\"])\n        logger.info(f\"Eval results: {eval_results}\")\n        save_eval_results('validation', eval_results)\n\n    def run_eval_tasks():\n        for task in eval_tasks:\n            logger.info(f\"Starting evaluation task '{task}'\")\n            task_info = task_descriptors[task][\"get_task_info_fn\"](params)\n            task_estimator = eval_task_estimators[task]\n            task_input_fn = task_descriptors[task][\"input_fn\"]\n            eval_results = task_estimator.evaluate(\n                input_fn=task_input_fn,\n                steps=task_info[\"n_steps\"],\n                name=task)\n            logger.info(f\"Eval task '{task}' results: {eval_results}\")\n            save_eval_results(task, eval_results)\n    \n    if args.eval:\n        run_eval_tasks()\n        if params[\"eval_steps\"] > 0:\n            run_eval()\n        return\n\n\n    elif has_predict_or_eval_steps_or_eval_tasks:\n        # Eval and train - stop and predict and/or eval every checkpoint\n        while current_step < params[\"train_steps\"]:\n            next_checkpoint = min(current_step + args.steps_per_checkpoint,\n                                  params[\"train_steps\"])\n\n            estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint)\n            current_step = next_checkpoint\n\n            if params[\"predict_steps\"] > 0:\n                logger.info(\"Running prediction...\")\n                predictions = estimator.predict(input_fn=pred_input_fn)\n                enc = fetch_encoder(params)\n                handle_pred_output_fn(predictions, logger, enc, params, out_name=f\"predictions_{args.sacred_id}_{current_step}\")\n\n            if params[\"eval_steps\"] > 0:\n                run_eval()\n\n            if eval_tasks:\n                run_eval_tasks()\n                \n        return\n    else:\n        # Else, just train\n        while current_step < params[\"train_steps\"]:\n            # Else, don't stop and restart\n            estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params[\"train_steps\"])\n\n\nif __name__ == \"__main__\":\n    tf.disable_v2_behavior()\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "model_fns.py",
    "content": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.python.tpu import tpu_estimator\nimport mesh_tensorflow.transformer as mtf_transformer\nfrom optimizers import get_optimizer\nfrom utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,\n                   get_batch_size, auto_layout, auto_layout_and_mesh_shape)\nfrom models.utils import biasmask_attn_weights\nfrom tensorflow.python.ops import resources\nfrom sample import sample_autoregressive\nfrom models.gpt2 import gpt2\nimport math\n\n\ndef model_fn(features, labels, mode, params):\n    # Get global step\n    global_step = tf.train.get_global_step()\n\n    # Construct mtf graph + mesh from params\n    graph = mtf.Graph()\n    mesh_shape = mtf.convert_to_shape(params[\"mesh_shape\"])\n    layout_rules = mtf.convert_to_layout_rules(params[\"layout\"])\n\n    # Mesh setup\n    if params[\"use_tpu\"]:\n        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)\n    else:\n        var_placer = None\n        gpu_ids = params[\"gpu_ids\"]\n        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(\n            mesh_shape, layout_rules, gpu_ids)\n\n    # Trainable variable precision\n    # Store to checkpoints in master type, train in slice type, compute in activation type\n    if params[\"precision\"] == \"bfloat16\":\n        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,\n                                           activation_dtype=tf.bfloat16)\n    else:\n        variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)\n\n    # Build mtf mesh object\n    mesh = mtf.Mesh(graph, \"my_mesh\", var_placer)\n\n    # Build mtf_features & seq length dict for getting number of microbatches\n    # We need to pack inputs into a dict to pass into serialize_training_step\n    features_dict = {\"inputs\": features, \"labels\": labels}\n    sequence_length_dict = {\"inputs\": params[\"n_ctx\"], \"labels\": params[\"n_ctx\"]}\n\n    params = add_mode_to_params(params, mode)\n    batch_size = get_batch_size(params)\n\n    batch_dim = mtf.Dimension(\"batch\", batch_size)\n    batch_dims = [batch_dim]\n    feature_length = sequence_length_dict[\"inputs\"]\n    length_dim = mtf.Dimension(\"sequence\", feature_length)\n\n    mtf_features = {}\n    for key, x in features_dict.items():\n        if x is not None:\n            feature_shape = mtf.Shape(batch_dims + [length_dim])\n            if type(features_dict[key]) == dict:\n                features_dict[key] = features_dict[key][\"feature\"]\n            x = tf.cast(features_dict[key], tf.int32)\n            x = tf.reshape(x, feature_shape.to_integer_list)\n            mtf_features[key] = mtf.import_fully_replicated(\n                mesh, x, feature_shape, name=key)\n\n    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model\n    other_features = {}\n    memory_length_dim = mtf.Dimension(\"memory_length\", length_dim.size)\n\n    attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params[\"causal\"] else None\n\n    # Add attn_bias into mtf_features\n    other_features[\"attn_bias\"] = attn_bias\n\n    # Define other Dimensions that we'll need inside the model\n    embd_dim = mtf.Dimension(\"embd\", params[\"n_embd\"])\n    vocab_dim = mtf.Dimension(\"vocab\", params[\"n_vocab\"])\n    # We need this because gathering when both the args have the same dimension in them breaks things\n    # This dim is specifically for the weights\n    # This prevents the \"Einsum has lhs dimension without corresponding rhs or output dimension.\" error\n    embed_sequence_dim = mtf.Dimension(\"embed_sequence\", params[\"n_ctx\"])\n\n    other_features[\"embd_dim\"] = embd_dim\n    other_features[\"vocab_dim\"] = vocab_dim\n    other_features[\"embed_sequence_dim\"] = embed_sequence_dim\n    other_features[\"memory_length_dim\"] = memory_length_dim\n\n    if mode == tf.estimator.ModeKeys.PREDICT:\n        # Set up the model for prediction\n        inputs = mtf_features[\"inputs\"]\n        if params[\"remove_partial_sequences\"] is None:\n            params[\"remove_partial_sequences\"] = False\n\n        export = params.get(\"export\", False)\n\n        if not export:\n            mtf_samples = sample_autoregressive(\n                inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,\n                remove_partial_sequences=params[\"remove_partial_sequences\"], stop_at_token=params[\"eos_id\"],\n                sampling_use_entmax=params['sampling_use_entmax'], max_steps=params[\"predict_max_steps\"])\n\n        else:\n            with mtf.utils.outside_all_rewrites():\n                with tf.variable_scope('gpt2'):\n                    mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,\n                                                               variable_dtype=variable_dtype, context=None)\n\n        mtf_samples = mtf.anonymize(mtf_samples)\n        inputs = mtf.anonymize(inputs)\n        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)\n        inputs = lowering.export_to_tf_tensor(inputs)\n        outputs = lowering.export_to_tf_tensor(mtf_samples)\n        predictions = {\n            \"inputs\": inputs,\n            \"outputs\": outputs}\n\n        def scaffold_fn():\n            return tf.train.Scaffold(\n                local_init_op=tf.group(\n                    tf.train.Scaffold.default_local_init_op(),\n                    lowering.copy_masters_to_slices(),\n                    name=\"mtf_local_init_op\"),\n                ready_op=tf.concat(\n                    [tf.report_uninitialized_variables(),\n                     resources.report_uninitialized_resources()],\n                    axis=0,\n                    name=\"mtf_ready_op\"))\n\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=tf.estimator.ModeKeys.PREDICT,\n            predictions=predictions,\n            scaffold_fn=scaffold_fn,\n            prediction_hooks=[mtf.MtfRestoreHook(lowering)])\n\n    # We're not predicting, so we better be training or evaluating\n    assert mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]\n\n    if mode == tf.estimator.ModeKeys.TRAIN:\n        # Gets number of microbatches per batch for serialized training\n        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed\n        num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,\n                                                                                sequence_length=sequence_length_dict,\n                                                                                mesh_shape=mesh_shape,\n                                                                                layout_rules=layout_rules,\n                                                                                tokens_per_microbatch_per_replica=\n                                                                                params[\"tokens_per_mb_per_replica\"]))\n    else:\n        num_microbatches = 1\n\n    params[\"num_microbatches\"] = num_microbatches  # Add num microbatches to params\n\n    if num_microbatches > 1:\n\n        # For serialize_training_step we need to modify the model to output results in a dict\n        def serialized_fn(mtf_features):\n            if params[\"model\"] == \"GPT\":\n                with tf.variable_scope('gpt2'):\n                    logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,\n                                                          variable_dtype=variable_dtype)\n                return {\"logits\": logits, \"loss\": loss, \"loss_batch\": loss_batch}\n            else:\n                raise Exception(f\"'{params['model']}' is not a valid model - please select from [GPT]\")\n\n        # Serialize the training step - Gradients are accumulated locally and reduced once.\n        var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)\n        loss = output_dict[\"loss\"]\n        loss_batch = output_dict[\"loss_batch\"]\n        logits = output_dict[\"logits\"]\n    else:\n        # If we're not splitting into microbatches, return logits & loss as is\n        if params[\"model\"] == \"GPT\":\n            with mtf.utils.outside_all_rewrites():\n                with tf.variable_scope('gpt2'):\n                    logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,\n                                                          variable_dtype=variable_dtype, context=None)\n        else:\n            raise Exception(f\"'{params['model']}' is not a valid model - please select from [GPT]\")\n\n    # Auto layout generation\n    if params[\"auto_layout\"]:\n        auto_layout(graph, mesh_shape, logits, loss)\n    if params[\"auto_layout_and_mesh_shape\"]:\n        auto_layout_and_mesh_shape(graph, params[\"num_cores\"], logits, loss)\n\n    if mode == tf.estimator.ModeKeys.TRAIN:\n        # In TRAIN mode, get optimizer\n        if params[\"num_microbatches\"] > 1:\n            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn\n            # So we pass them in here\n            _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,\n                                                     inp_var_grads=var_grads)\n        else:\n            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank\n            _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)\n        # Log summaries to tensorboard\n        mtf.scalar_summary(\"loss\", loss)\n        # Log gradients if in params\n        if params[\"log_grads\"] not in [None, False]:\n            for g in var_grads:\n                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))\n                mtf.scalar_summary(\"grads/norm\" + g.name[:-2], grad_norm)\n    else:\n        # For now, we can only export fully-replicated tensors.\n        # This has to be done before lowering or they will not be included in the graph\n        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)\n        max_logits = mtf.argmax(logits, vocab_dim)\n        del logits\n        fully_replicated_mean_logits = mtf.anonymize(mean_logits)\n        fully_replicated_max_logits = mtf.anonymize(max_logits)\n        fully_replicated_loss_batch = mtf.anonymize(loss_batch)\n\n    # Gets & prints info about no. trainable vars in the model & dimension names\n    get_graph_info(graph)\n\n    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors\n    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)\n    tf_loss = lowering.export_to_tf_tensor(loss)\n    tf_loss = tf.cast(tf_loss, tf.float32)\n\n    if mode == tf.estimator.ModeKeys.TRAIN:\n        # Use our patched version until mtf updates theirs\n        host_call = create_host_call(params['model_path'])\n        mtf.utils.remove_summaries()\n\n        # Creates train_op\n        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]\n        tf_update_ops.append(tf.assign_add(global_step, 1))  # Need to manually increment global_step\n        tf.logging.info(f\"tf_update_ops: {tf_update_ops}\")\n        train_op = tf.group(tf_update_ops)\n    else:\n        tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)\n        tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)\n        tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))\n\n    with mtf.utils.outside_all_rewrites():\n        # Copy master variables to slices. Must be called first.\n        restore_hook = mtf.MtfRestoreHook(lowering)\n        if mode == tf.estimator.ModeKeys.TRAIN:\n            # Set up the checkpoint server and return the TPUEstimatorSpec\n            saver = tf.train.Saver(\n                tf.global_variables(),\n                sharded=True,\n                max_to_keep=10,\n                keep_checkpoint_every_n_hours=2,\n                defer_build=False,\n                save_relative_paths=True)\n            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)\n            saver_listener = mtf.MtfCheckpointSaverListener(lowering)\n            saver_hook = tf.train.CheckpointSaverHook(\n                params[\"model_path\"],\n                save_steps=params[\"steps_per_checkpoint\"],\n                saver=saver,\n                listeners=[saver_listener])\n\n            return tpu_estimator.TPUEstimatorSpec(\n                tf.estimator.ModeKeys.TRAIN,\n                loss=tf_loss,\n                host_call=host_call,\n                train_op=train_op,\n                training_hooks=[restore_hook, saver_hook])\n\n        elif mode == tf.estimator.ModeKeys.EVAL:\n            # Evaluation metrics\n            def _perplexity(loss):\n                perplexity = tf.exp(loss)\n                return tf.metrics.mean(perplexity)\n\n            def _bits_per_byte(loss):\n                bpb = loss * (0.29335 / math.log(2))\n                return tf.metrics.mean(bpb)\n\n            def _metric_fn(tf_mean_logits, tf_loss_batch):\n                mean_logits = tf.metrics.mean(tf_mean_logits)\n                loss = tf.reduce_mean(tf_loss_batch)\n                perp = _perplexity(loss)\n                bpb = _bits_per_byte(loss)\n                return {\"mean_logits\": mean_logits, \"perplexity\": perp, \"bits per byte\": bpb}\n\n            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):\n                eos_token = params[\"eos_id\"]\n                answer_positions = tf.where(tf.math.not_equal(labels, eos_token))\n\n                correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)\n                accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))\n\n                # I guess tf_loss_batch has z_loss and maybe other stuff added to it\n                # so maybe this should be calculated separately in the future\n                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)\n                log_perplexity = tf.metrics.mean(answer_loss)\n\n                return {\"lambada_acc\": accuracy, \"lambada_log_ppl\": log_perplexity}\n\n            eval_task = params[\"eval_task\"]\n            if eval_task == \"lambada\":\n                eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])\n            else:\n                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])\n\n            return tpu_estimator.TPUEstimatorSpec(\n                tf.estimator.ModeKeys.EVAL,\n                evaluation_hooks=[restore_hook],\n                loss=tf_loss,\n                eval_metrics=eval_metrics)\n"
  },
  {
    "path": "models/activations.py",
    "content": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport random\n\nBASE_FNS = {'gelu': mtf.gelu,\n            'relu': mtf.relu,\n            'sigmoid': mtf.sigmoid,\n            'tanh': mtf.tanh,\n            'selu': mtf.selu,\n            'elu': mtf.elu,\n            'abs': mtf.abs,\n            'sin': mtf.sin,\n            'cos': mtf.cos,\n            'sign': mtf.sign,\n            'silu': mtf.swish,\n            'softplus': mtf.softplus\n            }\n\n\ndef _arcsinh(x):\n    return mtf.log(x + mtf.sqrt(1 + x ** 2))\n\n\ndef _var(x, init):\n    return mtf.get_variable(x.mesh, f\"activation-{random.randint(0, 2 ** 32):x}\", [],\n                            initializer=tf.constant_initializer(init), dtype=x.dtype)\n\n\ndef _pos_var(x, val):\n    return mtf.softplus(_var(x, 0)) + val\n\n\ndef _rrelu(x):\n    negative_scale = random.random()\n    return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)\n\n\ndef _elish(x):\n    cond = mtf.cast(mtf.greater(x, 0), x.dtype)\n    exp = mtf.exp(x)\n    return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)\n\n\nCUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01),\n              'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20),\n              'id': lambda x: x,\n              'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49,\n              'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7,\n              'spike': lambda x: 1 / (1 + x ** 2),\n              'spike2': lambda x: mtf.exp(-x ** 2),\n              'tanhshrink': lambda x: x - tanh(x),\n              'softsign': lambda x: x / (mtf.abs(x) + 1),\n              'softmax': lambda x: mtf.softmax(x, x.shape[-1]),\n              'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]),\n              'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1,\n              'rrelu': _rrelu,\n              'elish': _elish,\n              'arcsinh': _arcsinh,\n              'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (\n                          _pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))),\n              'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),\n              'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),\n              'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),\n              'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x),\n              'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)),\n              'cosid': lambda x: mtf.cos(x) - x,\n              'minsin': lambda x: mtf.minimum(x, mtf.sin(x)),\n              'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)),\n              'mish': lambda x: x * mtf.tanh(mtf.softplus(x)),\n              'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)),\n              'lisht': lambda x: x * mtf.tanh(x),\n              'seagull': lambda x: mtf.log(1 + x ** 2),\n              'snake': lambda x: x + mtf.sin(x) ** 2,\n              'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x),\n              'softplusmone': lambda x: mtf.softplus(x) - 1\n              }\n\n\ndef get_activation_fn(params):\n    if \"activation_fn\" in params:\n        activation_fn = params[\"activation_fn\"]\n    else:\n        print(\"Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)\")\n        activation_fn = \"gelu\"\n\n    if activation_fn in BASE_FNS:\n        return BASE_FNS[activation_fn]\n\n    if activation_fn in CUSTOM_FNS:\n        return CUSTOM_FNS[activation_fn]\n\n    raise ValueError('unknown activation function \"activation_fn\" in config')\n\n\n\n"
  },
  {
    "path": "models/gpt2/gpt2.py",
    "content": "\"\"\"GPT-like model in Mesh-Tensorflow\"\"\"\nimport tensorflow.compat.v1 as tf\nimport mesh_tensorflow.transformer as mtf_transformer\n\nfrom models.utils import parse_inputs, entmax_cross_entropy_with_logits\nfrom models.layers import *\n\n\n# --------------------------------------------------------------------------------\n# TRANSFORMER BLOCK:\n\ndef block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, pos_emb, variable_dtype, context=None):\n    use_mlp_glu = params[\"mlp_glu\"] == True\n    use_scale_norm = params[\"scalenorm\"] == True\n    use_moe = exists(params[\"moe_layers\"]) and (layer_num in params[\"moe_layers\"])\n    use_rezero = params[\"rezero\"] == True\n    macaron_attention = params[\"macaron\"] == True\n\n    def fn(x):\n        with tf.variable_scope(scope):\n            nx = x.shape[-1]  # Grab last dimension from input\n\n            if use_rezero:\n                prenorm = identity\n            elif use_scale_norm:\n                prenorm = scale_norm\n            else:\n                prenorm = layer_norm\n\n            pre_residual_fn = rezero if use_rezero else identity\n\n            attention_type = params[\"attention_types\"][layer_num]\n\n            if macaron_attention:\n                mult = 0.5\n                mlp_fn = mlp_glu if use_mlp_glu else mlp\n                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)\n                # Define intermediate layer of mlp - to split\n                dim_intermediate_expanded = mtf.Dimension(\"intermediate_expanded\", intermediate_size)\n                m = mlp_fn(x, \"mlp_macaron\", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)\n\n                x = x + (m * mult)\n            else:\n                mult = 1\n\n            if attention_type != \"none\":\n                res_x = prenorm(x, \"norm_1\", variable_dtype=variable_dtype, params=params)\n                a = attn(res_x, \"attn\", nx, attention_type=attention_type,\n                         params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim,\n                         variable_dtype=variable_dtype, context=context, pos_emb=pos_emb)\n            else:\n                a = x\n\n            x = x + pre_residual_fn(a, \"norm_rezero_1\", dtype=variable_dtype)\n\n            res_x = prenorm(x, \"norm_2\", variable_dtype=variable_dtype, params=params)\n\n            if use_moe:\n                moe_params = mtf.transformer.moe.HParams()\n                mtf.transformer.moe.set_default_moe_hparams(moe_params)\n                moe_params.add_hparam(\"moe_min_expert_capacity\", 1)\n                moe_params.add_hparam(\"moe_use_experts_attention\", False)\n\n                # Override defaults\n                for k, v in params[\"moe_params\"].items():\n                    moe_params.add_hparam(k, v)\n\n                moe_train = params[\"mode\"] == \"train\"\n\n                m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params,\n                                                                           train=moe_train,\n                                                                           mesh_shape=params[\"mesh_shape\"],\n                                                                           layout=params[\"layout\"],\n                                                                           activation=params.get(\"moe_activation\",\n                                                                                                 \"relu\"),\n                                                                           variable_dtype=variable_dtype,\n                                                                           num_microbatches=params[\"num_microbatches\"])\n                m = mtf.dropout(m, rate=params[\"res_dropout\"], name=\"moe_dropout\")\n            else:\n\n                mlp_fn = mlp_glu if use_mlp_glu else mlp\n                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)\n\n                # Define intermediate layer of mlp - to split\n                dim_intermediate_expanded = mtf.Dimension(\"intermediate_expanded\", intermediate_size)\n\n                m = mlp_fn(res_x, \"mlp\", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)\n                aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype)\n\n            x = x + pre_residual_fn((m * mult), \"norm_rezero_2\", variable_dtype)\n            return x, aux_loss\n\n    return fn\n\n\n# --------------------------------------------------------------------------------\n# GPT2 MODEL:\n\ndef model(mtf_features, other_features, params, mesh, variable_dtype, context=None):\n    \"\"\"A GPT style model implemented in mesh tensorflow.\"\"\"\n\n    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)\n\n    if is_incremental_inference(context):\n        # reshape inputs if in inference mode\n        x = mtf.gather(x, context.position - 1, sequence_dim)\n        x = mtf.reshape(x, [batch_dim])\n\n    use_axial_pos_emb = exists(params[\"axial_pos_emb\"])\n    use_rotary_emb = exists(params[\"rotary_emb\"])\n\n    # Text encoding\n    wte = mtf.get_variable(mesh, \"wte\", mtf.Shape([vocab_dim, embd_dim]),\n                           initializer=tf.random_normal_initializer(stddev=0.02),\n                           master_dtype=variable_dtype.master_dtype,\n                           slice_dtype=variable_dtype.slice_dtype,\n                           activation_dtype=variable_dtype.activation_dtype)\n\n    with tf.variable_scope(\"token_embd\"):\n        # Text embedding\n        h = mtf.gather(wte, x, vocab_dim)\n        if params[\"embed_dropout\"] > 0 and params[\"mode\"] == \"train\":\n            h = mtf.dropout(h, rate=params[\"embed_dropout\"], name=\"wte_dropout\")\n\n    # Position encoding\n\n    if use_rotary_emb:\n        wpe = None\n        layer_pos_emb = rotary_positional_emb(mesh, sequence_dim, params, variable_dtype)\n    elif use_axial_pos_emb:\n        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)\n        layer_pos_emb = None\n    else:\n        # Use standard position encoding\n        wpe = mtf.get_variable(mesh, \"wpe\", mtf.Shape([embed_sequence_dim, embd_dim]),\n                               initializer=tf.random_normal_initializer(stddev=0.01),\n                               master_dtype=variable_dtype.master_dtype,\n                               slice_dtype=variable_dtype.slice_dtype,\n                               activation_dtype=variable_dtype.activation_dtype)\n        layer_pos_emb = None\n\n    if exists(wpe):\n        with tf.variable_scope(\"pos_embd\"):\n            # Positional embedding\n            position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (\n                    context.position - 1)\n            pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])\n            if params[\"embed_dropout\"] > 0 and params[\"mode\"] == \"train\":\n                pos_emb = mtf.dropout(pos_emb, rate=params[\"embed_dropout\"], name=\"wte_dropout\")\n            h += pos_emb\n\n    aux_losses = 0  # instantiate auxiliary losses (for MOE models)\n\n    for layer in range(params[\"n_layer\"]):\n        # attn blocks\n        share_parameters = exists(params[\"share_parameters\"]) and params[\"share_parameters\"] == True\n        block_scope = f\"h{layer}\" if not share_parameters else \"\"\n\n        block_fn = block(params=params, scope=block_scope, layer_num=layer,\n                         bias=other_features[\"attn_bias\"],\n                         sequence_dim=sequence_dim,\n                         memory_length_dim=other_features[\"memory_length_dim\"],\n                         pos_emb = layer_pos_emb,\n                         variable_dtype=variable_dtype,\n                         context=context)\n\n        # If true and in train mode, enable gradient checkpointing\n        recompute_grad = params[\"recompute_grad\"] and (params[\"mode\"] == \"train\") == True\n        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])\n        aux_losses += loss\n\n    no_weight_tie_emb = params[\"no_weight_tie\"] == True\n    if no_weight_tie_emb:\n        with tf.variable_scope(\"wte_final_linear\"):\n            logits = linear(h, \"linear_out\", vocab_dim, variable_dtype=variable_dtype, params=params)\n    else:\n        # Layer normalize & affine transform\n        h = layer_norm(h, \"ln_f\", variable_dtype=variable_dtype)\n        seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension(\"sequence\", 1)\n        with tf.variable_scope(\"wte_final_einsum\"):\n            # Equivalent to tf.matmul\n            logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])\n\n    if params[\"mode\"] in [\"train\", \"eval\"]:\n        labels = mtf_features[\"labels\"]\n        z_loss = params.get(\"z_loss\", 1e-4)  # an auxiliary loss used to stabilize mtf xentropy\n\n        # Go to full precision for the logits \n        logits = mtf.cast(logits, tf.float32)\n\n        use_entmax_loss = params.get(\"entmax_loss\", False)\n        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits\n\n        with tf.variable_scope(\"xentropy_final\"):\n            loss_batch = loss_fn(logits=logits, targets=labels,\n                                 vocab_dim=logits.shape[-1], z_loss=z_loss)\n\n        # For non-autoregressive models (masked language modeling training)\n        # Make sure labels with padding tokens are not counted in the loss\n        if not params[\"causal\"]:\n            padding_id = params.get(\"padding_id\", 0)\n            loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))\n\n        with tf.variable_scope(\"reduce_mean_final\"):\n            loss = mtf.reduce_mean(loss_batch)\n\n        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)\n        loss /= params[\"num_microbatches\"]\n        # Convert to train dtype\n        loss = mtf.cast(loss, variable_dtype.slice_dtype)\n    else:\n        loss = None\n        loss_batch = None\n\n    # Cast back to checkpoint dtype\n    logits = mtf.cast(logits, variable_dtype.master_dtype)\n    return logits, loss, loss_batch\n"
  },
  {
    "path": "models/layers.py",
    "content": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport math\nimport mesh_tensorflow.transformer as mtf_transformer\n\nfrom models.activations import get_activation_fn\n\n\n# --------------------------------------------------------------------------------\n# LAYERS:\n\nsentinel = object()\n\n\ndef exists(x):\n    return x is not None\n\n\ndef identity(x, *args, **kwargs):\n    return x\n\n\ndef is_incremental_inference(context):\n    return exists(context) and context.mode == \"incremental\"\n\n\ndef norm(x, axis, epsilon=1e-8):\n    x -= mtf.reduce_mean(x, reduced_dim=axis, name=\"norm_reduce_mean_u\")\n    s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name=\"norm_reduce_mean_s\")\n    return x * mtf.rsqrt(s + epsilon)\n\n\ndef rezero(x, scope, dtype):\n    with tf.variable_scope(scope):\n        g = mtf.get_variable(x.mesh, \"g\", [], initializer=tf.constant_initializer(0), dtype=dtype)\n        return x * g\n\n\ndef scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):\n    if axis is sentinel:\n        axis = x.shape[-1]\n\n    with tf.variable_scope(scope):\n        g = mtf.get_variable(x.mesh, \"g\", [], initializer=tf.constant_initializer(1),\n                             master_dtype=variable_dtype.master_dtype,\n                             slice_dtype=variable_dtype.slice_dtype,\n                             activation_dtype=variable_dtype.activation_dtype)\n\n        x = norm(x, axis, epsilon)\n        x = x * g\n        return x\n\n\ndef layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):\n    \"\"\"Normalize to mean = 0, std = 1, then do a diagonal affine transform.\"\"\"\n    if axis is sentinel:\n        axis = x.shape[-1]\n\n    with tf.variable_scope(scope):\n        n_state = x.shape[-1]\n\n        g = mtf.get_variable(x.mesh, \"g\", [n_state], initializer=tf.constant_initializer(1),\n                             master_dtype=variable_dtype.master_dtype,\n                             slice_dtype=variable_dtype.slice_dtype,\n                             activation_dtype=variable_dtype.activation_dtype)\n        b = mtf.get_variable(x.mesh, \"b\", [n_state], initializer=tf.constant_initializer(0),\n                             master_dtype=variable_dtype.master_dtype,\n                             slice_dtype=variable_dtype.slice_dtype,\n                             activation_dtype=variable_dtype.activation_dtype)\n\n        x = norm(x, axis, epsilon)\n        x = x * g + b\n        return x\n\n\ndef linear_attention(q, k, v):\n    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])\n    q = mtf.rename_dimension(q, \"features_per_head\", \"features_per_head_in\")\n    k = mtf.rename_dimension(k, \"features_per_head\", \"features_per_head_in\")\n\n    dim_in = k.shape[-1]\n\n    q = mtf.softmax(q, dim_in)\n    k = mtf.softmax(k, seq_dim)\n\n    context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])\n    attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])\n    return attn\n\n\ndef causal_linear_attention(q, k, v, eps = 1e-6):\n    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])\n    q = mtf.rename_dimension(q, \"features_per_head\", \"features_per_head_in\")\n    k = mtf.rename_dimension(k, \"features_per_head\", \"features_per_head_in\")\n\n    dim_in = k.shape[-1]\n\n    q = mtf.softmax(q, dim_in)\n    k = mtf.exp(k)\n\n    cumulative_k = mtf.cumsum(k, seq_dim) + eps\n    D_inv = 1. / mtf.einsum([q, cumulative_k], output_shape=[batch_dim, seq_dim, head_dim])\n\n    context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])\n    cumulative_context = mtf.cumsum(context, seq_dim)\n\n    attn = mtf.einsum([q, cumulative_context, D_inv], output_shape=[batch_dim, seq_dim, head_dim, dim_out])\n    return attn\n\n\ndef linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False):\n    # nf = number of features\n    if params[\"scale_by_depth\"] and scale:\n        # Scale by sqrt(num_layers), only happens at the final projection before a res block output\n        w_init_stdev = w_init_stdev * (1. / math.sqrt(params[\"n_layer\"]))\n    if params[\"scale_by_in\"]:  # Scale by sqrt(num_input_features)\n        w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size))  # Dimension is a namedtuple of (name, size)\n    # Not in the variable_scope because mtf already has a variable_scope in it\n    with tf.variable_scope(\"conv1d_main\"):\n        c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True,\n                             kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev),\n                             variable_dtype=variable_dtype,\n                             )\n        return c\n\n\ndef memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):\n    \"\"\"memory / key values from all attention paper\"\"\"\n\n    dim_mem_kv = mtf.Dimension(\"mem_kv_sequence\", num_mem_kv)\n    emb_dim = k.shape[-1]\n    mem_std = 1 / math.sqrt(emb_dim.size)\n\n    mem_k = mtf.get_variable(mesh, \"mem_k\", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),\n                             initializer=tf.random_normal_initializer(stddev=mem_std),\n                             master_dtype=variable_dtype.master_dtype,\n                             slice_dtype=variable_dtype.slice_dtype,\n                             activation_dtype=variable_dtype.activation_dtype,\n                             )\n    mem_v = mtf.get_variable(mesh, \"mem_v\", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),\n                             initializer=tf.random_normal_initializer(stddev=mem_std),\n                             master_dtype=variable_dtype.master_dtype,\n                             slice_dtype=variable_dtype.slice_dtype,\n                             activation_dtype=variable_dtype.activation_dtype)\n\n    mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),\n                       (mem_k, mem_v))\n    mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, \"mem_kv_sequence\", \"sequence\"),\n                       (mem_k, mem_v))\n\n    k = mtf.concat([mem_k, k], \"sequence\")\n    v = mtf.concat([mem_v, v], \"sequence\")\n    return k, v\n\n\ndef attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None, pos_emb=None):\n    # x :: [batch, seq, n_embd]\n    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh\n\n    # n_state is the same as config[\"n_embd\"], which is also the same as dim_embd.\n    assert n_state.size % params[\"n_head\"] == 0\n\n    dim_heads = mtf.Dimension(\"heads\", params[\"n_head\"])\n\n    num_mem_kv = params.get(\"num_mem_kv\", 0)\n    use_num_mem_kv = num_mem_kv > 0\n\n    with tf.variable_scope(scope):\n        # Compute attention inputs\n        dim_kv = mtf.Dimension(\"features_per_head\", params[\"n_embd\"] // params[\"n_head\"])\n        mtfparams = mtf.transformer.attention.attention_params_simple(\n            x.mesh,\n            io_dim=dim_embd,\n            kv_dim=dim_kv,\n            heads_dim=dim_heads,\n            variable_dtype=variable_dtype\n        )\n        q = mtfparams.compute_q(x)\n        k = mtfparams.compute_k(x)\n        v = mtfparams.compute_v(x)\n\n        if is_incremental_inference(context):\n            one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)\n            inv_one_hot = 1.0 - one_hot\n            old_k, old_v = context.get_states(2)\n            k = old_k * inv_one_hot + k * one_hot\n            v = old_v * inv_one_hot + v * one_hot\n\n        if exists(context):\n            context.record_new_states([k, v])\n\n        if exists(pos_emb):\n            cos, sin = pos_emb\n            k = apply_rotary_emb(k, cos, sin)\n\n            if is_incremental_inference(context):\n                seq_dim = cos.shape.get_dim_by_name('sequence')\n                cos = mtf.gather(cos, context.position - 1, seq_dim)\n                sin = mtf.gather(sin, context.position - 1, seq_dim)\n\n            q = apply_rotary_emb(q, cos, sin)\n\n        with tf.variable_scope(\"attention\"):\n            if attention_type == \"local\":\n                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.\n                radius = params.get(\"local_attention_radius\", 256)\n\n                if is_incremental_inference(context):\n                    q *= one_hot\n\n                a = mtf_transformer.attention.local_attention_1d(\n                    q, k, v,\n                    length_dim=k.shape[1],\n                    key_dim=dim_kv,\n                    value_dim=dim_kv,\n                    radius=radius,\n                    length_dim_num_splits=1,\n                    fully_autoregressive=params[\"causal\"],\n                    attention_kwargs={},\n                )\n\n                if is_incremental_inference(context):\n                    a = mtf.gather(a, context.position - 1, dim_seq)\n\n            elif attention_type == \"global\":\n\n                # TODO: pass in fake context\n                # Broadcast mask bias across batch and heads\n                if exists(bias):\n                    if not is_incremental_inference(context):\n                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])\n                    else:\n                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position\n                        bias = mtf.gather(bias, context.position - 1, dim_seq)\n                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])\n\n                # memory key / values, from all-attention paper\n                if use_num_mem_kv:\n                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)\n\n                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)\n                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)\n\n                attn_dropout_rate = params[\"attn_dropout\"] if params[\"mode\"] == \"train\" else 0\n\n                a = mtf_transformer.attention.attention(\n                    q, k, v,\n                    memory_length_dim=memory_length_dim,\n                    key_dim=dim_kv,\n                    value_dim=dim_kv,\n                    bias=broadcasted_bias,\n                    dropout_rate=attn_dropout_rate\n                )\n\n            elif attention_type == \"linear\":\n                linear_attn_fn = causal_linear_attention if params[\"causal\"] else linear_attention\n                a = linear_attn_fn(q, k, v)\n\n            else:\n                raise NotImplementedError(\"Unknown attention type {}!\".format(attention_type))\n\n        with tf.variable_scope(\"compute_output\"):\n            a = mtfparams.compute_output(a, x_shape)\n\n        with tf.variable_scope(\"compute_output_bias\"):\n            b = mtf.get_variable(x.mesh, \"o_b\", [dim_embd], initializer=tf.constant_initializer(0),\n                                 master_dtype=variable_dtype.master_dtype,\n                                 slice_dtype=variable_dtype.slice_dtype,\n                                 activation_dtype=variable_dtype.activation_dtype)\n            a += b\n\n        if params[\"mode\"] == \"train\" and params[\"res_dropout\"] > 0:\n            a = mtf.dropout(a, rate=params[\"res_dropout\"], name=\"res_dropout\")\n        return a\n\n\ndef mlp(x, scope, n_state, *, variable_dtype, params):\n    activation_fn = get_activation_fn(params)\n    with tf.variable_scope(scope):\n        nx = x.shape[-1]\n        h = activation_fn(linear(x, \"c_fc\", n_state, variable_dtype=variable_dtype, params=params))\n        h2 = linear(h, \"c_proj\", nx, variable_dtype=variable_dtype, params=params, scale=True)\n        if params[\"mode\"] == \"train\" and params[\"res_dropout\"] > 0:\n            h2 = mtf.dropout(h2, rate=params[\"res_dropout\"], name=\"mlp_dropout\")\n        return h2\n\n\ndef mlp_glu(x, scope, n_state, *, variable_dtype, params):\n    activation_fn = get_activation_fn(params)\n    with tf.variable_scope(scope):\n        nx = x.shape[-1]\n        h = linear(x, \"c_fc\", n_state, params=params)\n\n        h, gate = mtf.split(h, h.shape[-1], 2)\n        h *= activation_fn(gate)\n\n        h2 = linear(h, \"c_proj\", nx, variable_dtype=variable_dtype, params=params, scale=True)\n        if params[\"mode\"] == \"train\" and params[\"res_dropout\"] > 0:\n            h2 = mtf.dropout(h2, rate=params[\"res_dropout\"], name=\"mlp_dropout\")\n        return h2\n\n\ndef axial_positional_emb(embd_dim, mesh, params, variable_dtype):\n    # Use axial position encoding\n    axial_dim_1, axial_dim_2 = params[\"axial_pos_emb\"]\n\n    axial_dim = mtf.Dimension(\"axial_dim\", axial_dim_1 * axial_dim_2)\n    dim_axials = [mtf.Dimension(f\"axial_dim_{i}\", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]\n\n    axial_wpe_1 = mtf.get_variable(mesh, \"axial_wpe_1\", mtf.Shape([dim_axials[0], embd_dim]),\n                                   initializer=tf.random_normal_initializer(stddev=0.01),\n                                   master_dtype=variable_dtype.master_dtype,\n                                   slice_dtype=variable_dtype.slice_dtype,\n                                   activation_dtype=variable_dtype.activation_dtype)\n\n    axial_wpe_2 = mtf.get_variable(mesh, \"axial_wpe_2\", mtf.Shape([dim_axials[1], embd_dim]),\n                                   initializer=tf.random_normal_initializer(stddev=0.01),\n                                   master_dtype=variable_dtype.master_dtype,\n                                   slice_dtype=variable_dtype.slice_dtype,\n                                   activation_dtype=variable_dtype.activation_dtype)\n\n    axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),\n                                   (axial_wpe_1, axial_wpe_2))\n    wpe = (axial_wpe_1 + axial_wpe_2) / 2\n\n    wpe = mtf.reshape(wpe, [axial_dim, embd_dim])\n\n    return wpe\n\ndef rotary_positional_emb(mesh, sequence_dim, params, variable_dtype):\n    dtype = variable_dtype.master_dtype\n    dim_head = params[\"n_embd\"] // params[\"n_head\"]\n\n    dim_head = mtf.Dimension(\"features_per_head\", dim_head)\n    half_dim_head = mtf.Dimension(\"half_features_per_head\", dim_head.size // 2)\n\n    dim_range = mtf.range(mesh, half_dim_head, dtype) * 2 / dim_head.size\n    half_freqs = 1. / mtf.pow(mtf.constant(mesh, 10000, dtype = dtype), dim_range)\n\n    seq = mtf.range(mesh, sequence_dim, dtype)\n    half_freqs = mtf.einsum([half_freqs, seq], [sequence_dim, half_dim_head])\n\n    freqs = mtf.concat((half_freqs, half_freqs), half_dim_head.name)\n    freqs = mtf.rename_dimension(freqs, half_dim_head.name, dim_head.name)\n    return mtf.cos(freqs), mtf.sin(freqs)\n\ndef rotate_half(x):\n    dim_head_name = \"features_per_head\"\n    dim_head = x.shape.get_dim_by_name(dim_head_name)\n    half_dim_head_size = dim_head.size // 2\n    x1 = mtf.slice(x, 0, half_dim_head_size, dim_head_name)\n    x2 = mtf.slice(x, half_dim_head_size, half_dim_head_size, dim_head_name)\n    return mtf.concat((-x2, x1), dim_head.name)\n\ndef apply_rotary_emb(x, cos, sin):\n    rotated_x = rotate_half(x)\n    return x * cos + rotated_x * sin\n"
  },
  {
    "path": "models/utils.py",
    "content": "import tensorflow as tf\nimport mesh_tensorflow as mtf\nfrom functools import partial\n\n\ndef entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None,\n                    n_iter=50):\n    x, = explicit_inputs\n    y, = outputs\n    dY, = output_grads\n\n    gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))\n    dX = dY * gppr\n\n    q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim)\n    dX = dX - q * gppr\n\n    return dX,\n\n\ndef entmax_forward(x, alpha=1.3, dim=None, n_iter=50):\n    assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'\n\n    _gp = lambda x, alpha: x ** (alpha - 1)\n    _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))\n    _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)\n\n    dim = x.shape[-1] if dim is None else dim\n    d = dim.size\n\n    x = x * (alpha - 1)\n\n    max_val = mtf.reduce_max(x, reduced_dim=dim)\n\n    tau_lo = max_val - _gp(1, alpha)\n    tau_hi = max_val - _gp(1 / d, alpha)\n\n    f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1\n\n    dm = tau_hi - tau_lo\n\n    for _ in range(n_iter):\n        dm = dm / 2\n        tau_m = tau_lo + dm\n        p_m = _p(x - tau_m, alpha)\n        f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1\n\n        mask = mtf.greater_equal((f_m * f_lo), 0)\n        tau_lo = mtf.where(mask, tau_m, tau_lo)\n\n    p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)\n    return p_m\n\n\ndef entmax(x, alpha=1.3, dim=None, n_iter=50):\n    kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter)\n\n    return mtf.custom_gradient(\n        partial(entmax_forward, **kwargs),\n        partial(entmax_backward, **kwargs),\n        [x]\n    )\n\n\ndef entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):\n    if targets.dtype.is_integer:\n        # hard targets\n        if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):\n            raise ValueError(\n                \"softmax_cross_entropy_with_logits with hard targets \"\n                \"dims in targets=%s should be dims in logits=%s other than \"\n                \"vocab_dim=%s\" % (targets, logits, vocab_dim))\n        targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)\n    elif set(targets.shape.dims) != set(logits.shape.dims):\n        raise ValueError(\n            \"softmax_cross_entropy_with_logits with soft targets \"\n            \"dims in targets=%s should be dims in logits=%s\" % (targets, logits))\n\n    if vocab_dim not in logits.shape.dims:\n        raise ValueError(\"vocab_dim must be in logits.shape.dims\")\n\n    log_entmax = mtf.log(entmax(logits, dim=vocab_dim))\n\n    loss = mtf.negative(\n        mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))\n\n    return loss\n\n\ndef sample_categorical(x, dim=None):\n    dim = x.shape[-1] if dim is None else dim\n\n    cdf = mtf.cumsum(x, dim)\n    rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)\n    mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)\n    return mtf.argmax(mask, dim)\n\n\ndef biasmask_attn_weights(mesh, nd, ns, variable_dtype):\n    # The old mask_attn_weights applied directly to the QK;\n    # this returns a bias that the attention code from mtf adds to the attention matrix.\n    # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.\n    # n_src and n_dest are both the same, i.e equal to sequence length\n    # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T\n    # Information flows from k and v (memory_length) to q (sequence)\n    i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size\n    j = mtf.range(mesh, ns, tf.int32)\n    i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))\n    dtype = variable_dtype.activation_dtype\n    return mtf.cast(mtf.less(i, j), dtype) * -1e10\n\n\ndef parse_inputs(mtf_features, other_features):\n    # Parse inputs and labels from the mtf_features / other_features input dicts\n    # All dimensions are defined inside model_fn for efficiency\n    x = mtf_features[\"inputs\"]\n\n    batch_dim = x.shape[0]\n    sequence_dim = x.shape[1]\n    embd_dim = other_features[\"embd_dim\"]\n    vocab_dim = other_features[\"vocab_dim\"]\n    embed_sequence_dim = other_features[\"embed_sequence_dim\"]\n\n    return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim\n"
  },
  {
    "path": "optimizers.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\n\ndef clip_by_global_norm(grads, clip_norm):\n    \"\"\"Clip the grads by global norm.\"\"\"\n    global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))\n    multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)\n    clipped_grads = [None if t is None else t * multiplier for t in grads]\n    return clipped_grads, global_norm\n\ndef get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):\n    \"\"\"Creates and returns an optimizer training op.\"\"\"\n    global_step = tf.train.get_or_create_global_step()\n\n    learning_rate = tf.constant(value=params[\"lr\"], shape=[], dtype=variable_dtype.slice_dtype)\n    clip_value = mtf.constant(mesh, params[\"gradient_clipping\"], dtype=variable_dtype.slice_dtype)\n\n    if inp_var_grads is None:\n        var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])\n    else:\n        var_grads = inp_var_grads\n\n    # Cast to full precision\n    var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]\n\n    # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps\n    end_step = params.get(\"lr_decay_end\", params[\"train_steps\"]) \n\n    if params[\"lr_decay\"] == \"linear\":\n        learning_rate = tf.train.polynomial_decay(\n            learning_rate,\n            global_step,\n            end_step,\n            end_learning_rate=params[\"lr\"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper\n            power=1.0,\n            cycle=False)\n    elif params[\"lr_decay\"] == \"cosine\":\n        learning_rate = tf.train.cosine_decay(\n            learning_rate,\n            global_step,\n            end_step,\n            alpha=0.1  # Alpha is min lr value as a fraction of init lr.\n        )\n\n    if params[\"warmup_steps\"] > 0:\n        global_steps_int = tf.cast(global_step, tf.int32)\n        warmup_steps_int = tf.constant(params[\"warmup_steps\"], dtype=tf.int32)\n\n        dtype = variable_dtype.slice_dtype\n\n        global_steps_float = tf.cast(global_steps_int, dtype)\n        warmup_steps_float = tf.cast(warmup_steps_int, dtype)\n\n        warmup_percent_done = global_steps_float / warmup_steps_float\n        warmup_learning_rate = learning_rate * warmup_percent_done\n\n        is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)\n        learning_rate = ((1.0 - is_warmup) * learning_rate +\n                       is_warmup * warmup_learning_rate)\n\n    learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name=\"learning_rate\")\n    mtf.scalar_summary(\"lr\", learning_rate)\n\n    if params[\"opt_name\"].lower() == \"adam\":\n        optimizer = AdamWeightDecayOptimizer(\n            learning_rate=learning_rate,\n            weight_decay_rate=params[\"weight_decay\"],\n            beta_1=params[\"beta1\"],\n            beta_2=params[\"beta2\"],\n            epsilon=params[\"epsilon\"],\n            exclude_from_weight_decay=[\"norm\", \"bias\"],\n            variable_dtype=variable_dtype\n        )\n    else:\n        optimizer = mtf.optimize.AdafactorOptimizer(\n            learning_rate=params[\"lr\"],\n            decay_rate=params[\"weight_decay\"],\n            beta1=params[\"beta1\"],\n            epsilon1=params[\"ada_epsilon1\"],\n            epsilon2=params[\"ada_epsilon2\"]\n        )\n\n    if params[\"gradient_clipping\"] is not None:\n        (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)\n\n    update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)\n    return learning_rate, update_ops, var_grads_fp\n\n\nclass AdamWeightDecayOptimizer(mtf.optimize.Optimizer):\n  \"\"\"A basic Adam optimizer that includes \"correct\" L2 weight decay.\"\"\"\n\n  def __init__(self,\n               learning_rate,\n               weight_decay_rate=0.0,\n               beta_1=0.9,\n               beta_2=0.999,\n               epsilon=1e-6,\n               exclude_from_weight_decay=None,\n               variable_dtype=None):\n    \"\"\"Constructs a AdamWeightDecayOptimizer.\"\"\"\n\n    self.learning_rate = learning_rate\n    self.weight_decay_rate = weight_decay_rate\n    self.beta_1 = beta_1\n    self.beta_2 = beta_2\n    self.epsilon = epsilon\n    self.exclude_from_weight_decay = exclude_from_weight_decay\n    self.variable_dtype = variable_dtype\n\n  def apply_grad(self, grad, var):\n    \"\"\"See base class.\"\"\"\n    if grad is None:\n      tf.logging.warning(\"Gradient is None for variable %s\" % var.name)\n      return []\n    \n    grad = mtf.to_float(grad)\n\n    assignments = []\n\n    m = mtf.get_variable(\n        var.mesh, var.name + \"/adam_m\", var.shape,\n        initializer=tf.zeros_initializer(), \n        # master_dtype=self.variable_dtype.master_dtype, \n        # slice_dtype=self.variable_dtype.slice_dtype, \n        # activation_dtype=self.variable_dtype.activation_dtype, \n        trainable=False)\n\n    v = mtf.get_variable(\n        var.mesh, var.name + \"/adam_v\", var.shape,\n        initializer=tf.zeros_initializer(), \n        # master_dtype=self.variable_dtype.master_dtype, \n        # slice_dtype=self.variable_dtype.slice_dtype, \n        # activation_dtype=self.variable_dtype.activation_dtype, \n        trainable=False)\n\n    # Standard Adam update.\n    next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad\n    next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)\n\n    update = next_m / (mtf.sqrt(next_v) + self.epsilon)\n\n    # Just adding the square of the weights to the loss function is *not*\n    # the correct way of using L2 regularization/weight decay with Adam,\n    # since that will interact with the m and v parameters in strange ways.\n    #\n    # Instead we want to decay the weights in a manner that doesn't interact\n    # with the m/v parameters. This is equivalent to adding the square\n    # of the weights to the loss with plain (non-momentum) SGD.\n    if self._do_use_weight_decay(var.name):\n      update += mtf.to_float(var.value) * self.weight_decay_rate \n\n    update_with_lr = self.learning_rate * update\n\n    var_update = mtf.assign_sub(var, update_with_lr)\n\n    assignments.extend(\n        [var_update,\n         mtf.assign(m, next_m),\n         mtf.assign(v, next_v)])\n    return assignments\n\n  def _do_use_weight_decay(self, param_name):\n    \"\"\"Whether to use L2 weight decay for `param_name`.\"\"\"\n    if not self.weight_decay_rate:\n      return False\n    if self.exclude_from_weight_decay:\n      for r in self.exclude_from_weight_decay:\n        if re.search(r, param_name) is not None:\n          return False\n    return True"
  },
  {
    "path": "requirements.txt",
    "content": "google-api-python-client\njsonlines\nlm_dataformat\nmesh-tensorflow==0.1.18\nnumpy\noauth2client\nortools\npytest\nsacred\ntensorflow==2.5.1\ntensorflow-datasets==3.2.1\ntokenizers==0.9.4\ntransformers==4.1.1\ntpunicorn\nabsl-py\nftfy\nsacred\npymongo\n"
  },
  {
    "path": "run_experiment.py",
    "content": "import atexit\nimport sacred\nimport argparse\nimport time\nimport math\nimport subprocess\nimport shutil\nimport os\nimport json\nimport threading\nimport requests\nimport glob\nfrom configs import fetch_model_params\nimport socket\nimport subprocess\nimport queue\nimport sys\nimport signal\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any\nparser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters\nparser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)\nparser.add_argument('--steps_per_checkpoint', type=int, default=5000)\nparser.add_argument('--autostack', action=\"store_false\")\nparser.add_argument('--auto_layout', action=\"store_true\")\nparser.add_argument('--auto_layout_and_mesh_shape', action=\"store_true\")\nparser.add_argument('--new', action='store_true')\nparser.add_argument('--test', action='store_true')\nparser.add_argument('--eval', action='store_true')\nparser.add_argument('--predict', action='store_true')\nparser.add_argument('--no_delete_tpu', action='store_true')\nparser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)\nparser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds\nargs = parser.parse_args()\n\nparams = fetch_model_params(args.model)\n\nex = sacred.Experiment(args.experiment_name)\nex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))\n\n\ndef get_open_port(lo=8000, hi=8100):\n    for i in range(lo, hi):\n        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n            if s.connect_ex(('localhost', i)) != 0:\n                return i\n\n\ndef train_thread(args, tpu, id, q):\n    print('starting training on', tpu)\n\n    # pass binary flags through\n    opts = ''\n    for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:\n        if args.__getattribute__(flag):\n            opts += ' --' + flag\n\n    for flag in ['autostack', ]:\n        if not args.__getattribute__(flag):\n            opts += ' --' + flag\n\n    cmd = \"python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}\".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)\n    print('Running:', cmd)\n    proc = subprocess.Popen(cmd, shell=True)\n\n    # poll until it's exited\n    while proc.poll() is None:\n        time.sleep(60)\n        try:\n            nq, *nargs = q.get_nowait()\n            if nq == 'kill':\n                print('train thread recieved kill signal from logging thread')\n                # first send SIGTERM\n                proc.terminate()\n\n                time.sleep(60)\n                \n                # if it still hasn't exited, we send SIGKILL\n                if proc.poll() is None: \n                    print('SIGTERM not successful, sending SIGKILL')\n                    proc.kill()\n\n        except queue.Empty:\n            pass\n\n    print('exited training!')\n    if proc.returncode == 0:\n        print('exited gracefully')\n        os.kill(os.getpid(), signal.SIGINT)\n        return\n    \n    if args.no_delete_tpu:\n        print('recreate done, exiting train_thread - not killing tpu!')\n        return\n    print(\"Recreating {} in 60sec...\".format(tpu))\n    time.sleep(60)\n    os.system(\"pu recreate {} --yes --retry 3600 --retry-randomness 1.5\".format(tpu))\n    print('recreate done, exiting train_thread')\n    \n    # clear out queue\n    while True:\n        try:\n            q.get_nowait()\n            print('dropped request in queue after pu recreate')\n        except queue.Empty:\n            break\n\n\ndef get_json(uri, params=None, timeout=15):\n    resp = requests.get(uri, params=params, timeout=timeout)\n    resp.raise_for_status()\n    return resp.json()\n\n\ndef get_tag_sets(base_uri):\n    j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})\n    assert isinstance(j, dict)\n    return {\n        run: j[run].keys()\n        for run in j.keys()\n    }\n\n\ndef get_scalar_data(base_uri, run, tag):\n    j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})\n    assert isinstance(j, list)\n    return j\n\n\ndef get_run_data(port):\n    base_uri = f'http://localhost:{port}/'\n    r = {}\n    try:\n        tag_sets = get_tag_sets(base_uri)\n        runs = tag_sets.keys()\n        if '.' in runs:\n            if 'loss' in tag_sets['.']:\n                r['loss'] = get_scalar_data(base_uri, '.', 'loss')\n        if 'eval' in runs:\n            if 'loss' in tag_sets['eval']:\n                r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')\n        if 'eval_lambada' in runs:\n            if 'lambada_acc' in tag_sets['eval_lambada']:\n                r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')\n            if 'lambada_log_ppl' in tag_sets['eval_lambada']:\n                r['lambada_ppl'] = [\n                    [t, s, math.exp(lp)]\n                    for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')\n                ]\n    except:\n        import traceback\n        traceback.print_exc()\n    return r\n\n\n@ex.main\ndef main(_run):\n    print('Starting run', _run._id)\n    print('experiment main invoked with argv:', \" \".join(sys.argv))\n    print('WARNING: please remember to remove old metric log files from the model directory.')\n\n    os.makedirs('run_configs', exist_ok=True)\n    shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))\n\n    tensorboard_port = get_open_port()\n    print('Tensorboard at port:', tensorboard_port)\n    print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))\n    os.system(\"screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'\".format(_run._id, params[\"model_path\"], tensorboard_port,params[\"model_path\"], tensorboard_port,))\n    atexit.register(goodbye, _run._id)\n\n    curr_step = {}\n    seen_predictions = set()\n\n    heartbeat_timeout = args.initial_heartbeat_timeout * 2\n    while True:\n        last_tb_log_time = time.time()\n        start_time = time.time()\n        q = queue.Queue()\n        trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))\n        trainthd.start()\n\n        while trainthd.is_alive():\n            time.sleep(60)\n\n            if start_time + args.initial_heartbeat_timeout < time.time():\n                # after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower\n                heartbeat_timeout = args.heartbeat_timeout\n\n            print('Polling tensorboard for metrics...')\n            data = get_run_data(tensorboard_port)\n            for k in data.keys():\n                for ts, step, val in data[k]:\n                    if step <= curr_step.get(k, -1):\n                        continue\n                    _run.log_scalar(k, val, step)\n                    if k == 'loss':\n                        _run.log_scalar('tb_ts', ts, step)\n                        print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))\n                    \n                    # found something new, so logging!\n                    last_tb_log_time = time.time()\n\n                    curr_step[k] = step\n\n            for f in glob.glob('predictions_{}_*'.format(_run._id)):\n                if f in seen_predictions:\n                    continue\n                print('collecting prediction file', f)\n                ex.add_artifact(f)\n                \n                seen_predictions.add(f)\n            \n            # collect eval metrics from jsonl\n            if os.path.exists(f'eval_{_run._id}.jsonl'):\n                with open(f'eval_{_run._id}.jsonl') as fh:\n                    for line in fh:\n                        ob = json.loads(line)\n                        val_step = ob['global_step']\n                        val_task = ob['task']\n                        for metr in ob.keys():\n                            k = 'fs.' + val_task + '.' + metr\n                            if metr in ['task', 'global_step']: continue\n                            if val_step <= curr_step.get(k, -1): continue\n                            _run.log_scalar(k, ob[metr], val_step)\n                            curr_step[k] = val_step\n\n            if time.time() - last_tb_log_time > heartbeat_timeout:\n                # the run hasn't logged in a while, so we restart it\n                q.put(('kill',))\n\n                # give training thread some time to do its thing and recreate tpu\n                while trainthd.is_alive():\n                    print('logging thread waiting for killing stalled run and for tpu recreate to finish')\n                    time.sleep(60)\n                \n                # reset heartbeat timeout to initial\n                heartbeat_timeout = args.initial_heartbeat_timeout\n                last_tb_log_time = time.time()\n\n\n        if args.no_delete_tpu:\n            break\n\n\ndef goodbye(id):\n    print(\"You are now leaving the Python sector.\")\n    print(\"Sie verlassen den pythonischen Sektor.\")\n\n    os.system(\"screen -S tensorboard_{} -X quit\".format(id))\n\n        \nif __name__ == '__main__':\n    for file in glob.glob(\"**/*\", recursive=True):\n        if file.split('.')[-1] in ['py']:\n            print('Adding', file, 'to sacred')\n            ex.add_source_file(file)\n\n    ex.add_config({\n        'tpu_name': args.tpu,\n        **params\n    })\n\n    ex.run()\n"
  },
  {
    "path": "sample.py",
    "content": "import mesh_tensorflow as mtf\nimport tensorflow.compat.v1 as tf\nimport mesh_tensorflow.transformer as mtf_transformer\n\nfrom models.utils import entmax, sample_categorical\nfrom models.gpt2 import gpt2\n\ndef sample_autoregressive(partial_sequences,\n                          other_features,\n                          params,\n                          stop_at_token=50256,\n                          max_steps=None,\n                          temperature=0.9,\n                          variable_dtype=mtf.VariableDType(tf.float32),\n                          encoder_output=None,\n                          encoder_sequence_id=None,\n                          encoder_inputs=None,\n                          shared_params=None,\n                          has_partial_sequences=True,\n                          encoder_layer_outputs=None,\n                          never_end=False,\n                          remove_partial_sequences=False,\n                          sampling_keep_top_k=-1,\n                          sampling_use_entmax = False,\n                          bos_id=50256,\n                          ):\n    \"\"\"Sample randomly one token at a time.\n\n    The partial_sequences represent partial sequences to be continued.  The\n    first tokens of each sequence are nonzero representing the given partial\n    sequences and the last tokens of each sequence are zeros, representing what\n    needs to be filled in.\n\n    If there are no partial sequences (you want to sample from the beginning),\n    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and\n    has_partial_sequences=False (so we can skip computation).\n\n    Args:\n        partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]\n        stop_at_token: an optional integer eos id.  Stop when we produce it.\n        max_steps: an optional integer, the max number of steps to decode.\n        temperature: an optional floating point value between 0.0 and 1.0 0.0\n        means argmax, 1.0 means sample according to predicted distribution.\n        variable_dtype: a mtf.VariableDType\n        encoder_output: an optional Tensor\n        encoder_sequence_id: an optional Tensor\n        encoder_inputs: an optional Tensor\n        shared_params: an optional dictionary\n        has_partial_sequences: a boolean\n        encoder_layer_outputs: optional - readonly list of tensor activations when\n        decoding, one per each input layer + the embedding layer\n        never_end: a boolean - if set, then avoid generating stop_at_token\n        remove_partial_sequences: a boolean - whether to remove the partial\n        sequences from the output\n        sampling_keep_top_k: an integer - if not -1, only sample from the top k\n        logits.\n        bos_id: beginning of sequence id\n\n    Returns:\n        a Tensor with shape [<batch_dims>, length_dim]\n    \"\"\"\n\n    inputs = partial_sequences  # Partial sequences to fill in\n    batch_dims = inputs.shape.dims[:-1]\n    length_dim = inputs.shape.dims[-1]\n    padding_id = params.get(\"padding_id\", 0)\n    slow_sampling = params.get(\"slow_sampling\", False)\n\n\n    initial_position = mtf.reduce_sum(\n        mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim)  # Gets position where zero padding starts\n\n    length_range = mtf.range(inputs.mesh, length_dim, tf.int32)\n    input_full_attention = True  # for now hardcode this to true bc lazy\n    if input_full_attention:\n        # Vanilla autoregressive model - each position can see previous positions.\n        # Think this feeds in to the loop fn and tells each position where it can attend to?\n        read_priority = write_priority = length_range * mtf.to_int32(\n            mtf.greater(length_range, initial_position))\n    else:\n        read_priority = write_priority = length_range\n\n    # Builds context to pass around internally\n    # The 'first part' context records initial states of k / v / x\n\n    if not slow_sampling:\n        context_first_part = mtf_transformer.transformer.Context(\n            model=None,\n            mesh=inputs.mesh,\n            batch_dims=batch_dims,\n            length_dim=length_dim,\n            variable_dtype=variable_dtype,\n            mode=\"first_part\",\n            position=length_range,\n            position_is_default=True,\n            new_states=[],\n            initial_position=initial_position,\n            sequence_id=None,\n            encoder_output=encoder_output,\n            encoder_sequence_id=encoder_sequence_id,\n            constant_states=[],\n            shared_params=shared_params,\n            encoder_layer_outputs=encoder_layer_outputs,\n            write_priority=write_priority,\n            read_priority=read_priority,\n            inputs=inputs,\n            encoder_inputs=encoder_inputs)\n\n        with tf.variable_scope(\"gpt2\"):\n            logits, _, _ = gpt2.model({\"inputs\": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)\n\n        if not has_partial_sequences:\n            initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]\n        else:\n            initial_states = context_first_part.new_states\n    else:\n        initial_states = []\n\n    if not has_partial_sequences:\n        partial_sequences_eos_count = 0\n\n    if stop_at_token is not None:\n        partial_sequences_eos_count = mtf.reduce_sum(\n            mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),\n            reduced_dim=length_dim)\n\n    def cond_fn(position, ids, *unused_states):\n        \"\"\"Should we run another loop iteration?\"\"\"\n        past_end = mtf.greater_equal(position, length_dim.size)\n        if max_steps:\n            past_end = mtf.logical_or(\n                past_end, mtf.greater_equal(position - initial_position, max_steps))\n\n        is_done = past_end\n        if stop_at_token is not None:\n            eos_count = mtf.reduce_sum(\n                mtf.to_int32(mtf.equal(ids, stop_at_token)),\n                reduced_dim=length_dim)\n            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)\n            is_done = mtf.logical_or(is_done, has_additional_eos)\n        all_done = mtf.reduce_all(is_done)\n        return mtf.logical_not(all_done)\n\n    def body_fn(position, ids, *states):\n        \"\"\"One step in the decode loop.\"\"\"\n        nonlocal sampling_keep_top_k\n\n        context = mtf_transformer.transformer.Context(\n            model=None,\n            mesh=inputs.mesh,\n            batch_dims=batch_dims,\n            length_dim=length_dim,\n            variable_dtype=variable_dtype,\n            mode=\"incremental\",\n            position=position,\n            position_is_default=True,\n            states=states,\n            new_states=[],\n            initial_position=position,\n            sequence_id=None,\n            encoder_output=encoder_output,\n            encoder_sequence_id=encoder_sequence_id,\n            shared_params=shared_params,\n            encoder_layer_outputs=encoder_layer_outputs,\n            write_priority=write_priority,\n            read_priority=read_priority,\n            inputs=ids,\n            encoder_inputs=encoder_inputs) if not slow_sampling else None\n\n        with tf.variable_scope(\"gpt2\", reuse=tf.AUTO_REUSE):\n            logits, _, _ = gpt2.model({\"inputs\": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)\n\n        if not sampling_use_entmax:\n            # By default, do top_k sampling of 0.9\n            if sampling_keep_top_k == -2:\n                sampling_keep_top_k = int(logits.shape[-1].size * 0.1)\n\n            if sampling_keep_top_k != -1:\n                if sampling_keep_top_k <= 0:\n                    raise ValueError(\"sampling_keep_top_k must either be -1 or positive.\")\n                k_largest = mtf.nth_largest_element(\n                    logits, n=sampling_keep_top_k,\n                    reduced_dim=other_features[\"vocab_dim\"])\n                logits = mtf.where(mtf.less_equal(logits, k_largest),\n                                   mtf.ones_like(logits) * -1e6, logits)\n\n            ids_this_step = mtf.sample_with_temperature(\n                logits, other_features[\"vocab_dim\"], temperature)\n        else:\n            ids_this_step = sample_categorical(entmax(logits))\n\n        if slow_sampling:\n            ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)\n        else:\n            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))\n\n        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)\n        one_new_id = ids_this_step * one_hot\n        new_ids = (1 - one_hot) * ids + one_new_id\n        new_position = position + 1\n\n        ret = [new_position, new_ids]\n        if context is not None:\n            ret += context.new_states\n        return ret\n\n    while_loop_inputs = [initial_position, inputs] + initial_states\n    final_position, outputs = mtf.while_loop(\n        cond_fn, body_fn, while_loop_inputs)[:2]\n    del final_position\n    if has_partial_sequences and remove_partial_sequences:\n        # Remove partial sequences from outputs\n        partial_length = mtf.reduce_sum(\n            mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),\n            reduced_dim=length_dim)\n        outputs = mtf.dynamic_shift(\n            outputs, -partial_length, length_dim, wrap=False)\n    return outputs\n"
  },
  {
    "path": "tasks.py",
    "content": "import os.path\nimport json\nimport requests\nimport numpy as np\nimport ftfy\nfrom data.encoders import fetch_encoder, encode\nimport tensorflow as tf\nimport re\nfrom functools import partial\n\nlambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'\nnormalization = 'NFKC'\n\n\n# Note: this task is called \"lambada\" but it really refers to OpenAI's version\n# of the task, which actually differs in some ways from the task described in\n# the original paper. So, strictly speaking, accuracy values from this task\n# should not be compared to accuracy values from the original lambada task.\n# For more information, see\n#   https://github.com/openai/gpt-2/issues/131\n\ndef lambada_create_tokens_data(params, path):\n    with open(path, 'w') as f:\n        req = requests.get(lambada_src_uri)\n        req.raise_for_status()\n        jsons = [json.loads(l) for l in req.iter_lines()]\n        texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]\n        enc = fetch_encoder(params)\n        arrays = [encode(enc, t) for t in texts]\n        json.dump(arrays, f)\n        return arrays\n\n\ndef lambada_read_or_create_tokens_data(params, path):\n    # if you tell me where the file should go, i will helpfully create it for you\n    if not os.path.exists(path):\n        return lambada_create_tokens_data(params, path)\n    with open(path) as f:\n        return json.load(f)\n\n\ndef bin_pack(params, tokens_data):\n    eos_token = params['eos_id']\n    n_ctx = params['n_ctx']\n    dummy_token = 1\n    pad_batch_size = params['eval_batch_size']\n    bins = []\n    for a in tokens_data:\n        if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:\n            bins.append([])\n        bins[-1] += a\n        bins[-1].append(eos_token)\n    while len(bins) % pad_batch_size != 0:\n        bins.append([])\n    bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)\n    for i, b in enumerate(bins):\n        bins_array[i, 0:len(b)] = b\n    return bins_array\n\n\ndef lambada_init(params):\n    ds_configs = params['dataset_configs']\n    l = [\n        ds_configs[ds_id].get('lambada_tokens_path', \"./lambada.json\")\n        for ds_id, _, _, _ in params['datasets']\n    ]\n    assert len(l) > 0, 'lambada_tokens_path not found in the dataset config'\n    lt_path = l[0]\n    assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json'\n\n    tokens_data = lambada_read_or_create_tokens_data(params, lt_path)\n    bins_array = bin_pack(params, tokens_data)\n    params['lambada_tokens_path'] = lt_path\n    params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']\n\n\ndef lambada_get_task_info(params):\n    return {\n        'n_steps': params['lambada_n_steps'],\n    }\n\n\n# The LAMBADA evaluation code looks at the logits of each position just before an eos_token\ndef lambada_input(params):\n    eos_token = 50256 if params['n_vocab'] >= 50257 else 0\n    n_ctx = params['n_ctx']\n    lt_path = params['lambada_tokens_path']\n    tokens_data = lambada_read_or_create_tokens_data(params, lt_path)\n    bins_array = bin_pack(params, tokens_data)\n    dataset = tf.data.Dataset.from_tensor_slices(bins_array)\n\n    def _get_output(bin):\n        bin = tf.cast(bin, dtype=tf.int32)\n        indexes = tf.range(n_ctx)\n        results = tf.gather(bin, (indexes + 1) % n_ctx)\n        eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)\n        output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))\n        bin = tf.reshape(bin, [n_ctx])\n        bin = tf.cast(bin, dtype=tf.int32)\n        output = tf.reshape(output, [n_ctx])\n        output = tf.cast(output, dtype=tf.int32)\n        return bin, output\n\n    dataset = dataset.map(_get_output,num_parallel_calls=tf.data.AUTOTUNE)\n    dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)\n    dataset = dataset.repeat()\n    return dataset\n\n\ntask_descriptors = {\n    'lambada': {\n        'init_fn': lambada_init,\n        'get_task_info_fn': lambada_get_task_info,\n        'input_fn': lambada_input,\n    }\n}\n"
  },
  {
    "path": "utils.py",
    "content": "import re\nfrom urllib.parse import urlparse\nfrom shutil import rmtree\nimport logging\nimport os\nfrom pathlib import Path\nimport sys\nimport tensorflow.compat.v1 as tf\nimport tensorflow.compat.v2 as tf2\nimport mesh_tensorflow as mtf\nimport mesh_tensorflow.auto_mtf\nfrom data.encoders import fetch_encoder\nimport re\n\ndef setup_logging(args):\n    Path(\"logs\").mkdir(exist_ok=True)\n    tf.logging.set_verbosity(logging.INFO)\n    tf.get_logger().propagate = False  # Remove double log on console\n    name = os.path.splitext(os.path.basename(args.model))[0]\n    handlers = [\n        logging.FileHandler(f\"logs/{name}.log\"),\n        logging.StreamHandler(sys.stdout)\n    ]\n    logger = logging.getLogger(\"tensorflow\")\n    logger.handlers = handlers\n    return logger\n\n\ndef get_batch_size(params):\n    return params[f\"{params['mode']}_batch_size\"]\n\n\ndef add_mode_to_params(params, mode):\n    if mode == tf.estimator.ModeKeys.PREDICT:\n        params[\"mode\"] = \"predict\"\n    elif mode == tf.estimator.ModeKeys.EVAL:\n        params[\"mode\"] = \"eval\"\n    elif mode == tf.estimator.ModeKeys.TRAIN:\n        params[\"mode\"] = \"train\"\n    else:\n        raise ValueError(f\"Invalid mode {mode}\")\n    return params\n\n\ndef simd_mesh_setup(params, mesh_shape, layout_rules):\n    \"\"\"Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores\"\"\"\n\n    num_hosts = params[\"context\"].num_hosts\n    host_placement_fn = params[\"context\"].tpu_host_placement_function\n    device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]\n    tf.logging.info(f\"device_list = {device_list}\")\n\n    # TODO: Better estimation of replica cache size?\n    replica_cache_size = 300 * 1000000  # 300M per replica\n\n    # Worker 0 caches all the TPU binaries\n    worker0_mem = replica_cache_size * params[\"context\"].num_replicas\n    devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)\n    var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)\n    mesh_devices = [\"\"] * mesh_shape.size\n    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(\n        mesh_shape, layout_rules, mesh_devices, params[\"context\"].device_assignment)\n\n    return var_placer, mesh_impl\n\n\ndef remove_batch_from_layout(layout):\n    \"\"\"\n    The tf-mesh layout splits across batch size, remove it.\n    Useful for prediction steps, when you no longer want large batches.\n\n    :param layout: string describing tf-mesh layout\n    :return: layout minus batch dimension\n    \"\"\"\n    layout = layout.split(',')\n    ret_layout = \"\"\n    for i in layout:\n        if \"batch\" in i:\n            pass\n        else:\n            ret_layout += f\"{i},\"\n    return ret_layout[:-1]\n\n\ndef yes_or_no(question):\n    while True:\n        reply = str(input(question+' (y/n): ')).lower().strip()\n        if reply[:1] == 'y':\n            return True\n        if reply[:1] == 'n':\n            return False\n\n\ndef remove_gs_or_filepath(path):\n    parsed_url = urlparse(path)\n    if parsed_url.scheme == \"gs\":\n        os.system(f\"gsutil rm -rf {path}\")\n        return\n    rmtree(path)\n\n\ndef save_config(params_dict, logdir):\n    print(f\"Saving config to {logdir}\")\n    text = \"{\\n\\n\"\n    total_params = len(params_dict)\n    for count, key in enumerate(params_dict):\n        config_value = str(params_dict[key])\n        if re.search('[a-zA-Z]', config_value):\n            if config_value.lower() != 'true':\n                if config_value.lower() != 'false':\n                    if config_value[0] != '[':\n                        # TODO: Making a manual exception for parsing epsilon right now since it's the only number in\n                        #       scientific notation. Should fix this.\n                        if key != \"epsilon\":\n                            config_value = f'\"{config_value}\"'\n        if count == total_params - 1:\n            text += f'\"{str(key)}\"' + ' : ' + config_value + '\\n\\n'\n        else:\n            text += f'\"{str(key)}\"'  + ' : ' + config_value + ',\\n\\n'\n    text += '\\n\\n}'\n    sess = tf.InteractiveSession()\n    summary_op = tf.summary.text(\"run_config\", tf.convert_to_tensor(text))\n    summary_writer = tf.summary.FileWriter(f\"{logdir}/config\", sess.graph)\n    text = sess.run(summary_op)\n    summary_writer.add_summary(text, 0)\n    summary_writer.flush()\n    summary_writer.close()\n    tf.reset_default_graph()\n    print('Done!')\n\n\ndef expand_attention_types_params(params_list):\n    newlist = []\n    for item in params_list:\n        for _ in range(item[1]):\n            newlist.extend(item[0])\n    return newlist\n\n\ndef get_n_trainable_vars(graph):\n    \"\"\"\n    Gets number of trainable vars in a MTF model.\n\n    :param graph: Mesh-Tensorflow graph\n    :return: None\n    \"\"\"\n    total_parameters = 0\n    for variable in graph.trainable_variables:\n      shape = variable.shape.dims\n      variable_parameters = 1\n      for dim in shape:\n          variable_parameters *= dim.size\n      total_parameters += variable_parameters\n    print(f\"\\n\\nN TRAINABLE VARS:\\n{total_parameters:,}\\n\\n\")\n\n\ndef print_dim_names(graph):\n    \"\"\"\n    Print names of all Dimensions\n    :param graph: Mesh-Tensorflow graph\n    :return: None\n    \"\"\"\n    all_dim_names = []\n    for variable in graph.all_variables:\n        names = variable.shape.dimension_names\n        all_dim_names.append(names)\n\n    # Print all dim names in graph & write to file\n    all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims\n    unique_dims = list(set(all_dim_names))\n    print(\"ALL DIM NAMES:\")\n    for dim_name in unique_dims:\n        print(dim_name)\n    print('\\n')\n\n\ndef get_graph_info(graph):\n    \"\"\"\n    Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file\n    TODO: how to get un-trainable dim-names too, batch etc.\n\n    :param graph: Mesh-Tensorflow graph\n    :return: None\n    \"\"\"\n    get_n_trainable_vars(graph)\n    print_dim_names(graph)\n\n\ndef loss_denominator(targets, num_microbatches):\n    \"\"\"Denominator applied to losses.\n\n    This is usually the size of the targets tensor (omitting ensemble\n    dimensions).  Alternatively, it is an override value passed to the\n    class constructor.\n\n    Args:\n      targets: a mtf.Tensor\n      num_microbatches: an integer - greater than one if the step has been\n        serialized into multiple microbatches to save memory.\n    Returns:\n      a float\n    \"\"\"\n    ret = float(targets.shape.size) * num_microbatches\n    return float(ret)\n\ndef check_dataset(input_fn, params, global_step=None):\n    tf.enable_eager_execution()\n    if global_step is not None:\n        dataset = input_fn(params, global_step=global_step)\n    else:\n        dataset = input_fn(params)\n    dataset_iter = dataset.make_one_shot_iterator()\n    tensor, _ = next(dataset_iter)\n    enc = fetch_encoder(params)\n\n    for p in tensor[:1]:\n        txt = enc.decode(p)\n\n    print('-' * 50)\n    print(txt[:500], '\\n\\n...\\n\\n', txt[-500:])\n    print('-' * 50)\n    exit()\n\ndef auto_layout(graph, mesh_shape, logits, loss):\n    layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])\n    print(f\"Auto-selected layout:\\n{layout_rules}\\nRe-initialize graph with selected layout\")\n    quit() \n\ndef auto_layout_and_mesh_shape(graph, num_cores, logits, loss):\n    layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,\n                                                                    [logits, loss], max_mesh_shape_dimensions=4)\n    print(f\"Num cores:\\n{num_cores}\\nAuto-selected layout:\\n{layout_rules}\\nAuto-selected mesh shape:\\n{mesh_shape}\" \\\n            f\"\\nRe-initialize graph with selected layout & mesh shape\")\n    quit() \n\ndef create_host_call(model_dir):\n    \"\"\"Construct a host_call writing scalar summaries.\n\n    Borrowed from t2t.\n    \n    Args:\n        model_dir: String containing path to train\n    Returns:\n        (fn, args) Pair to be called by TPUEstimator as the host_call.\n    \"\"\"\n\n    graph = tf.get_default_graph()\n    # A list of (name, lowered tensor) tuples\n    summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)\n\n    def maybe_cast(tensor):\n        assert tensor.shape.is_compatible_with([]), tensor.name\n        if tensor.dtype == tf.int64:\n            return tf.to_int32(tensor)\n        if tensor.dtype == tf.bfloat16:\n            return tf.cast(tensor, tf.float32)\n        return tensor\n\n    reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]\n\n    # When no supported summaries are found, don't create host_call. Otherwise,\n    # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue\n    # it, eventually causing hang.\n    if not reshaped_tensors:\n        return None\n\n    def host_call_fn(global_step, *args):\n        \"\"\"Training host call. Creates scalar summaries for training metrics.\"\"\"\n        # This function is executed on the CPU and should not directly reference\n        # any Tensors in the rest of the `model_fn`. To pass Tensors from the\n        # model to the `model_fn`, provide as part of the `host_call`.\n        global_step = tf.cast(global_step[0], tf.int64)\n        with tf2.summary.create_file_writer(model_dir).as_default():\n            # We cannot directly use any tensor from summaries, because each\n            # tensor here must be a concat of multiple tensors from all shards.\n            # Therefore, we rely on the assumption that args wil have the same\n            # length as summaries, and all tensors in args will have the same\n            # order of self._tup_summaries.\n            assert len(args) == len(summaries)\n            for i, tensor in enumerate(args):\n                name = summaries[i][0]\n                tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)\n        return tf.summary.all_v2_summary_ops()\n\n    global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])\n    return host_call_fn, [global_step_t] + reshaped_tensors\n\n\ndef natural_sort(l): \n    convert = lambda text: int(text) if text.isdigit() else text.lower() \n    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] \n    return sorted(l, key = alphanum_key)\n"
  }
]