[
  {
    "path": ".gitignore",
    "content": "outputs\ntables/*/*.csv\ntables/*/*.csv#\ntables/*.csv\ntables/*.csv#\ntables/*.ods\n*.png\n*.pdf\n\n# torchdynamo debug\nisolate\nrepro.py\n\ncheckpoints\nwandb-metadata.json\n\ntorch_compile_debug/\n\ndedup\n\n.vs/\n\n*.pdf\nimages\n\n*.temp.sh\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n*.csv\n*.txt\n*.pth\n\ncramming-data/\nsanity.sh\nlog/\ndel.sh\ndel.py\nsort_plots/"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# precommit hooks from https://github.com/ashleve/lightning-hydra-template\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v3.4.0\n    hooks:\n      # list of supported hooks: https://pre-commit.com/hooks.html\n      - id: trailing-whitespace\n      - id: end-of-file-fixer\n      - id: check-yaml\n      - id: check-added-large-files\n      - id: debug-statements\n      - id: detect-private-key\n\n  # python code formatting\n  - repo: https://github.com/psf/black\n    rev: 22.3.0\n    hooks:\n      - id: black\n        args: [--line-length, \"140\", \"--fast\"] # ;>\n\n  # yaml formatting\n  - repo: https://github.com/pre-commit/mirrors-prettier\n    rev: v2.3.0\n    hooks:\n      - id: prettier\n        types: [yaml]\n\n  # python code analysis\n  - repo: https://github.com/PyCQA/flake8\n    rev: 4.0.1\n    hooks:\n      - id: flake8\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Sean McLeish, Jonas Geiping\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "# added by check-manifest\ninclude *.py\ninclude *.yaml\nrecursive-include cramming *.md\nrecursive-include cramming *.yaml\nglobal-exclude *.pyc\nglobal-exclude __pycache__\n"
  },
  {
    "path": "README.md",
    "content": "# Transformers Can Do Arithmetic with the Right Embeddings! [Link to arXiv paper](https://arxiv.org/abs/2405.17399)\n\nA joint project by: Sean McLeish, Arpit Bansal, Alex Stein,  Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild and Tom Goldstein\n\n\n\nThis repository contains code to replicate our research. It is a fork of the language model training framework [cramming](https://github.com/JonasGeiping/cramming) edited to for a next token prediction objective.\n\nWe provide a standalone implementation of Abacus Embeddings in [abacus.py](abacus.py).\n\n## Citing Our Work\nTo cite our work, please use this bibtex.\n```\n@article{mcleish2024transformers,\n    title={Transformers Can Do Arithmetic with the Right Embeddings}, \n    author={Sean McLeish and Arpit Bansal and Alex Stein and Neel Jain and John Kirchenbauer and Brian R. Bartoldson and Bhavya Kailkhura and Abhinav Bhatele and Jonas Geiping and Avi Schwarzschild and Tom Goldstein},\n    journal={arXiv preprint arXiv:2405.17399},\n    year={2024}\n}\n```\n\n# Getting Started\nWe developed in Python 3.10.4, to install run:\n```\ngit clone git@github.com:mcleish7/arithmetic.git\ncd arithmetic\npip install .\n```\n\nOn some machines you will need to run:\n1. `pip install multiprocess -U`\n2. `pip install dill -U`\n3. `pip install apache-beam -U`\n\n# Arithmetic\n## Datasets\nWe release our datasets on [Google Drive](https://drive.google.com/drive/folders/1DqjCrUM1cNV7069Zl25_qBw2Px2xAw9j?usp=sharing) both in zipped format. We recommend you work with the zipped version until it is correctly placed in your file system.\n\nAlternatively, you can make your own datasets using [create_data_split.py](create_data_split.py) using the commands from [shells/generate_and_tokenize_data.sh](shells/generate_and_tokenize_data.sh).\n\n## File Structure\nWe recommend creating another directory `cramming-data` inside of arithmetic. This is where the models, logs and data will be stored.\n\nYou can either export you cramming base directory path to your `.bashrc` or you can replace `$cramming_base_dir` manually in the provided shells.\n```\ncd arithmetic\nmkdir cramming-data\necho 'export cramming_base_dir=MY_BASE_DIR' >> ~/.bashrc\nsource ~/.bashrc\n```\nFor example, this may look like: `echo 'export cramming_base_dir=~/arithmetic/cramming-data' >> ~/.bashrc`\n\nFor example our file system looks like:\n```\ncramming-generative\n└── cramming-data\n    ├── addition-train-one\n    │    ├── pretrain/<DATE>/<TIME>\n    │    │    ├── .hydra\n    │    │    │   ├── config.yaml\n    │    │    │   ├── hydra.yaml\n    │    │    │   └── overrides.yaml\n    │    │    └── addition-train-one_pretrain.log\n    │    ├── checkpoints/FINAL_<LOSS_VAL>\n    │    │    ├── model_config.json\n    │    │    ├── model.safetensors\n    │    │    └── state_dict.pth\n    │    └── downstream\n    └── data\n        └── arithmetic_data\n            ├── +_grid_eval_dataset_reverse_all_tokenized\n            └── ... other datasets ...\n```\n\n## Training\nExample commands are in the [shells](shells) directory, organised by task.\n\n### Explanation of Some Commands\n1. Give samples instead of tokens equal importance in loss: `arch.loss_reduction=none`\n2. Divide the gradients in the recurrent block by the number of recurrences: `arch.throttle=True`\n3. Mask before the equals sign: `arch.mask_before_equals=True`\n4. Skip connections inside of the recurrent block: `arch.forward_only_model_with_skip=True`\n5. Multi-GPU: `python` -> `torchrun --nproc_per_node=<NUM GPUS> --standalone ` and add `impl.fullgraph=false`\n\n### Positional Embeddings:\n#### Absolute\n1. Learned: `arch.embedding.pos_embedding=learned`\n2. Abacus: `arch.embedding.pos_embedding=abacus`\n* If you want the maximum k in abacus to be larger: `arch.embedding.max_abacus_len=100`, be default this value is 100. Abacus is also implemented in a standalone manner in [abacus.py](abacus.py).\n\n#### Relative\n1. NoPE: `arch.embedding.pos_embedding=None`\n2. FIRE: `arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"`\n3. FIRE randomised: e.g:`arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" arch.attention.max_length=128` by default `arch.attention.max_length=0` so setting this longer than the max sequence length gives some randomness in the embedding.\n4. RoPE: `arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=true`\n\n### Checkpointing\nWe have implemented *single* GPU training checkpointing, to do this use:\n`impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last'`\nThis saves a checkpoint every 60 minutes under the name 'last'\n\nCaution: This feature is not fully tested for multi-GPU cases. We also cannot currently train models which have used their full budget for longer.\n\n### WandB\nYou can log runs to your weights&biases account. To do so, simply modify `wandb.entity` and `wandb.project` on the command line or at [cramming/config/wandb/default.yaml](cramming/config/wandb/default.yaml).\n\n## Testing\nWe show examples in [shells/evaluation.sh](shells/evaluation.sh). \n\nWe provide a very basic automation in [gen_eval_script.py](gen_eval_script.py), this prints the basic commands you may need to further edit these.\n\n### Addition\nFor addition we have a very large possible evaluation set, we do a grid search over a 100x100 grid which we split into 20 pieces with the aim of balancing the number of forward calls across all 20 pieces.\nWe then have a further eval for operand lengths 100->160.\n\n### Multiplication\nWe only evaluate up to 25x25, which we do in a single job.\n\n### Sorting\nSorting uses a separate evaluation file [sort_eval.py](sort_eval.py), this is because the evaluation calls cannot be parallelised, making evaluation much longer.\nThe evaluation cannot be parallelised because the place of the equals sign is not fixed for a batch.\nWe currently evaluate across 30 jobs for a 30x30 grid but this can be reduced to a smaller number of jobs using these flags: `max_size_given, start_ind_1_given, start_ind_2_given`\n\n### Bitwise OR\nWe use the same framework as for addition but the process is quicker as some of the batches do not contain 100 samples as there are not 100 possibilities for some batches. Unlike addition we do not sample with replacement for this task.\n\n# Analysis\n1. We provide [pretty_plotter.py](pretty_plotter.py) to combine the small evaluation grids together into one plot.\nUse this by putting the model name into the string at the top of the `main` function.\n2. For the large 100x100 grids we provide [pretty_plotter_big.py](pretty_plotter_big.py).\nThese are designed to be as flexible as possible but may need to be edited to fit your file set up.\n3. For sorting, we provide [pretty_plotter_sort.py](pretty_plotter_sort.py), this allows us to read the individual `.txt` files created during testing and merge them all together into a nice plot.\n\n# Contact\nPlease, feel free to contact us with any questions, or open an issue on Github."
  },
  {
    "path": "abacus.py",
    "content": "\"\"\"Implementation of abacus embeddings\"\"\"\n# Example of how to extract digit tokens to pass into constructor\n# digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])\n\nclass Abacus(torch.nn.Module):\n    \"\"\"\n    Abacus Embeddings, learned emebddings resued for each digit.\n    Integers must be reversed for this to work correctly.\n    Transformers Can Do Arithmetic with the Right Embeddings, McLeish et al. (2024)\n    \"\"\"\n    def __init__(self, digit_tokens, embedding_dim, max_seq_length=1024, max_k=99):\n        \"\"\"\n        digit_tokens (list): list of the tokens for each of the 10 digits, `digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])`\n        embedding_dim (int): dimension to embed into\n        max_seq_length (int): maximum number of embeddings that can be trained\n        max_k (int): maximum k value which we randomly shift by during training\n        \"\"\"\n        super().__init__()\n        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)\n        self.register_buffer(\"digits\", torch.tensor(digit_tokens), persistent=False)\n\n        self.max_k = max_k\n\n    def helper(self, mask, device):\n        \"\"\"\n        Converts a binary mask of digit locations into spans of consecutive digits\n        \"\"\"\n        mask_shape = mask.shape\n        \n        # Create a shifted version of the mask to detect changes from 0 to 1\n        shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)\n        starts = (shifted_mask != mask) & mask\n        \n        # Generate IDs for each segment of 1s, processing row-wise\n        segment_ids = torch.cumsum(starts, dim=1)\n        \n        # Generate an index array row-wise\n        index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)\n        \n        # Reset index at the start of each segment\n        reset_index = torch.zeros_like(mask).long()\n        second_term = index * starts.long()\n        reset_index = reset_index.scatter_add(1, segment_ids, second_term)\n        \n        # Calculate positions in segment\n        positions = index - reset_index.gather(1, segment_ids) + 1\n        \n        # Ensure only values within 1-segments are non-zero\n        result = positions * mask\n\n        return result\n\n    def forward(self, input_ids):\n        \"\"\"\n        input_ids (tensor): a batch of inputs, each row is a sample\n        \"\"\"\n        mask = torch.isin(input_ids, self.digits)\n        output = self.helper(mask, input_ids.device)\n\n        k=0\n        if self.training:\n            k = random.randint(0, self.max_k)\n            output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1\n\n        return self.embedding(output)\n"
  },
  {
    "path": "arithmetic_eval_quicker.py",
    "content": "import logging\nimport hydra\nfrom omegaconf import OmegaConf\nimport cramming\nimport torch\nfrom safetensors.torch import load_file\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport json\nimport numpy as np\nimport re\nimport pandas as pd\nimport datasets\nimport os\nfrom typing import List, Dict\nfrom cramming.data.tokenizer_preparation import get_tokenizer\nimport random\n\nlog = logging.getLogger(__name__)\n\ndef grid_plotter(data, type=\"accs\", name='_large', extra_path=None):\n    \"\"\"plot a 2d accuracy grid\"\"\"\n    data = np.array(data)*100\n    df = pd.DataFrame(data)\n\n    # Create the heatmap\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, cmap=\"YlGnBu\", fmt=\".1f\", annot_kws={'size': 8,'rotation':0})\n    \n    # Customize the plot\n    plt.title(\"Accuracy - percetange, rounded to 1dp\")\n    plt.ylabel(\"1st Number Length\")\n    plt.xlabel(\"2nd Number Length\")\n    size = data.shape[0]\n    plt.xticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))\n    plt.yticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))\n\n    if extra_path is not None:\n        plt.savefig(f\"{extra_path}{type}{name}_grid_plot\", bbox_inches='tight')\n    else:\n        plt.savefig(f\"{type}{name}_grid_plot\", bbox_inches='tight')\n    plt.clf()\n\ndef index_hints_helper(num, tokenizer):\n    \"\"\"Add index hints into a tokenized number\"\"\"\n    char_set = tokenizer.char_set\n    shape1 = num.shape[1]\n    for i in range(shape1):\n        this_char_token = tokenizer._convert_token_to_id(char_set[i])\n        char_to_insert = this_char_token * torch.ones((num.shape[0], 1), dtype=num.dtype, device=num.device)\n        num = torch.cat((num[:,:(2*i)], char_to_insert, num[:,(2*i):]), dim=1)\n    return num\n\ndef grid_logic(cfg):\n    \"\"\"logic to select function to control which part of a 2d grid this run should be responsible for evaling\"\"\"\n\n    # origional testing\n    def logic_func_large(data_size_1, data_size_2):\n        return (data_size_1 <= 23 or data_size_2 <=23)\n    logic_func = logic_func_large\n    name = '_large'\n    max_size = 23+1\n    \n    if cfg.ood_only:\n        def logic_func_ood(data_size_1, data_size_2):\n            return (data_size_1 >=24 or data_size_2 >=24) and (data_size_1 <= 30 or data_size_2 <=30)\n        logic_func = logic_func_ood\n        name = '_ood_only'\n        max_size = 30+1\n        \n    if cfg.up_to_40:\n        def logic_func_40(data_size_1, data_size_2):\n            return (data_size_1 >=31 or data_size_2 >=31) and (data_size_1 <=40 or data_size_2 <=40)\n        logic_func = logic_func_40\n        name = '_up_to_40'\n        max_size = 40+1\n        \n    if cfg.up_to_50:\n        def logic_func_50(data_size_1, data_size_2):\n            return (data_size_1 >=41 or data_size_2 >=41) and (data_size_1 <=50 or data_size_2 <=50)\n        logic_func = logic_func_50\n        name = '_up_to_50'\n        max_size = 50+1\n\n    # checkerboarding: for the large eval we can checkerboard:\n\n    if cfg.checkerboard is not None:\n        if cfg.checkerboard == 'even':\n            def checkerboard_even(data_size_1, data_size_2):\n                return ((data_size_1+data_size_2)%2 ==0)\n            checkerboard_func = checkerboard_even\n            checkerboard_str = \"_even\"\n        elif cfg.checkerboard == 'odd':\n            def checkerboard_odd(data_size_1, data_size_2):\n                return ((data_size_1+data_size_2)%2 ==1)\n            checkerboard_func = checkerboard_odd\n            checkerboard_str = \"_odd\"\n        else:\n            print(\"checkerboard config not allowed\")\n            exit()\n    else:\n        def always_true(data_size_1, data_size_2):\n            return True\n        checkerboard_func = always_true\n        checkerboard_str = \"\"\n\n\n    # if we are testing up to 100, split into 10 steps each of approximately equal number of forward passes required\n    if cfg.big_eval_step_1: # 1 -> 46\n        def logic_func_big_1(data_size_1, data_size_2):\n            return (data_size_1 <= 46 and data_size_2 <= 46) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_1\n        name = '_big_eval_1'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_2: # 47 -> 58\n        def logic_func_big_2(data_size_1, data_size_2):\n            return (data_size_1 >=47 or data_size_2 >=47) and (data_size_1 <=58 and data_size_2 <=58) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_2\n        name = '_big_eval_2'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_3: # 59 -> 67\n        def logic_func_big_3(data_size_1, data_size_2):\n            return (data_size_1 >=59 or data_size_2 >=59) and (data_size_1 <=67 and data_size_2 <=67) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_3\n        name = '_big_eval_3'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_4: # 68 -> 74\n        def logic_func_big_4(data_size_1, data_size_2):\n            return (data_size_1 >=68 or data_size_2 >=68) and (data_size_1 <=74 and data_size_2 <=74) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_4\n        name = '_big_eval_4'+checkerboard_str\n        max_size = 100+1\n      \n    if cfg.big_eval_step_5: # 75 -> 80\n        def logic_func_big_5(data_size_1, data_size_2):\n            return (data_size_1 >= 75 or data_size_2 >=75) and (data_size_1 <=80 and data_size_2 <=80) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_5\n        name = '_big_eval_5'+checkerboard_str\n        max_size = 100+1\n\n    if cfg.big_eval_step_6: # 81 -> 85\n        def logic_func_big_6(data_size_1, data_size_2):\n            return (data_size_1 >= 81 or data_size_2 >=81) and (data_size_1 <=85 and data_size_2 <=85) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_6\n        name = '_big_eval_6'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_7: # 86 -> 90\n        def logic_func_big_7(data_size_1, data_size_2):\n            return (data_size_1 >= 86 or data_size_2 >=86) and (data_size_1 <=90 and data_size_2 <=90) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_7\n        name = '_big_eval_7'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_8: # 91 -> 94\n        def logic_func_big_8(data_size_1, data_size_2):\n            return (data_size_1 >= 91 or data_size_2 >=91) and (data_size_1 <=94 and data_size_2 <=94) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_8\n        name = '_big_eval_8'+checkerboard_str\n        max_size = 100+1\n    \n    if cfg.big_eval_step_9: # 95 -> 97\n        def logic_func_big_9(data_size_1, data_size_2):\n            return (data_size_1 >= 95 or data_size_2 >=95) and (data_size_1 <=97 and data_size_2 <=97) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_9\n        name = '_big_eval_9'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_10: # 98 -> 100\n        def logic_func_big_10(data_size_1, data_size_2):\n            return (data_size_1 >= 98 or data_size_2 >=98) and (data_size_1 <=100 and data_size_2 <=100) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_10\n        name = '_big_eval_10'+checkerboard_str\n        max_size = 100+1\n\n    # boolean_list_precidence = [large, ood_only, up_to_40, up_to_50, big_eval_step_1, big_eval_step_2, big_eval_step_3, big_eval_step_4, big_eval_step_5]\n\n    log.info(f\"large = {cfg.large}\")\n    log.info(f\"ood only = {cfg.ood_only}\")\n    log.info(f\"up to 40 = {cfg.up_to_40}\")\n    log.info(f\"up to 50 = {cfg.up_to_50}\")\n    log.info(f\"big eval 1 = {cfg.big_eval_step_1}\")\n    log.info(f\"big eval 2 = {cfg.big_eval_step_2}\")\n    log.info(f\"big eval 3 = {cfg.big_eval_step_3}\")\n    log.info(f\"big eval 4 = {cfg.big_eval_step_4}\")\n    log.info(f\"big eval 5 = {cfg.big_eval_step_5}\")\n    log.info(f\"big eval 6 = {cfg.big_eval_step_6}\")\n    log.info(f\"big eval 7 = {cfg.big_eval_step_7}\")\n    log.info(f\"big eval 8 = {cfg.big_eval_step_8}\")\n    log.info(f\"big eval 9 = {cfg.big_eval_step_9}\")\n    log.info(f\"big eval 10 = {cfg.big_eval_step_10}\")\n    log.info(f\"the last true value in the above list will be run, mul and pos arith can take control after this\")\n\n    return logic_func, name, max_size\n\ndef main(cfg):\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, \"checkpoints\")\n    tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,\n                                                                                local_checkpoint_folder,\n                                                                                cfg.eval.arch_modifications)\n    if cfg.max_rec is not None: # can have more/less recurrences for eval\n        cfg_arch.maximal_recurrence_in_eval = cfg.max_rec\n    else:\n        cfg_arch.maximal_recurrence_in_eval = cfg_arch.maximal_recurrence\n    log.info(f\"cfg_arch.maximal_recurrence_in_eval changed to {cfg_arch.maximal_recurrence_in_eval}\")\n    cfg_arch.throttle = False # turn throttle off\n\n    logic_func, name, max_size = grid_logic(cfg)\n\n    if cfg.mul: # multiplication\n        def logic_func_for_mul(data_size_1, data_size_2):\n            return (data_size_1 <= 25 or data_size_2 <= 25)\n        logic_func = logic_func_for_mul\n        name = '_large'\n        max_size = 25+1\n    log.info(f\"mul = {cfg.mul}\")\n\n    if cfg.pos_arth: # bitwise OR\n        def logic_func_for_pos(data_size_1, data_size_2):\n            return (data_size_1 <= 25 or data_size_2 <= 25)\n        logic_func = logic_func_for_pos\n        name = '_large'\n        max_size = 25+1\n    log.info(f\"pos_arth = {cfg.pos_arth}\")\n\n    if cfg.pos_arth_ood:\n        def logic_func_for_pos_ood(data_size_1, data_size_2):\n            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)\n        logic_func = logic_func_for_pos_ood\n        name = '_ood_only'\n        max_size = 40+1\n    log.info(f\"pos_arth_ood = {cfg.pos_arth_ood}\")\n\n    # import tokeniser\n    cfg_data_sources_values_list = list(cfg.data.sources.values())[0]\n    if cfg_data_sources_values_list[\"provider\"] == \"arithmetic\":\n        tokenizer = get_tokenizer(cfg_data_sources_values_list[\"tokenizer_type\"])\n    else: \n        log.info(\"exiting as this is only for arithmetic\")\n        exit()\n    vocab = tokenizer.ids_to_tokens\n    EOS_token = tokenizer._convert_token_to_id(tokenizer.eos_token)\n    PAD_token = tokenizer._convert_token_to_id(tokenizer.pad_token)\n    assert PAD_token == 0, \"PAD token must be token zero for our code to work\"\n\n    # Load model\n    if 'alpha' not in cfg_arch:\n        cfg_arch['alpha'] = 1.0\n    model = cramming.construct_model(cfg_arch, tokenizer).to(device)\n    model = cramming.backend.load_model_checkpoint(model, model_file)\n    model.to(device)\n    model.eval()\n\n    log.info(f\"greedy = {cfg.greedy}, note: if greedy = True this overrides any temperature arguments\")\n    ## Greedy decoding will overide any temperature arguments\n\n    if cfg.max_size_given is not None: # allows unique splits for eval\n        max_size = max_size_given\n\n    # Grid plots - grid search from 1x1 to 12x12 data\n    data_sizes = list(range(1, max_size))\n    acc_grid = np.zeros((len(data_sizes),len(data_sizes)))\n    start_ind_1 = 0\n    start_ind_2 = 0\n    tuple_method = False\n    completed_one = False\n    if \"big_eval\" in name:\n        tuple_method = True\n        # go up two layers and search for grid\n        try:\n            with open(f\"../../accs_grid_quick{name}.json\", 'r') as file:\n                data = json.load(file)\n            start_ind_1 = data[1]\n            start_ind_2 = data[2]\n            acc_grid = np.array(data[0])\n            log.info(\"loaded grid from previous run\")\n        except:\n            pass\n\n    if cfg.start_ind_1_given is not None: # allows unique splits for eval\n        start_ind_1 = cfg.start_ind_1_given\n    if cfg.start_ind_2_given is not None:\n        start_ind_2 = cfg.start_ind_2_given\n    log.info(f\"start_ind_1 = {start_ind_1}, start_ind_2 = {start_ind_2}\")\n\n    os.makedirs(\"outputs\", exist_ok=True)\n\n    if not cfg.extended_eval:\n        # main 2d loop\n        for data_size_1 in data_sizes:\n            for data_size_2 in data_sizes:\n                if (data_size_1 < start_ind_1 or data_size_2 < start_ind_2) and not completed_one:\n                    continue\n                else:\n                    proceed = False\n                    # if both data sizes are less than the start indices, then dont proceed\n                    # but if one of them is greater than the start indices, then proceed\n                    if data_size_1 >= start_ind_1 or data_size_2 >= start_ind_2:\n                        proceed = True\n                        \n                    if not proceed:\n                        continue\n\n                print(f\"evaluating for {data_size_1} and {data_size_2}\")\n\n                if logic_func(data_size_1, data_size_2):\n                    completed_one = True\n                    log.info(f\"Starting iteration in grid eval for size: {data_size_1} and {data_size_2}\")\n                    correct_total = 0\n\n                    # get the correct dataset, these names may need to be changed if you make new datasets\n                    file_path = f\"../../../../data/arithmetic_data/+_grid_eval_dataset_padded_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset\"\n                    if cfg.reverse_inputs:\n                        file_path = f\"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset\"\n                    if cfg.mul:\n                        file_path = f\"../../../../data/arithmetic_data/x_grid_eval_dataset_2_reverse_all_tokenized/x_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_exact_seed_91/hf_tokenized_dataset\"\n                    if cfg.pos_arth or cfg.pos_arth_ood:\n                        file_path = f\"../../../../data/arithmetic_data/pos_or_one_vec_zeros_eval/or_one_vec_zeros_{data_size_1}_{data_size_2}/hf_tokenized_dataset\"\n                    tokenized_dataset = datasets.load_from_disk(file_path)[\"test\"]\n                    data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)\n                    equals_tensor = data_size_1+data_size_2+6\n                    if cfg.pos_arth or cfg.pos_arth_ood:\n                        equals_tensor = data_size_1+data_size_2+2\n\n                    for batch in data_loader:\n                        # split prompt and answer\n                        tokenized_prompts = batch[\"input_ids\"][:equals_tensor]\n                        tokenized_prompts = torch.stack(tokenized_prompts).to(device)\n                        tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)\n                        tokenized_answers = batch[\"input_ids\"][equals_tensor:]\n                        tokenized_answers = torch.stack(tokenized_answers).to(device)\n                        tokenized_answers = torch.transpose(tokenized_answers, 0, 1)\n   \n                        if cfg.remove_padding and (cfg_data_sources_values_list[\"tokenizer_type\"] != \"index\"):\n                            # removes the padding from the eval data\n                            num1 = tokenized_prompts[:,:data_size_1]\n                            op = tokenized_prompts[:,data_size_1+1:data_size_1+2]\n                            num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]\n                            equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]\n                            tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)\n \n                        if cfg_data_sources_values_list[\"tokenizer_type\"] == \"index\":\n                            # adding in the index hints to the input numbers\n                            num1 = tokenized_prompts[:,:data_size_1]\n                            num1 = index_hints_helper(num1, tokenizer)\n                            op = tokenized_prompts[:,data_size_1+1:data_size_1+2]\n                            num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]\n                            num2 = index_hints_helper(num2, tokenizer)\n                            equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]\n                            tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)\n\n                            predicted_ids = None\n\n                            ## below inserts the characters for the model, we decided against this in the end\n                            predicted_ids = model._generate(tokenized_prompts, token_limit=(tokenized_answers.shape[1]*2), temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)\n                            predicted_ids = torch.transpose(predicted_ids, 0, 1)\n\n                            new_tensor = torch.zeros_like(predicted_ids)\n                            for i in range(predicted_ids.size(0)): # inefficient!!\n                                # Filter out values greater than 17\n                                filtered_values = predicted_ids[i][predicted_ids[i] <= 17]\n                                # Place filtered values in new tensor and pad with zeros\n                                new_tensor[i, :len(filtered_values)] = filtered_values\n\n                            predicted_ids = new_tensor[:, :tokenized_answers.shape[1]] # trim off the excess\n                            predicted_ids = torch.transpose(predicted_ids, 0, 1)\n\n                        else:\n                            predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)\n                        \n                        if len(predicted_ids.shape) > 1: # i.e. we have a batch of more than one\n                            predicted_ids = torch.transpose(predicted_ids, 0, 1)\n                        else:\n                            predicted_ids = predicted_ids.reshape((1,-1)) # add a batch dim otherwise\n                            \n                    # ignore everything after EOS on eval but replacing all after EOS with PAD\n                    eval_tensor = predicted_ids.clone()\n                    input_tensor_EOS = (eval_tensor == EOS_token).int()\n                    indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)\n                    mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]\n                    eval_tensor[mask] = PAD_token\n                    \n                    # compare eval tensor to correct outputs\n                    elementwise_equal = torch.eq(eval_tensor, tokenized_answers)\n                    rows_equal = torch.all(elementwise_equal, dim=1)\n                    num_equal_rows = torch.sum(rows_equal).item()\n                    correct_total += (num_equal_rows/tokenized_prompts.shape[0])\n                    log.info(f\"accuracy for {data_size_1}, {data_size_2}: {num_equal_rows} = {correct_total*100}%\")\n\n                    # combine the prompts and outputs\n                    complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)\n                    tokens_list = complete_lines.tolist()\n                    decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens\n                    log.info(f\"example for {data_size_1}, {data_size_2}: {decoded_batch[0]}\")\n                    # save the answers down so we don't eval twice ever\n                    with open(f\"outputs/+_n_{data_size_1}_m_{data_size_2}.json\", 'w') as json_file:\n                        json.dump(decoded_batch, json_file)\n\n                    acc_grid[(data_size_1-1),(data_size_2-1)] = correct_total\n\n                    if tuple_method:\n                        with open(f\"../../accs_grid_quick{name}.json\", \"w\") as file:\n                            tuple_to_save = (acc_grid.tolist(),data_size_1,data_size_2)\n                            json.dump(tuple_to_save, file)\n\n        log.info(f\"acc grid: {acc_grid}\")\n\n        with open(f\"accs_grid_quick{name}.json\", \"w\") as file:\n            json.dump(acc_grid.tolist(), file)\n        \n        # Grid plots - one for accs one for contains\n        grid_plotter(acc_grid, name=name)\n\n    if cfg.extended_eval:\n        # extended eval to eval large numbers easily, used the large eval numebers to split up into multiple parts\n\n        number = int(re.findall(r'\\d+', name)[0])\n        log.info(\"starting extended eval\")\n        # this is hard coded for reverse all, addition past 100x100 grid, removing the padding\n\n        accs = dict()\n        batch_size_extended_eval = 100\n\n        old_data_path = None\n        for root, dirs, files in os.walk(\"../..\"):\n            if f\"over_100_{number}.json\" in files:\n                old_data_path = os.path.join(root, f\"over_100_{number}.json\")\n\n        if number == 1:\n            start = 101\n            list_to_do = range(start,161)\n        elif number == 2:\n            list_to_do = [1000, 800]\n        elif number == 3:\n            list_to_do = [200, 700, 900]\n        elif number == 4:\n            list_to_do = [300, 400, 500, 600]\n        else:\n            print(\"number too high\")\n            exit()\n\n        if old_data_path is not None: # read the old accs dict and don't repeat what we have already done\n            with open(old_data_path, 'r') as file:\n                data = json.load(file)\n            accs = {int(k): v for k, v in data.items()}\n            to_do = set(list_to_do).difference(set(accs.keys()))\n            list_to_do = list(to_do)\n\n        log.info(f\"In extended eval with number {number}\")\n\n        for data_size in list_to_do:\n            log.info(f\"Extended eval {data_size}\")\n            correct_total = 0\n            file_path = f\"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized_over_100/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_42/hf_tokenized_dataset\"\n            tokenized_dataset = datasets.load_from_disk(file_path)[\"test\"]\n            data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=batch_size_extended_eval, shuffle=False)\n            equals_tensor = data_size+data_size+6\n\n            for batch in data_loader:\n                # get prompt and answer\n                tokenized_prompts = batch[\"input_ids\"][:equals_tensor]\n                tokenized_prompts = torch.stack(tokenized_prompts).to(device)\n                tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)\n                tokenized_answers = batch[\"input_ids\"][equals_tensor:]\n                tokenized_answers = torch.stack(tokenized_answers).to(device)\n                tokenized_answers = torch.transpose(tokenized_answers, 0, 1)\n\n                # remove the padding\n                num1 = tokenized_prompts[:,:data_size]\n                op = tokenized_prompts[:,data_size+1:data_size+2]\n                num2 = tokenized_prompts[:,data_size+3:data_size+data_size+3]\n                equals = tokenized_prompts[:,data_size+data_size+4:data_size+data_size+5]\n                tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)\n\n                # get the output from the model\n                predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)\n                predicted_ids = torch.transpose(predicted_ids, 0, 1) # add a batch dim\n\n                eval_tensor = predicted_ids.clone()\n                input_tensor_EOS = (eval_tensor == EOS_token).int()\n                indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)\n                mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]\n                eval_tensor[mask] = PAD_token\n                elementwise_equal = torch.eq(eval_tensor, tokenized_answers)\n                \n                rows_equal = torch.all(elementwise_equal, dim=1)\n                num_equal_rows = torch.sum(rows_equal).item()\n                correct_total += (num_equal_rows/tokenized_prompts.shape[0])\n                log.info(f\"accuracy for {data_size}, {data_size}: {num_equal_rows} = {correct_total*100}%\")\n\n                # combine the prompts and outputs\n                complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)\n                tokens_list = complete_lines.tolist()\n                decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens\n                log.info(f\"example for {data_size}, {data_size}: {decoded_batch[0]}\")\n                # save the answers down so we don't eval twice ever\n\n            accs[data_size] = correct_total\n            with open(f\"over_100_{number}.json\", 'w') as json_file:\n                    json.dump(accs, json_file)\n                    \n    log.info(\"Eval complete\")\n\n@hydra.main(config_path=\"cramming/config\", config_name=\"cfg_eval\", version_base=\"1.3\")\ndef launch(cfg):\n    log.info(\"calling main launch\")\n    cfg = cramming.utils.pathfinder(cfg)\n    log.info(OmegaConf.to_yaml(cfg, resolve=True))\n    main(cfg)\n\nif __name__ == \"__main__\":\n    launch()"
  },
  {
    "path": "cramming/__init__.py",
    "content": "\"\"\"Initialize cramming\"\"\"\n\nfrom cramming import utils\nfrom cramming.architectures import construct_model\nfrom cramming.backend import load_backend\nfrom cramming.data import load_pretraining_corpus, prepare_dataloaders\n\n\n__all__ = [\n    \"construct_model\",\n    \"load_backend\",\n    \"prepare_dataloaders\",\n    \"load_pretraining_corpus\",\n    \"utils\",\n]\n\n\nimport hydra\n\n\"\"\"Construct interfaces to some cfg folders for use in packaged installations:\"\"\"\n\n\ndef get_config(overrides=[]):\n    \"\"\"Return default hydra config.\"\"\"\n    with hydra.initialize(config_path=\"config\"):\n        cfg = hydra.compose(config_name=\"cfg\", overrides=overrides)\n        print(f\"Loading default config {cfg.name}.\")\n    return cfg\n\n\ndef get_model_config(arch=\"hf-bert-tiny\", overrides=[]):\n    \"\"\"Return default hydra config for a given attack.\"\"\"\n    with hydra.initialize(config_path=\"config/arch\"):\n        cfg = hydra.compose(config_name=arch, overrides=overrides)\n        print(f\"Loading model configuration {cfg.architecture}.\")\n    return cfg\n\n\ndef get_backend_config(backend=\"torch-default\", overrides=[]):\n    \"\"\"Return default hydra config for a given attack.\"\"\"\n    with hydra.initialize(config_path=\"config/impl\"):\n        cfg = hydra.compose(config_name=backend, overrides=overrides)\n        print(f\"Loading backend {cfg.name}.\")\n    return cfg\n"
  },
  {
    "path": "cramming/architectures/__init__.py",
    "content": "\"\"\"This module handles all questions of model architecture.\"\"\"\n\nfrom .construction import construct_model\n\n__all__ = [\"construct_model\"]\n"
  },
  {
    "path": "cramming/architectures/attention.py",
    "content": "\"\"\"Attention modules. Most code heavily stolen from the GPT-neoX implementation\"\"\"\nimport torch\nfrom transformers.models.bert.modeling_bert import BertSelfAttention\n\nfrom .embeddings import Rotary, RotarySanityCheck, RotaryEleutherAI, RotaryLLAMA, FIRE\nfrom typing import Optional\n\nfrom torch.nn.modules.linear import NonDynamicallyQuantizableLinear  # use to mark output projections of attn while it exists\n\n\ndef get_attention_mechanism(idx, hidden_size, cfg_attention, norm_fn: torch.nn.Identity):\n    # ########## main implementation\n    if cfg_attention.type == \"self-attention\":\n        mechanism = SeqFirstSelfAttention(hidden_size, cfg_attention, norm_fn)  # neox\n    # ########## other things:\n    elif cfg_attention.type == \"pytorch\":\n        mechanism = SelfAttentionPyTorch(hidden_size, cfg_attention)  # torch default\n    elif cfg_attention.type == \"pytorch-seqfirst\":\n        mechanism = SeqFirstSelfAttentionPyTorch(hidden_size, cfg_attention)  # torch default\n    elif cfg_attention.type == \"huggingface\":\n        mechanism = BertAttentionWrapper(hidden_size, cfg_attention)  # always includes bias!\n    elif cfg_attention.type == \"fourier\":\n        mechanism = FourierMixing(hidden_size, cfg_attention)\n    elif cfg_attention.type == \"none\":\n        mechanism = Identity(hidden_size)\n    elif cfg_attention.type == \"rn\":\n        mechanism = RandomNoise(hidden_size) # i.e. no signal on where to look\n    else:\n        raise ValueError(f\"Invalid attention type {cfg_attention.type} given.\")\n    return mechanism\n\n\nclass Identity(torch.nn.Module):\n    \"\"\"mini wrapper around BERT attention from huggingface for sanity checks.\"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[B S H]\"\n\n    def __init__(self, hidden_size):\n        super().__init__()\n        self.output_dim = hidden_size\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        return hidden_states\n\nclass RandomNoise(torch.nn.Module):\n    \"\"\"mini wrapper around BERT attention from huggingface for sanity checks.\"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[B S H]\"\n\n    def __init__(self, hidden_size):\n        super().__init__()\n        self.output_dim = hidden_size\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        print(\"using rn\")\n        return hidden_states + torch.normal(0, 0.1, hidden_states.shape).to(hidden_states.device)\n\nclass BertAttentionWrapper(BertSelfAttention):\n    \"\"\"mini wrapper around BERT attention from huggingface for sanity checks.\"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[B S H]\"\n\n    def __init__(self, hidden_size, cfg_attention):\n        class config:\n            pass\n\n        config.hidden_size = hidden_size\n        config.num_attention_heads = cfg_attention.num_attention_heads\n        config.attention_probs_dropout_prob = 0.0\n        config.is_decoder = True\n\n        super().__init__(config)\n        if cfg_attention.skip_output_projection:\n            self.dense = torch.nn.Identity()\n        else:\n            self.dense = torch.nn.Linear(hidden_size, hidden_size, bias=cfg_attention.bias_in_proj)\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        return self.dense(super().forward(hidden_states, attention_mask)[0])\n\n\nclass SelfAttentionPyTorch(torch.nn.Module):\n    \"\"\"Minimal wrapper around pytorch self attention.\"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[B S H]\"\n\n    def __init__(self, hidden_size, cfg_attention):\n        super().__init__()\n        self.attn = torch.nn.MultiheadAttention(\n            hidden_size,\n            cfg_attention.num_attention_heads,\n            dropout=0.0,\n            batch_first=True,\n            bias=cfg_attention.bias_in_proj,\n            add_bias_kv=cfg_attention.qkv_bias,\n        )\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        return self.attn(\n            hidden_states,\n            hidden_states,\n            hidden_states,\n            attn_mask=attention_mask[0, 0, :, :],\n            need_weights=False,\n            is_causal=True,\n        )[0]\n\n\nclass SeqFirstSelfAttentionPyTorch(torch.nn.Module):\n    \"\"\"Minimal wrapper around pytorch self attention.\"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[S B H]\"\n\n    def __init__(self, hidden_size, cfg_attention):\n        super().__init__()\n        self.attn = torch.nn.MultiheadAttention(\n            hidden_size,\n            cfg_attention.num_attention_heads,\n            dropout=0.0,\n            batch_first=False,\n            bias=cfg_attention.bias_in_proj,\n            add_bias_kv=cfg_attention.qkv_bias,\n        )\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        return self.attn(\n            hidden_states,\n            hidden_states,\n            hidden_states,\n            attn_mask=attention_mask[0, 0, :, :],\n            need_weights=False,\n            is_causal=True,\n        )[0]\n\n\nclass SeqFirstSelfAttention(torch.nn.MultiheadAttention):\n    \"\"\"Self-attention layer.\n\n    This is the gpt neo-x implementation from:\n    https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py (which is a megatron variant)\n\n    This is a modified version of the neo-x implementation that I can manage to compile without graph breaks.\n\n    Inherits from MultiheadAttention to catch the same initialization\n    \"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT: str = \"[S B H]\"\n\n    def __init__(self, hidden_size: int, cfg_attention, norm_module=torch.nn.Identity):\n        torch.nn.Module.__init__(self)\n        self.hidden_size = hidden_size\n        self.num_attention_heads = cfg_attention.num_attention_heads\n        self.hidden_per_head = self.hidden_size // cfg_attention.num_attention_heads\n        self.register_buffer(\"norm_factor\", torch.tensor(self.hidden_per_head).rsqrt())\n        self.cfg_attention = cfg_attention\n        self.use_fire = False\n\n        self.norm = norm_module()\n\n        # Strided linear layer.\n        self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * self.hidden_size, self.hidden_size))\n        if cfg_attention.qkv_bias:\n            self.in_proj_bias = torch.nn.Parameter(torch.zeros(3 * self.hidden_size))\n        else:\n            self.in_proj_bias = None\n        self.bias_k, self.bias_v = None, None  # for compat with MultiheadAttention\n\n        self.output_dim = hidden_size\n        if cfg_attention.rotary_embedding == \"sanity\":\n            self.rotary_emb = RotarySanityCheck(self.hidden_per_head, seq_dim=0)\n        elif cfg_attention.rotary_embedding == \"v2\":\n            self.rotary_emb = RotaryEleutherAI(self.hidden_per_head)\n        elif cfg_attention.rotary_embedding == \"llama\":\n            self.rotary_emb = RotaryLLAMA(self.hidden_per_head)\n        elif cfg_attention.rotary_embedding == \"fire\":\n            self.rotary_emb = FIRE(cfg_attention.num_attention_heads, max_length=cfg_attention.max_length)\n            self.use_fire = True\n        elif cfg_attention.rotary_embedding:\n            self.rotary_emb = Rotary(self.hidden_per_head, seq_dim=0)\n        else:\n            self.rotary_emb = None\n            \n        if cfg_attention.sequence_op == \"torch-softmax\":\n            self.sequence_op = TorchSoftmax(cfg_attention.seq_op_in_fp32)\n        elif cfg_attention.sequence_op == \"shaped-attention\":\n            self.sequence_op = TorchShaped(cfg_attention.seq_op_in_fp32, hidden_size=self.hidden_size)\n        elif cfg_attention.sequence_op == \"swin-cosine\":\n            self.sequence_op = SwinCosine(cfg_attention.seq_op_in_fp32)\n        elif cfg_attention.sequence_op == \"torch-norm\":\n            self.sequence_op = TorchNormalize(self.num_attention_heads, cfg_attention.seq_op_in_fp32)\n        elif cfg_attention.sequence_op == \"none\":\n            self.sequence_op = ScaledIdentity(cfg_attention.seq_op_in_fp32)\n        elif cfg_attention.sequence_op == \"cumsum\":\n            self.sequence_op = Cumsum(cfg_attention.seq_op_in_fp32)\n        elif cfg_attention.sequence_op == \"cumsumexp\":\n            self.sequence_op = CumsumExp(cfg_attention.seq_op_in_fp32)\n        else:\n            raise ValueError(f\"Invalid sequence operation {cfg_attention.sequence_op} given.\")\n\n        if cfg_attention.skip_output_projection:\n            self.out_proj = torch.nn.Identity()\n        else:\n            self.out_proj = NonDynamicallyQuantizableLinear(hidden_size, hidden_size, bias=cfg_attention.bias_in_proj)\n\n        self.attention_func = self.attention\n\n    def attention(self, query_layer, key_layer, value_layer, attention_mask: Optional[torch.Tensor] = None, training: bool = False, fire: Optional[torch.Tensor] = None):\n        # ===================================\n        # Raw attention scores. [b, np, s, s]\n        # ===================================\n\n        # [b, np, sq, sk]\n        output_size = (query_layer.shape[1], query_layer.shape[2], query_layer.shape[0], key_layer.shape[0])\n\n        # [sq, b, np, hn] -> [sq, b * np, hn]\n        query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)\n        key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)\n\n        # this better be fused in a clever way:\n        matmul_result = torch.bmm(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) * self.norm_factor\n\n        # change view to [b, np, sq, sk]\n        attention_scores = matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])\n        if fire is not None:\n            attention_scores += fire\n\n        # ===========================\n        # Attention probs\n        # ===========================\n        # attention scores and attention mask [b, np, sq, sk]\n        attention_probs = self.sequence_op(attention_scores, attention_mask)\n\n        # =========================\n        # Context layer. [sq, b, hp]\n        # =========================\n\n        # value_layer -> context layer.\n        # [sk, b, np, hn] --> [b, np, sq, hn]\n\n        # context layer shape: [b, np, sq, hn]\n        output_size = (value_layer.shape[1], value_layer.shape[2], query_layer.shape[0], value_layer.shape[3])\n\n        # change view [sk, b * np, hn]\n        value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)\n\n        # change view [b * np, sq, sk]\n        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n\n        # matmul: [b * np, sq, hn]\n        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))\n\n        # change view [b, np, sq, hn]\n        context_layer = context_layer.view(*output_size)\n        return context_layer\n    \n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        # =====================\n        # hidden_states: [sq, b, h]\n        # Query, Key, and Value\n        # =====================\n        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]\n        mixed_x_layer = torch.nn.functional.linear(hidden_states, self.in_proj_weight, self.in_proj_bias)\n\n        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]\n        # new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, 3 * self.hidden_per_head)\n        mixed_x_layer = mixed_x_layer.view(\n            hidden_states.shape[0], hidden_states.shape[1], self.num_attention_heads, 3 * self.hidden_per_head\n        )\n        # print(\"mixed shape \",mixed_x_layer.shape) (82, 24, 16, 192)\n\n        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]\n        (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [self.hidden_per_head] * 3, dim=3)\n        \n        fire = None\n        if self.rotary_emb is not None:\n            if self.use_fire:\n                fire = self.rotary_emb(query_layer.size(0), query_layer.device)\n            else:\n                query_layer, key_layer = self.rotary_emb(query_layer, key_layer)\n                # print(query_layer.shape)\n\n        # ==================================\n        # Attention computation\n        # ==================================\n        context_layer = self.attention_func(query_layer, key_layer, value_layer, attention_mask, self.training, fire)\n\n        # [b, np, sq, hn] --> [sq, b, np, hn]\n        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()\n\n        # [sq, b, np, hn] --> [sq, b, hp]\n        # new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)\n        context_layer = context_layer.view(context_layer.shape[0], context_layer.shape[1], self.hidden_size)\n        return self.out_proj(self.norm(context_layer))\n\n\nclass FourierMixing(torch.nn.Module):\n    \"\"\"Fourier mixing layer as described in the FNet paper.\n    Layer takes input with size [Batch, Seq, Hidden] and returns output of the same size.\n    This function can take an attention mask as input, but will ignore it.\n    \"\"\"\n\n    __constants__ = [\"LAYOUT\"]\n    LAYOUT = \"[B S H]\"\n\n    def __init__(self, hidden_size, cfg_attention):\n        super().__init__()\n        self.fft_op_in_fp32 = True  # Always necessary (atleast on pytorch 1.12)\n        self.output_dim = hidden_size\n        if cfg_attention.rotary_embedding:\n            if cfg_attention.low_level_fusion:\n                self.rotary_emb = torch.jit.script(Rotary(hidden_size, seq_dim=1))\n            else:\n                self.rotary_emb = Rotary(hidden_size, seq_dim=0)\n        else:\n            self.rotary_emb = None\n\n    def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):\n        \"\"\"Forward will take an attention mask but ignore it!\"\"\"\n\n        if self.rotary_emb is not None:\n            # full rotary (mostly on for compatibility, no guarantees on this being non-terrible)\n            cos, sin = self.rotary_emb.get_cos_sin_cache(hidden_states)\n            hidden_states = (hidden_states * cos[:, 0]) + (self.rotary_emb.rotate_half(hidden_states) * sin[:, 0])\n\n        if self.fft_op_in_fp32:\n            hidden_state_dtype = hidden_states.dtype\n            hidden_states = hidden_states.float()\n        else:\n            hidden_state_dtype = None\n\n        # Implementation 1:\n        # hidden_states = torch.fft.fft(torch.fft.fft(hidden_states, dim=0, , norm=\"ortho\"), dim=2, , norm=\"ortho\").real\n        # Implementation 2:\n        hidden_states = torch.fft.fftn(hidden_states, dim=(1, 2), norm=\"ortho\").real  # could also cast into angle?\n\n        if self.fft_op_in_fp32:\n            hidden_states = hidden_states.to(hidden_state_dtype)\n\n        return hidden_states\n\n\nclass TorchSoftmax(torch.nn.Module):\n    def __init__(self, seq_op_in_fp32=False):\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        if attention_mask is not None:\n            inputs = inputs.masked_fill_(attention_mask, -10000.0)\n        probs = torch.softmax(inputs, dim=-1).to(dtype=input_dtype)\n        return probs\n\n\nclass TorchShaped(torch.nn.Module):\n    \"\"\"Noci et al.\"\"\"\n\n    def __init__(self, seq_op_in_fp32=False, hidden_size=768):\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n        self.register_buffer(\"nfactor\", torch.tensor(hidden_size).rsqrt())\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        input_dtype = inputs.dtype\n        breakpoint()\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        if attention_mask is not None:\n            inputs = inputs.masked_fill_(attention_mask, -10000.0)\n        probs = torch.softmax(inputs * self.nfactor, dim=-1).to(dtype=input_dtype)\n        I = torch.eye(probs.shape[-1], dtype=probs.dtype, device=probs.device)[None, None, :, :]\n        shaped_outputs = probs + I - 1 / probs.shape[-1]\n        return shaped_outputs\n\n\nclass SwinCosine(torch.nn.Module):\n    \"\"\"kind of SwinCosine, but not quite (normalizations scaled by mean(q) and mean(k))\"\"\"\n\n    def __init__(self, seq_op_in_fp32=False, tau=0.1, eps=1e-8):\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n        self.tau = 0.1\n        self.eps = eps\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        \"\"\"inputs are q_i, k_j -> o_ij. Normalize\"\"\"\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        row_norm = inputs.mean(dim=-1, keepdim=True).norm(dim=-2, keepdim=True)\n        col_norm = inputs.mean(dim=-2, keepdim=True).norm(dim=-1, keepdim=True)\n        outputs = inputs / torch.clamp(row_norm * col_norm * self.tau, min=self.eps)\n\n        if attention_mask is not None:\n            outputs[:, :, attention_mask[0, 0]] = 0\n\n        return outputs.to(dtype=input_dtype)\n\n\nclass TorchNormalize(torch.nn.Module):\n    def __init__(self, num_attention_heads=1, seq_op_in_fp32=False):\n        \"\"\"Normalized attention pooling as described in Richter&Wattenhofer, 2020.\"\"\"\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n        self.seq_gamma = torch.nn.Parameter(torch.ones(1, num_attention_heads, 1, 1))\n        self.seq_beta = torch.nn.Parameter(torch.zeros(1, num_attention_heads, 1, 1))\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        # Inputs are [b, np, sq, sk]\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n\n        if attention_mask is not None:\n            inputs.masked_fill_(attention_mask, 0.0)\n\n        norms = torch.nn.functional.layer_norm(inputs, inputs.shape[1:], eps=1e-05)\n        norms = (norms * self.seq_gamma + self.seq_beta).to(dtype=input_dtype)\n        return norms\n\n\nclass ScaledIdentity(torch.nn.Module):\n    def __init__(self, seq_op_in_fp32):\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        \"\"\"Sequence-scaled input.\"\"\"\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        return (inputs * torch.as_tensor(inputs.shape[2]).rsqrt()).to(dtype=input_dtype)\n\n\nclass Cumsum(torch.nn.Module):\n    def __init__(self, seq_op_in_fp32):\n        super().__init__()\n        self.seq_op_in_fp32 = seq_op_in_fp32\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        \"\"\"Sequence-scaled input cumulative sum.\"\"\"\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        return (inputs.cumsum(dim=-1) * pow(inputs.shape[2], -0.5)).to(dtype=input_dtype)\n\n\nclass CumsumExp(torch.nn.Module):\n    def __init__(self, seq_op_in_fp32):\n        super().__init__()\n        self.seq_op_in_fp32 = True  # Required as of pytorch 1.13\n\n    def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):\n        \"\"\"Sequence-scaled input cumulative sum.\"\"\"\n        input_dtype = inputs.dtype\n        if self.seq_op_in_fp32:\n            inputs = inputs.to(dtype=torch.float)\n        return (inputs.logcumsumexp(dim=-1) * pow(inputs.shape[2], -0.5)).to(dtype=input_dtype)\n"
  },
  {
    "path": "cramming/architectures/components.py",
    "content": "\"\"\"Basic transformer components.\"\"\"\n\nimport torch\n\nfrom typing import Tuple\nfrom functools import partial\n\nfrom .embeddings import SinusoidalPositional, LearnablePositional, ScaledSinosoidal, Abacus\nfrom torch.nn.modules.linear import NonDynamicallyQuantizableLinear  # use to mark output projections of attn while it exists\n\nINPLACE = False\n\n\nclass EmbeddingComponent(torch.nn.Module):\n    \"\"\"Absolute Embeddings and sine embeddings\"\"\"\n    def __init__(self, cfg_embedding, norm, norm_eps):\n        super().__init__()\n\n        self.word_embedding = torch.nn.Embedding(cfg_embedding.vocab_size, cfg_embedding.embedding_dim)\n        if cfg_embedding.pos_embedding == \"learned\":\n            self.pos_embedding = LearnablePositional(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)\n        elif cfg_embedding.pos_embedding == \"learned_rand\":\n            self.pos_embedding = LearnablePositionalRand(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)\n        elif cfg_embedding.pos_embedding == \"sinusoidal\":\n            self.pos_embedding = SinusoidalPositional(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)\n        elif cfg_embedding.pos_embedding == \"scaled-sinusoidal\":\n            self.pos_embedding = ScaledSinosoidal(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)\n        elif cfg_embedding.pos_embedding == \"abacus\":\n            self.pos_embedding = Abacus(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length, max_k=cfg_embedding.max_abacus_len)\n        else:\n            self.pos_embedding = None\n\n        if cfg_embedding.normalization:\n            self.stabilize_low_precision = cfg_embedding.get(\"stable_low_precision\", False)\n            self.norm = _get_norm_fn(norm)(cfg_embedding.embedding_dim, eps=norm_eps)\n        else:\n            self.stabilize_low_precision = False\n            self.norm = torch.nn.Identity()\n\n    def forward(self, input_ids):\n        embeds = self.word_embedding(input_ids)\n\n        if self.pos_embedding is not None:\n            embeds += self.pos_embedding(input_ids)\n        \n\n        if self.stabilize_low_precision:\n            # Stabilize as in bnb StableEmbedding\n            return self.norm(embeds.to(torch.get_default_dtype())).to(embeds.dtype)\n        else:\n            return self.norm(embeds)\n\n\nclass PredictionHeadComponent(torch.nn.Module):\n    def __init__(self, cfg_arch):\n        super().__init__()\n\n        if cfg_arch.embedding.embedding_dim == cfg_arch.hidden_size:\n            output_size = cfg_arch.hidden_size\n        else:\n            output_size = cfg_arch.embedding.embedding_dim\n\n        self.dense = torch.nn.Linear(cfg_arch.hidden_size, output_size, bias=cfg_arch.use_bias)\n        self.nonlin = _get_nonlin_fn(cfg_arch.nonlin, use_gating=False)()\n        self.norm = _get_norm_fn(cfg_arch.norm)(output_size, eps=cfg_arch.norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.norm(self.nonlin(self.dense(hidden_states)))\n        return hidden_states\n\n\nclass NormalizedResidualConnection(torch.nn.Module):\n    \"\"\"Implement variations on residual connection types, especially stabilized versions and deep/shaped propagation.\"\"\"\n\n    def __init__(self, input_dim, cfg_arch, output_dim=None, dropout=0.0):\n        super().__init__()\n        output_dim = input_dim if output_dim is None else output_dim\n        self.dropout = torch.nn.Dropout(dropout) if dropout > 0 else torch.nn.Identity()\n        if cfg_arch.norm_scheme == \"pre\":\n            self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)\n            self._chosen_forward_impl = self._prenormalization_residual\n        elif cfg_arch.norm_scheme == \"post\":\n            self.norm = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)\n            self._chosen_forward_impl = self._postnormalization_residual\n        elif cfg_arch.norm_scheme == \"simple\":\n            self._chosen_forward_impl = self._simple_residual\n        elif cfg_arch.norm_scheme == \"deepnorm\":\n            self.norm = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)\n            if \"num_transformer_layers\" in cfg_arch:\n                self.alpha = (2.0 * cfg_arch.num_transformer_layers) ** 0.25\n            elif \"layers_in_recurrent_block\" in cfg_arch:\n                self.alpha = (2.0 * cfg_arch.layers_in_recurrent_block * cfg_arch.maximal_recurrence) ** 0.25\n            else:\n                raise ValueError(\"Need to define `num_transformer_layers` in config for deepnorm.\")\n            self._chosen_forward_impl = self._deepnorm_residual\n        elif cfg_arch.norm_scheme == \"shaped\":\n            self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)\n            self.gamma = 0.214  # Noci et al., could make this into a parameter\n            self.alpha = torch.as_tensor(1 - self.gamma**2).sqrt().item()\n            self._chosen_forward_impl = self._prenorm_equalized_residual\n        elif cfg_arch.norm_scheme == \"sandwich\":\n            self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)\n            self.norm2 = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)\n            self._chosen_forward_impl = self._sandwich_residual\n        else:\n            raise ValueError(f\"Invalid type of residual connection {cfg_arch.norm_scheme} given.\")\n\n    def _simple_residual(self, residual, layer, states, *args, **kwargs):\n        return residual + self.dropout(layer(states, *args, **kwargs))\n\n    def _prenormalization_residual(self, residual, layer, states, *args, **kwargs):\n        return residual + self.dropout(layer(self.norm(states), *args, **kwargs))\n\n    def _postnormalization_residual(self, residual, layer, states, *args, **kwargs):\n        return self.norm(residual + layer(states, *args, **kwargs))\n\n    def _deepnorm_residual(self, residual, layer, states, *args, **kwargs):\n        return self.norm(residual * self.alpha + self.dropout(layer(states, *args, **kwargs)))\n\n    def _prenorm_equalized_residual(self, residual, layer, states, *args, **kwargs):\n        return residual * self.alpha + self.dropout(layer(self.norm(states), *args, **kwargs)) * self.gamma\n\n    def _sandwich_residual(self, residual, layer, states, *args, **kwargs):\n        return self.norm2(residual + self.dropout(layer(self.norm(states), *args, **kwargs)))\n\n    def forward(self, residual: torch.Tensor, layer_callable: torch.nn.Module, states: torch.Tensor, *args, **kwargs):\n        \"\"\"Argument might look weird here, but I find it nicer because it reads like the pre/post schemes from left to right,\n        as\n        residual + layer ( state )\n\n        Additional args are passed directly into the layer callable\n        \"\"\"\n        return self._chosen_forward_impl(residual, layer_callable, states, *args, **kwargs)\n\n\ndef _get_norm_fn(norm_name):\n    if norm_name == \"ScaleNorm\":\n        norm_fn = ScaleNorm\n    elif norm_name == \"RMSNorm\":\n        norm_fn = RMSNorm\n    elif norm_name == \"ApexLayerNorm\":\n        from apex.normalization import FusedLayerNorm\n\n        norm_fn = FusedLayerNorm\n    else:\n        norm_fn = getattr(torch.nn, norm_name)\n    return norm_fn\n\n\ndef _get_nonlin_fn(nonlin_name, use_gating=True):\n    if \"glu\" in nonlin_name.lower():\n        nonlin_name = nonlin_name.split(\"glu\")[0]\n        wrap_in_glu = use_gating\n    else:\n        wrap_in_glu = False\n    nonlin_fn = getattr(torch.nn, nonlin_name)  # dont mess this up :<\n    try:\n        nonlin_fn = partial(nonlin_fn, inplace=INPLACE)\n        nonlin_fn()\n    except TypeError:\n        nonlin_fn = getattr(torch.nn, nonlin_name)\n\n    if wrap_in_glu:\n        return partial(GLU, nonlin_fn)\n    else:\n        return nonlin_fn\n\n\nclass GLU(torch.nn.Module):\n    \"\"\"*-GLU activation functions.\n\n    Implementation mostly following megatron\n    \"\"\"\n\n    def __init__(self, sub_activation):\n        super().__init__()\n        self.sub_activation = sub_activation()\n\n    def forward(self, inputs):\n        x, gate = inputs.chunk(2, dim=-1)\n        return self.sub_activation(gate) * x\n\n\nclass ScaleNorm(torch.nn.Module):\n    \"\"\"Quick and simple scale norm implementation. \"elementwise_affine\" is not the ideal name but for compat with LayerNorm\n\n    Do we also need FixNorm (cosine in the last layer)? It's a maybe here:\n    https://github.com/lucidrains/performer-pytorch/issues/55#issuecomment-762544686\n    \"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-5, elementwise_affine: bool = True):\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.learnable_scale = torch.nn.Parameter(torch.tensor(float(hidden_size) ** -0.5))\n        else:\n            self.register_buffer(\"learnable_scale\", torch.tensor(float(hidden_size) ** -0.5))\n\n    def forward(self, inputs):\n        \"\"\"This is the same eps clipping as in the original ScaleNorm implementation.\"\"\"\n        return inputs * self.learnable_scale / torch.norm(inputs, dim=-1, keepdim=True).clamp(min=self.eps)\n\n\nclass RMSNorm(torch.nn.Module):\n    \"\"\"The RMS variant of scaling norms.  \"elementwise_affine\" is not the ideal name but for compat with LayerNorm\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6, elementwise_affine: bool = True):\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.learnable_scale = torch.nn.Parameter(torch.ones(hidden_size) ** -0.5)\n        else:\n            self.register_buffer(\"learnable_scale\", torch.ones(hidden_size) ** -0.5)\n\n    def _legacy_forward(self, inputs):\n        \"\"\"This is the same eps clipping as in the original ScaleNorm implementation.\"\"\"\n        return inputs * self.learnable_scale / torch.norm(inputs, dim=-1, keepdim=True).clamp(min=1e-8)\n\n    def _norm(self, x):\n        \"\"\"LLama implementation\"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float()).type_as(x)\n        return output * self.learnable_scale\n\n\ndef get_causal_attention_mask(input_ids) -> torch.Tensor:\n    \"\"\"Simplified triangular causal mask. Adapted for multiple heads.\"\"\"\n    seq_length = input_ids.shape[1]  # not transposed yet\n    device = input_ids.device\n    # lower triangular attention mask\n    mask = torch.tril(torch.ones((1, 1, seq_length, seq_length), device=device)).view(1, 1, seq_length, seq_length)\n\n    # convert to binary\n    return mask < 0.5\n\n\ndef get_extended_attention_mask(attention_mask: torch.Tensor, input_shape: Tuple[int], causal_attention: bool = False) -> torch.Tensor:\n    \"\"\"\n    Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n    Arguments:\n        attention_mask (`torch.Tensor`):\n            Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n        input_shape (`Tuple[int]`):\n            The shape of the input to the model.\n    Returns:\n        `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.\n\n    Method stolen from huggingface :)\n    \"\"\"\n    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n    # ourselves in which case we just need to make it broadcastable to all heads.\n    if attention_mask.dim() == 3:\n        extended_attention_mask = attention_mask[:, None, :, :]\n    elif attention_mask.dim() == 2:\n        # Provided a padding mask of dimensions [batch_size, seq_length]\n        # - if the model is a decoder, apply a causal mask in addition to the padding mask\n        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if causal_attention:\n            batch_size, seq_length = input_shape\n            seq_ids = torch.arange(seq_length, device=attention_mask.device)\n            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]\n            # in case past_key_values are used we need to add a prefix ones mask to the causal mask\n            # causal and attention masks must have same type with pytorch version < 1.3\n            causal_mask = causal_mask.to(attention_mask.dtype)\n\n            if causal_mask.shape[1] < attention_mask.shape[1]:\n                prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                causal_mask = torch.cat(\n                    [\n                        torch.ones((batch_size, seq_length, prefix_seq_len), device=attention_mask.device, dtype=causal_mask.dtype),\n                        causal_mask,\n                    ],\n                    axis=-1,\n                )\n            extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n        else:\n            extended_attention_mask = attention_mask[:, None, None, :]\n    else:\n        raise ValueError(f\"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})\")\n\n    # extended_attention_mask = extended_attention_mask.to(dtype=self.setup[\"dtype\"])  # fp16 compatibility\n    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n    return extended_attention_mask\n\n\n\"\"\"Collect inits.\"\"\"\n\n\n@torch.no_grad()\ndef _init_module(module, init_method=\"normal\", init_std=0.02, hidden_size=768, num_layers=12):\n    \"\"\"Todo: refactor this insanity\"\"\"\n    if \"deepnorm\" in init_method:  # This is a xavier init with changes in the MHA inits\n        if \"normal\" in init_method:\n            gain = init_std\n        elif \"subln\" in init_method:\n            gain = torch.as_tensor(2 * num_layers).log().sqrt()  # foundation transformer paper, use with subln\n        elif \"straight\" in init_method:\n            gain = torch.as_tensor(8 * num_layers).pow(-0.25)  # deepnorm paper, use with deepnorm\n        elif \"as-is\" in init_method:  # use locally defined inits for each module\n            gain = 1.0\n        else:\n            raise ValueError(f\"Invalid init method {init_method} given.\")\n\n        if isinstance(module, torch.nn.Linear):\n            if isinstance(module, NonDynamicallyQuantizableLinear):\n                # This is handled below in the MultiheadAttention section\n                pass\n            else:\n                if module.weight is not None:\n                    torch.nn.init.xavier_normal_(module.weight, gain=gain)\n                if module.bias is not None:\n                    module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0, std=module.weight.shape[1] ** -0.5)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, torch.nn.LayerNorm):\n            if module.weight is not None:\n                module.bias.data.zero_()\n                module.weight.data.fill_(1.0)\n        elif isinstance(module, torch.nn.MultiheadAttention):  # be careful with other transformer definitions!\n            if \"mimetic\" in init_method:\n                if module.in_proj_weight is not None:\n                    h = module.in_proj_weight.shape[1]\n                    Z1 = module.in_proj_weight.new_empty([h, h])\n                    torch.nn.init.xavier_normal_(Z1, gain=gain)  # as per deepnorm prescription\n                    I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)\n                    U1, S1, V1 = torch.linalg.svd(Z1 + I, full_matrices=False)\n                    V = U1 @ torch.diag_embed(S1.sqrt())\n                    O = V1 @ torch.diag_embed(S1.sqrt())\n\n                    k = module.head_dim\n                    I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)\n                    Qlist, Klist = [], []\n                    for head in range(module.num_heads):\n                        Z2 = module.in_proj_weight.new_empty([h, h])\n                        torch.nn.init.xavier_normal_(Z2, gain=1.0)  # as per deepnorm prescription\n                        U2, S2, V2 = torch.linalg.svd(Z2 + I, full_matrices=False)\n                        Qlist.append(U2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))\n                        Klist.append(V2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))\n                    Q, K = torch.cat(Qlist, dim=-1), torch.cat(Klist, dim=-1)\n                    module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())\n                    if module.out_proj is not None:\n                        module.out_proj.weight.data.copy_(O)\n            else:\n                if module.in_proj_weight is not None:\n                    h = module.in_proj_weight.shape[1]\n                    Q, K, V = (\n                        module.in_proj_weight.new_empty([h, h]),\n                        module.in_proj_weight.new_empty([h, h]),\n                        module.in_proj_weight.new_empty([h, h]),\n                    )\n                    torch.nn.init.xavier_normal_(Q, gain=1.0)  # as per deepnorm prescription\n                    torch.nn.init.xavier_normal_(K, gain=1.0)\n                    torch.nn.init.xavier_normal_(V, gain=gain)\n                    module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())\n                # init outproj:\n                if module.out_proj is not None:\n                    torch.nn.init.xavier_normal_(module.out_proj.weight, gain=gain)\n                    if module.out_proj.bias is not None:\n                        module.out_proj.bias.data.zero_()\n            if module.in_proj_bias is not None:\n                module.in_proj_bias.data.zero_()\n            if module.bias_k is not None:\n                module.bias_k.data.zero_()\n            if module.bias_v is not None:\n                module.bias_v.data.zero_()\n            if module.out_proj is not None and module.out_proj.bias is not None:\n                module.out_proj.bias.data.zero_()\n    else:\n        if \"normal\" in init_method:\n            std = init_std\n        elif init_method == \"small\" in init_method:\n            # Transformers without Tears: Improving\n            # the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010)\n            std = torch.as_tensor(2 / (5 * hidden_size)).sqrt()\n        elif \"megatron\" in init_method:\n            std = torch.as_tensor(1 / (3 * hidden_size)).sqrt()\n            # Megatron init is near-equal to normal if hidden=768, but otherwise smaller\n        elif \"wang\" in init_method:\n            std = 2 / num_layers / torch.as_tensor(hidden_size).sqrt()\n        elif \"as-is\" in init_method:  # use locally defined inits for each module\n            return\n        else:\n            raise ValueError(f\"Invalid init method {init_method} given.\")\n        if isinstance(module, torch.nn.Linear):\n            if isinstance(module, NonDynamicallyQuantizableLinear):\n                # This is handled below in the MultiheadAttention section\n                pass\n            else:\n                # Slightly different from the TF version which uses truncated_normal for initialization\n                # cf https://github.com/pytorch/pytorch/pull/5617\n                if module.weight is not None:\n                    module.weight.data.normal_(mean=0.0, std=std)\n                if module.bias is not None:\n                    module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, torch.nn.LayerNorm):\n            if module.weight is not None:\n                module.bias.data.zero_()\n                module.weight.data.fill_(1.0)\n        elif isinstance(module, torch.nn.MultiheadAttention):  # be careful with other transformer definitions!\n            if \"mimetic\" in init_method:\n                if module.in_proj_weight is not None:\n                    h = module.in_proj_weight.shape[1]\n                    Z1 = module.in_proj_weight.new_empty([h, h]).normal_() / h\n                    I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)\n                    U1, S1, V1 = torch.linalg.svd(0.2 * Z1 + 0.2 * I, full_matrices=False)\n                    V = U1 @ torch.diag_embed(S1.sqrt())\n                    O = V1 @ torch.diag_embed(S1.sqrt())\n\n                    k = module.head_dim\n                    I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)\n                    Qlist, Klist = [], []\n                    for head in range(module.num_heads):\n                        # Z2 = module.in_proj_weight.new_empty([h, h]).normal_() / h\n                        U2, S2, V2 = torch.linalg.svd(0 + 0.5 * I, full_matrices=False)  # alpha1 =0 from Trockman\n                        Qlist.append(U2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))  # this is a bit pointless, ...\n                        Klist.append(V2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))  # ... I've left it here for alpha1 not zero\n                    Q, K = torch.cat(Qlist, dim=-1), torch.cat(Klist, dim=-1)\n                    module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())\n                    if module.out_proj is not None:\n                        module.out_proj.weight.data.copy_(O)\n            else:\n                if module.in_proj_weight is not None:\n                    module.in_proj_weight.data.normal_(mean=0.0, std=std)\n                if module.out_proj is not None:\n                    module.out_proj.weight.data.normal_(mean=0.0, std=std)\n            if module.in_proj_bias is not None:\n                module.in_proj_bias.data.zero_()\n            if module.bias_k is not None:\n                module.bias_k.data.zero_()\n            if module.bias_v is not None:\n                module.bias_v.data.zero_()\n            # init outproj:\n            if module.out_proj is not None and module.out_proj.bias is not None:\n                module.out_proj.bias.data.zero_()\n"
  },
  {
    "path": "cramming/architectures/construction.py",
    "content": "\"\"\"Interface to construct models.\"\"\"\n\nfrom .huggingface_interface import construct_huggingface_model\nfrom .sanity_check import SanityCheckforPreTraining\nfrom .crammed_transformer import construct_crammed_transformer\nfrom .crammed_depthrecurrent import construct_crammed_recurrent\n\nimport logging\nfrom ..utils import is_main_process\n\nlog = logging.getLogger(__name__)\n\n\ndef construct_model(cfg_arch, tokenizer):\n    model = None\n    eos_token_id = tokenizer.eos_token  # tokenizer.vocab[\"<eot>\"]\n    if \"model_type\" in cfg_arch:\n        # attempt to solve locally\n        if \"SanityCheckLM\" in cfg_arch.model_type:\n            model = SanityCheckforPreTraining(cfg_arch.width, tokenizer.vocab_size)\n        elif \"ScriptableCrammedTransformer\" in cfg_arch.model_type:\n            model = construct_crammed_transformer(cfg_arch, tokenizer.vocab_size)\n        elif \"ScriptableCrammedDepthRecurrent\" in cfg_arch.model_type:\n            equals_token = tokenizer.vocab[\"=\"]\n            model = construct_crammed_recurrent(cfg_arch, tokenizer.vocab_size, equals_token)\n\n    if model is not None:  # Return local model arch\n        num_params = sum([p.numel() for p in model.parameters()])\n        if is_main_process():\n            log.info(f\"Model with architecture {cfg_arch.model_type} loaded with {num_params:,} parameters.\")\n        return model\n\n    try:  # else try on HF\n        model = construct_huggingface_model(cfg_arch, tokenizer.vocab_size)\n        num_params = sum([p.numel() for p in model.parameters()])\n        if is_main_process():\n            log.info(f\"Model with config {cfg_arch} loaded with {num_params:,} parameters.\")\n        return model\n    except Exception as e:\n        raise ValueError(f\"Invalid model architecture {cfg_arch.model_type} given. Error: {e}\")\n"
  },
  {
    "path": "cramming/architectures/crammed_depthrecurrent.py",
    "content": "\"\"\"Variant for modifications of the transformer architecture that are depth-recurrent\"\"\"\nimport torch\nfrom transformers import PretrainedConfig, PreTrainedModel\nfrom transformers import AutoConfig, AutoModel, AutoModelForCausalLM\n\nfrom typing import Optional\nfrom omegaconf import OmegaConf\n\nfrom .components import (\n    _get_norm_fn,\n    _get_nonlin_fn,\n    EmbeddingComponent,\n    GLU,\n    get_causal_attention_mask,\n    _init_module,\n    NormalizedResidualConnection,\n)\nfrom .attention import get_attention_mechanism\n\n\nclass crammedDepthRecurrentConfig(PretrainedConfig):\n    model_type = \"crammedDepthRecurrent\"\n\n    def __init__(self, cfg_arch_container: dict = {}, **kwargs):\n        self.arch = cfg_arch_container\n        super().__init__(**kwargs)\n\n\ndef construct_crammed_recurrent(cfg_arch, vocab_size, equals_token):\n    \"\"\"See the config file for details on what is possible.\"\"\"\n    cfg_arch.embedding.vocab_size = vocab_size\n\n    config = crammedDepthRecurrentConfig(OmegaConf.to_container(cfg_arch, resolve=True))\n    if config.arch[\"objective_layout\"] in [\"fixed\", \"albert\"]:\n        model = ScriptableRecurrentLMForPreTraining(config)\n    elif config.arch[\"objective_layout\"] in [\"TBPTT\", \"deepthinking\"]:\n        model = ScriptableRecurrentLMBPTT(config, equals_token)\n    else:\n        raise ValueError(f\"Invalid layout {config.arch['objective_layout']} of training objective given.\")\n\n    return model\n\n\nclass FFNComponent(torch.nn.Module):\n    \"\"\"Note: The FF layer is not auto-scaled when using a GLU type activation.\n    Better do this manually and choose a sensible intermed_size that is nicely divisible.\n\n    The neox suggestion for approx. equal parameter count is int(4 * 2 / 3 * hidden_size) * 2 [this is ~5.33]\n    \"\"\"\n\n    def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=None):\n        super().__init__()\n        self.dense_in = torch.nn.Linear(hidden_size, intermed_size, bias=cfg_arch.use_bias)\n        self.nonlin = _get_nonlin_fn(cfg_arch.nonlin)()\n        if isinstance(self.nonlin, GLU):\n            intermed_output_size = intermed_size // 2\n        else:\n            intermed_output_size = intermed_size\n        if cfg_arch.sub_normalization:\n            self.norm = _get_norm_fn(cfg_arch.norm)(intermed_output_size, eps=cfg_arch.norm_eps)\n        else:\n            self.norm = torch.nn.Identity()\n        output_size = hidden_size if output_size is None else output_size\n        self.dense_out = torch.nn.Linear(intermed_output_size, output_size, bias=cfg_arch.use_bias)\n\n    def forward(self, hidden_states):\n        return self.dense_out(self.norm(self.nonlin(self.dense_in(hidden_states))))\n\n\nclass TransformerLayer(torch.nn.Module):\n    \"\"\"A transformer structure based on the components from above.\"\"\"\n\n    def __init__(self, idx, cfg_arch):\n        super().__init__()\n        self.residual1 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)\n        self.residual2 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)\n        if cfg_arch.attention.sub_normalization:\n            sub_norm_fn = lambda: _get_norm_fn(cfg_arch.norm)(cfg_arch.hidden_size, eps=cfg_arch.norm_eps)  # noqa\n        else:\n            sub_norm_fn = torch.nn.Identity\n        self.attn = get_attention_mechanism(idx, cfg_arch.hidden_size, cfg_arch.attention, sub_norm_fn)\n        self.ffn = FFNComponent(cfg_arch.hidden_size, cfg_arch.intermed_size, cfg_arch)\n        self.LAYOUT = self.attn.LAYOUT\n\n    def forward(self, states, attention_mask: Optional[torch.Tensor] = None):\n        states = self.residual1(states, self.attn, states, attention_mask)\n        states = self.residual2(states, self.ffn, states)\n        return states\n\n\nclass TransformerBlock(torch.nn.Module):\n    \"\"\"A transformer block of multiple layers (without weightsharing).\"\"\"\n\n    def __init__(self, layers, cfg_arch):\n        super().__init__()\n        self.layers = torch.nn.ModuleList(layers)\n        self.seq_first = self.layers[0].LAYOUT == \"[S B H]\" if len(self.layers) > 0 else False\n        self.injection_type = cfg_arch.input_injection_type\n        if self.injection_type == \"linear\":\n            self.adapter = torch.nn.Linear(cfg_arch.hidden_size * 2, cfg_arch.hidden_size, bias=False)\n        elif self.injection_type == \"ffn\":\n            self.ffn = FFNComponent(cfg_arch.hidden_size * 2, cfg_arch.intermed_size, cfg_arch, cfg_arch.hidden_size)\n\n    def forward(self, states, injected_state, attention_mask: Optional[torch.Tensor] = None):\n        if self.injection_type == \"none\":\n            states = states\n        elif self.injection_type == \"add\": # this is the deafault in the config\n            states = states + injected_state\n        elif self.injection_type == \"linear\":\n            combined_inputs = torch.cat([states, injected_state], dim=-1)\n            states = self.adapter(combined_inputs)\n        elif self.injection_type == \"ffn\":\n            combined_inputs = torch.cat([states, injected_state], dim=-1)\n            states = self.ffn(combined_inputs)\n        for layer in self.layers:\n            states = layer(states, attention_mask)\n        return states\n\n\nclass TransposedAdapter(torch.nn.Linear):  # steal init\n    def __init__(self, embedding_dim, hidden_size, original_adapter, tie_weights=True):\n        torch.nn.Module.__init__(self)\n        # self.adapter.weight = self.encoder.adapter.weight.T # this would be nice but cannot assign like this\n        if tie_weights:\n            self.weight = original_adapter.weight\n        else:\n            self.adapter_active = False\n            self.weight = torch.nn.Parameter(torch.randn([hidden_size, embedding_dim]))  # transposed\n        self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def forward(self, inputs):\n        return torch.nn.functional.linear(inputs, self.weight.T)\n\n\nclass ScriptableRecurrentLM(PreTrainedModel):\n    \"\"\"Depth-recurrent model. Trying to include most reasonable variations of this concept\"\"\"\n\n    config_class = crammedDepthRecurrentConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n\n        self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)\n        if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:\n            self.adapter = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, bias=False)\n        else:\n            self.adapter = torch.nn.Identity()\n        self.state_init = self.cfg.state_init\n        self.recurrent_block = torch.compile(\n            TransformerBlock([TransformerLayer(idx, self.cfg) for idx in range(self.cfg.layers_in_recurrent_block)], self.cfg),\n            mode=\"default\",\n            disable=not self.cfg.local_compilation,\n        )\n        self.seq_first = self.recurrent_block.seq_first\n        if self.cfg.head == \"identity\":\n            self.head = torch.nn.Identity()\n        elif self.cfg.head == \"ffn\":\n            self.head = FFNComponent(self.cfg.hidden_size, self.cfg.intermed_size, self.cfg)\n        elif self.cfg.head == \"linear\":\n            self.head = torch.nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.use_bias)\n        else:\n            raise ValueError(f\"Invalid head layout {self.cfg.head} given.\")\n\n        if self.cfg.final_norm:\n            self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)\n        else:\n            self.final_norm = torch.nn.Identity()\n        self.register_buffer(\"attention_mask\", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)\n\n    def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = None, num_steps_with_grad: int = None):\n        if input_ids.shape[1] != self.attention_mask.shape[1]:\n            self.attention_mask = get_causal_attention_mask(input_ids)\n        hidden_states = self.adapter(self.embedding(input_ids))\n        if self.seq_first:\n            hidden_states = hidden_states.transpose(0, 1).contiguous()\n        injected_state = hidden_states.clone()\n\n        num_steps_prefix = 0 if num_steps_no_grad is None else num_steps_no_grad\n        hidden_states = self.initialize_state(hidden_states)\n\n        # Recurr without gradients\n        with torch.no_grad():\n            for repeat in range(num_steps_prefix):\n                hidden_states = self.recurrent_block(hidden_states, injected_state, self.attention_mask).clone()\n\n        num_steps_active = self.cfg.maximal_recurrence if num_steps_with_grad is None else num_steps_with_grad\n        # Recur with gradients\n        for repeat in range(num_steps_active):\n            hidden_states = self.recurrent_block(hidden_states, injected_state, self.attention_mask).clone()\n        return self.final_norm(self.head(hidden_states))\n\n    def initialize_state(self, hidden_states):\n        if self.cfg.initial_hidden_randomized:\n            batch_size = hidden_states.shape[0]\n            if self.state_init == \"normal\":\n                hidden_states = torch.randn_like(hidden_states)\n            elif self.state_init == \"embed\":  # initialized like a BERT embedding\n                hidden_states = torch.randn_like(hidden_states).mul(0.02)\n            elif self.state_init == \"zero\":\n                hidden_states = torch.zeros_like(hidden_states)\n            elif self.state_init == \"unit\":\n                hidden_states = torch.randn_like(hidden_states)\n                std, mean = torch.std_mean(hidden_states, dim=-1, keepdim=True)\n                hidden_states = (hidden_states - mean) / std\n        return hidden_states\n\n\nclass ScriptableRecurrentLMReplicaConcat(PreTrainedModel):\n    \"\"\"Depth-recurrent model. with skips inside block \n    This is nearly the same as ScriptableRecurrentLM but has skips inside block too\"\"\"\n\n    config_class = crammedDepthRecurrentConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n\n        self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)\n        if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:\n            self.adapter = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, bias=False)\n        else:\n            self.adapter = torch.nn.Identity()\n        self.state_init = self.cfg.state_init\n\n\n        self.max_recurs = self.cfg.layers_in_recurrent_block\n        self.recurrent_blocks = []\n        print(\"Initializing feedforward blocks with recall connections\")\n        for _ in range(self.max_recurs):\n            self.recurrent_blocks.append(\n                torch.compile(TransformerBlock([TransformerLayer(1, self.cfg)], self.cfg),\n                              mode=\"default\",\n                              disable=not self.cfg.local_compilation,)\n            )\n        self.recurrent_blocks = torch.nn.ModuleList(self.recurrent_blocks)\n        print(f\"Initialized feedforward blocks with recall connections. \"\n              f\"It has the depth of {self.max_recurs}\")\n\n        self.seq_first = self.recurrent_blocks[0].seq_first\n        if self.cfg.head == \"identity\":\n            self.head = torch.nn.Identity()\n        elif self.cfg.head == \"ffn\":\n            self.head = FFNComponent(self.cfg.hidden_size, self.cfg.intermed_size, self.cfg)\n        elif self.cfg.head == \"linear\":\n            self.head = torch.nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.use_bias)\n        else:\n            raise ValueError(f\"Invalid head layout {self.cfg.head} given.\")\n\n        if self.cfg.final_norm:\n            self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)\n        else:\n            self.final_norm = torch.nn.Identity()\n        self.register_buffer(\"attention_mask\", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)\n\n\n    def apply_recurrent_block(self, hidden_states, injected_state, attention_mask):\n        for block in self.recurrent_blocks:\n            hidden_states = block(hidden_states, injected_state, attention_mask)\n        return hidden_states\n\n\n    def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = None, num_steps_with_grad: int = None):\n        if input_ids.shape[1] != self.attention_mask.shape[1]:\n            self.attention_mask = get_causal_attention_mask(input_ids)\n        hidden_states = self.adapter(self.embedding(input_ids))\n        if self.seq_first:\n            hidden_states = hidden_states.transpose(0, 1).contiguous()\n        injected_state = hidden_states.clone()\n\n        num_steps_prefix = 0 if num_steps_no_grad is None else num_steps_no_grad\n        hidden_states = self.initialize_state(hidden_states)\n\n        # Recurr without gradients\n        with torch.no_grad():\n            for repeat in range(num_steps_prefix):\n                hidden_states = self.apply_recurrent_block(hidden_states, injected_state, self.attention_mask).clone()\n\n        num_steps_active = self.cfg.maximal_recurrence if num_steps_with_grad is None else num_steps_with_grad\n        # Recur with gradients\n        for repeat in range(num_steps_active):\n            hidden_states = self.apply_recurrent_block(hidden_states, injected_state, self.attention_mask).clone()\n        return self.final_norm(self.head(hidden_states))\n\n    def initialize_state(self, hidden_states):\n        if self.cfg.initial_hidden_randomized:\n            batch_size = hidden_states.shape[0]\n            if self.state_init == \"normal\":\n                hidden_states = torch.randn_like(hidden_states)\n            elif self.state_init == \"embed\":  # initialized like a BERT embedding\n                hidden_states = torch.randn_like(hidden_states).mul(0.02)\n            elif self.state_init == \"zero\":\n                hidden_states = torch.zeros_like(hidden_states)\n            elif self.state_init == \"unit\":\n                hidden_states = torch.randn_like(hidden_states)\n                std, mean = torch.std_mean(hidden_states, dim=-1, keepdim=True)\n                hidden_states = (hidden_states - mean) / std\n        return hidden_states\n\n\n\"\"\"Generator fn for these models.\"\"\"\n@torch.no_grad()\ndef _generate(self, input_ids, token_limit=100, temperature=1.0, steps_at_generation_time=None, track_steps=False, greedy=False, quick=False, **kwargs):\n    \"\"\"Generate token_limit many tokens from input_ids prompt. \n    track_steps = for making thinking plots\n    \"\"\"\n    predicted_ids = []\n    tracking = []\n    num_steps = self.cfg.maximal_recurrence_in_eval if steps_at_generation_time is None else steps_at_generation_time\n    logit_tensor = torch.zeros(token_limit, num_steps, self.cfg.embedding.vocab_size)\n    for gen_idx in range(token_limit):\n        if input_ids.shape[1] != self.encoder.attention_mask.shape[1]:\n            self.encoder.attention_mask = get_causal_attention_mask(input_ids)\n        hidden_states = self.encoder.adapter(self.encoder.embedding(input_ids))\n        if self.encoder.seq_first:\n            hidden_states = hidden_states.transpose(0, 1).contiguous()\n        injected_state = hidden_states\n        hidden_states = self.encoder.initialize_state(hidden_states)\n        # Recur without gradient\n        step = []\n        with torch.no_grad():\n            for repeat in range(num_steps):\n                if hasattr(self.encoder, 'recurrent_blocks'):\n                    for block in self.encoder.recurrent_blocks:\n                        hidden_states = block(hidden_states, injected_state, self.encoder.attention_mask)\n                else:\n                    hidden_states = self.encoder.recurrent_block._orig_mod(hidden_states, injected_state,\n                                                                           self.encoder.attention_mask)\n                if track_steps:\n                    # keep track of the intermediate probs\n                    output_states = self.encoder.final_norm(self.encoder.head(hidden_states.clone()))\n                    logits = self.decoder(self.adapter(output_states))\n                    logits = logits[-1, :, :] if self.encoder.seq_first else logits[:, -1, :]\n                    if greedy:\n                        probs = torch.softmax(logits, dim=-1)\n                        predicted_token = torch.argmax(logits, dim=1).unsqueeze(dim=0)\n                    else:\n                        probs = torch.softmax(logits * temperature, dim=-1)\n                        predicted_token = torch.multinomial(probs, 1)\n                    logit_tensor[gen_idx, repeat, :] = probs\n                    step.append(predicted_token)\n        if track_steps:\n            predicted_token = step[-1]\n        else:\n            # calcualte the probs if we haven't already\n            output_states = self.encoder.final_norm(self.encoder.head(hidden_states.clone()))\n            logits = self.decoder(self.adapter(output_states))\n            logits = logits[-1, :, :] if self.encoder.seq_first else logits[:, -1, :]\n            if greedy:\n                predicted_token = torch.argmax(logits, dim=1).unsqueeze(dim=0)\n            else:\n                predicted_token = torch.multinomial(torch.softmax(logits * temperature, dim=-1), 1)\n\n        if quick:\n            input_ids = torch.cat((input_ids, torch.transpose(predicted_token, 0, 1)), dim=1)\n        else:\n            input_ids = torch.cat([input_ids, predicted_token], dim=-1)\n        predicted_ids += [predicted_token]\n        tracking.append(step)\n\n    if quick:\n        generated_ids = torch.stack(predicted_ids, dim=1).squeeze()\n    else:\n        generated_ids = torch.cat(predicted_ids, dim=-1)\n\n    if track_steps:\n        return generated_ids, tracking, logit_tensor # tracking is a [num generated tokens, num recurrences] list of lists of tensors of which each tensor is a token id\n    return generated_ids\n\n\nclass ScriptableRecurrentLMForPreTraining(PreTrainedModel):\n    \"\"\"Pretraining version\"\"\"\n\n    config_class = crammedDepthRecurrentConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n\n        self.encoder = ScriptableRecurrentLM(config)\n        if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:\n            self.adapter = TransposedAdapter(\n                self.cfg.embedding.embedding_dim, self.cfg.hidden_size, self.encoder.adapter, self.cfg.tie_weights\n            )\n        else:\n            self.adapter = torch.nn.Identity()\n        self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)\n        if self.cfg.tie_weights:\n            self.decoder.weight = self.encoder.embedding.word_embedding.weight\n\n        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) # size_average defaults to True so when using masking loss is calculated correctly\n\n        self._init_weights()\n\n    def _init_weights(self, module=None):\n        modules = self.modules() if module is None else [module]\n        for module in modules:\n            _init_module(\n                module,\n                self.cfg.init.type,\n                self.cfg.init.std,\n                self.cfg.hidden_size,\n                self.cfg.layers_in_recurrent_block * self.cfg.maximal_recurrence,\n            )\n\n    def forward(self, input_ids: torch.Tensor, *args, **kwargs):\n        outputs = self.decoder(self.adapter(self.encoder(input_ids, num_steps_no_grad=0, num_steps_with_grad=self.cfg.maximal_recurrence)))\n\n        if self.encoder.seq_first:\n            shifted_outputs = outputs[:-1]\n            shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()\n            outputs = outputs.detach().transpose(0, 1)\n        else:\n            shifted_outputs = outputs[..., :-1, :].contiguous()\n            shifted_labels = input_ids[..., 1:].contiguous()\n            outputs = outputs.detach()\n\n        # Flatten the tokens and compute loss\n        loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), shifted_labels.view(-1))\n\n        return {\"loss\": loss, \"logits\": outputs[:, -1, :], \"log_perplexity\": loss.clone().detach()}\n\n    def _generate(self, input_ids, token_limit=100, temperature=0.7, steps_at_generation_time=None):\n        return _generate(self, input_ids, token_limit, temperature, steps_at_generation_time)\n\n\nclass ScriptableRecurrentLMBPTT(PreTrainedModel):\n    \"\"\"Pretraining version with stochastic depth / trunc. BPTT\"\"\"\n\n    config_class = crammedDepthRecurrentConfig\n\n    def __init__(self, config, equals_token):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n        self.equals_token = equals_token\n\n        self.max_recurrences_for_training = self.cfg.maximal_recurrence\n        self.max_backprop = max(self.cfg.maximal_recurrence // 2 if self.cfg.max_backprop is None else self.cfg.max_backprop, 1)\n        try:\n            self.forward_only_model_with_skip = self.cfg.forward_only_model_with_skip\n            if self.cfg.forward_only_model_with_skip:\n                print(\"Using forward only model with skip\")\n                self.encoder = ScriptableRecurrentLMReplicaConcat(config)\n            else:\n                self.encoder = ScriptableRecurrentLM(config)\n        except:\n            self.encoder = ScriptableRecurrentLM(config)\n\n        self.adapter = TransposedAdapter(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, self.encoder.adapter, self.cfg.tie_weights)\n        self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)\n        if self.cfg.tie_weights:\n            self.decoder.weight = self.encoder.embedding.word_embedding.weight\n\n        self.throttle = self.cfg.throttle\n        self.alpha = self.cfg.alpha\n        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction=self.cfg.loss_reduction) # size_average defaults to True so when using masking loss is calculated correctly\n        self._init_weights()\n\n        self.mask_before_equals = self.cfg.mask_before_equals\n        self.model_call = self.prog_model_call_with_masking # moved the logic for masking before equals into this function\n\n    def _init_weights(self, module=None):\n        modules = self.modules() if module is None else [module]\n        for module in modules:\n            _init_module(\n                module,\n                self.cfg.init.type,\n                self.cfg.init.std,\n                self.cfg.hidden_size,\n                self.cfg.layers_in_recurrent_block * self.cfg.maximal_recurrence,\n            )\n\n    def set_max_recurrences_for_training(self, new_max):\n        \"\"\"Can play around with recurrences during training\"\"\"\n        self.max_recurrences_for_training = new_max\n        self.max_backprop = max(self.max_recurrences_for_training // 2 if self.cfg.max_backprop is None else self.cfg.max_backprop, 1)\n\n    def forward(self, input_ids: torch.Tensor, *args, **kwargs):\n        \"\"\"\n        WARNING: max iters outputs is used for logits and entropy calcs\n        \"\"\"\n        if self.training:\n            loss, outputs = self.forward_progressive(input_ids)\n            if self.throttle:\n                Ek = 1 + min(self.max_recurrences_for_training / 4, self.max_backprop / 2)\n                loss = loss * (Ek / self.max_backprop)\n        else:\n            loss, outputs = self.model_call(input_ids, n=self.cfg.maximal_recurrence_in_eval, k=0)\n\n        return {\"loss\": loss, \"logits\": outputs[:, -1, :], \"log_perplexity\": loss.clone().detach()}\n    \n    def forward_progressive(self, input_ids):\n        \"\"\"Implements progressive loss\"\"\"\n        if self.alpha != 1:\n            # max iters forward pass\n            n = self.max_recurrences_for_training-self.max_backprop\n            k = self.max_backprop # i.e. maxmimise the number of layers we back prop through\n            loss_max_iters, outputs_max_iters = self.model_call(input_ids, n=n, k=k)\n        else:\n            loss_max_iters = torch.zeros(1, dtype=torch.float32).to(input_ids.get_device())\n\n        if self.alpha != 0:\n            # stochastic forward pass\n            n = torch.randint(low=0, high=self.max_recurrences_for_training, size=(1,))\n            k = torch.randint(low=1, high=1 + min(self.max_recurrences_for_training - n, self.max_backprop), size=(1,))\n            loss_progressive, outputs_progressive = self.model_call(input_ids, n=n, k=k)\n            if self.alpha == 1:\n                outputs_max_iters = outputs_progressive\n        else:\n            loss_progressive = torch.zeros(1, dtype=torch.float32).to(input_ids.get_device())\n        \n        loss = (1 - self.alpha) * loss_max_iters + self.alpha * loss_progressive\n        # Returning outputs max_iters to be used for logits, could try outputs_progressive\n        return loss, outputs_max_iters\n\n    def prog_model_call_with_masking(self, input_ids, n, k):\n        if self.mask_before_equals: # mask before equals\n            indices_of_equals = (input_ids == self.equals_token).nonzero()[:, 1] # gets the index of equals sign for each tensor in the batch\n            max_indices = torch.arange(input_ids.size(1), device=input_ids.device) # tensor for mask\n            masks = max_indices.unsqueeze(0) > indices_of_equals.unsqueeze(1) # fill tensor after including index of = sign for each row\n        else: # mask only the random padding\n            masks = input_ids != 0\n        \n        outputs = self.decoder(self.adapter(self.encoder(input_ids, num_steps_no_grad=n, num_steps_with_grad=k)))\n\n        if self.encoder.seq_first:\n            shifted_outputs = outputs[:-1]\n            shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()\n            outputs = outputs.detach().transpose(0, 1)\n            masked = torch.mul(shifted_labels, masks[..., 1:].transpose(0, 1))\n        else:\n            shifted_outputs = outputs[..., :-1, :].contiguous()\n            shifted_labels = input_ids[..., 1:].contiguous()\n            outputs = outputs.detach()\n            masked = torch.mul(shifted_labels, masks[..., 1:])\n        masked[masked == 0] = -100 # mask all 0's in loss\n\n        shifted_outputs_shape = shifted_outputs.shape\n        \n        loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), masked.view(-1)) # CE_Loss(Input, Target)\n        if self.cfg.loss_reduction=='none': # giving all output samples equal weighting\n            loss = loss.view(shifted_outputs_shape[0],shifted_outputs_shape[1])\n            loss = torch.mean(loss, dim=1)\n            loss = torch.mean(loss)\n        return loss, outputs\n\n    def _generate(self, input_ids, token_limit=100, temperature=1.0, steps_at_generation_time=None, track_steps=False, greedy=False, quick=False):\n        return _generate(self, input_ids, token_limit, temperature, steps_at_generation_time, track_steps, greedy=greedy, quick=quick)\n\n\n# ###### HF registry here? ############### #\n\nAutoConfig.register(\"crammedDepthRecurrent\", crammedDepthRecurrentConfig)\nAutoModel.register(crammedDepthRecurrentConfig, ScriptableRecurrentLM)\nAutoModelForCausalLM.register(crammedDepthRecurrentConfig, ScriptableRecurrentLMForPreTraining)\n"
  },
  {
    "path": "cramming/architectures/crammed_transformer.py",
    "content": "\"\"\"Base file for modifications of the transformer architecture\"\"\"\nimport torch\nfrom transformers import PretrainedConfig, PreTrainedModel\nfrom transformers import AutoConfig, AutoModel, AutoModelForCausalLM\n\nfrom typing import Optional\nfrom omegaconf import OmegaConf\n\nfrom .components import (\n    _get_norm_fn,\n    _get_nonlin_fn,\n    NormalizedResidualConnection,\n    EmbeddingComponent,\n    GLU,\n    get_causal_attention_mask,\n    _init_module,\n)\nfrom .attention import get_attention_mechanism\n\n\nclass crammedTransformerConfig(PretrainedConfig):\n    model_type = \"crammedTransformer\"\n\n    def __init__(self, cfg_arch_container: dict = {}, **kwargs):\n        self.arch = cfg_arch_container\n        super().__init__(**kwargs)\n\n\ndef construct_crammed_transformer(cfg_arch, vocab_size):\n    \"\"\"See the config file for details on what is possible.\"\"\"\n    cfg_arch.embedding.vocab_size = vocab_size\n\n    config = crammedTransformerConfig(OmegaConf.to_container(cfg_arch, resolve=True))\n    model = ScriptableLMForPreTraining(config)\n\n    return model\n\n\nclass FFNComponent(torch.nn.Module):\n    \"\"\"Note: The FF layer is not auto-scaled when using a GLU type activation.\n    Better do this manually and choose a sensible intermed_size that is nicely divisible.\n\n    The neox suggestion for approx. equal parameter count is int(4 * 2 / 3 * hidden_size) * 2 [this is ~5.33]\n    \"\"\"\n\n    def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=None):\n        super().__init__()\n        self.dense_in = torch.nn.Linear(hidden_size, intermed_size, bias=cfg_arch.use_bias)\n        self.nonlin = _get_nonlin_fn(cfg_arch.nonlin)()\n        if isinstance(self.nonlin, GLU):\n            intermed_output_size = intermed_size // 2\n        else:\n            intermed_output_size = intermed_size\n        if cfg_arch.sub_normalization:\n            self.norm = _get_norm_fn(cfg_arch.norm)(intermed_output_size, eps=cfg_arch.norm_eps)\n        else:\n            self.norm = torch.nn.Identity()\n        output_size = hidden_size if output_size is None else output_size\n        self.dense_out = torch.nn.Linear(intermed_output_size, output_size, bias=cfg_arch.use_bias)\n\n    def forward(self, hidden_states):\n        return self.dense_out(self.norm(self.nonlin(self.dense_in(hidden_states))))\n\n\nclass TransformerLayer(torch.nn.Module):\n    \"\"\"A transformer structure based on the components from above.\"\"\"\n\n    def __init__(self, idx, cfg_arch):\n        super().__init__()\n        self.residual1 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)\n        self.residual2 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)\n        if cfg_arch.attention.sub_normalization:\n            sub_norm_fn = lambda: get_norm_fn(cfg_arch.norm)(cfg_arch.hidden_size, eps=cfg_arch.norm_eps)  # noqa\n        else:\n            sub_norm_fn = torch.nn.Identity\n        self.attn = get_attention_mechanism(idx, cfg_arch.hidden_size, cfg_arch.attention, sub_norm_fn)\n        self.ffn = FFNComponent(cfg_arch.hidden_size, cfg_arch.intermed_size, cfg_arch)\n        self.LAYOUT = self.attn.LAYOUT\n\n    def forward(self, states, attention_mask: Optional[torch.Tensor] = None):\n        states = self.residual1(states, self.attn, states, attention_mask)\n        states = self.residual2(states, self.ffn, states)\n        return states\n\n\nclass ScriptableLM(PreTrainedModel):\n    \"\"\"Simplified transformer wrapper.\"\"\"\n\n    config_class = crammedTransformerConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n\n        self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)\n        self.layers = torch.nn.ModuleList([TransformerLayer(idx, self.cfg) for idx in range(self.cfg.num_transformer_layers)])\n        self.seq_first = self.layers[0].LAYOUT == \"[S B H]\" if len(self.layers) > 0 else False\n\n        if self.cfg.final_norm:\n            self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)\n        else:\n            self.final_norm = torch.nn.Identity()\n\n        self.register_buffer(\"attention_mask\", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)\n\n    def forward(self, input_ids: torch.Tensor):\n        if input_ids.shape[1] != self.attention_mask.shape[1]:\n            self.attention_mask = get_causal_attention_mask(input_ids)\n        hidden_states = self.embedding(input_ids)\n\n        if self.seq_first:\n            hidden_states = hidden_states.transpose(0, 1).contiguous()\n\n        for i, layer_module in enumerate(self.layers):\n            hidden_states = layer_module(hidden_states, self.attention_mask)\n\n        # if self.seq_first:\n        #     hidden_states = hidden_states.transpose(0, 1).contiguous()\n        # this happens only in the output if necessary\n\n        return self.final_norm(hidden_states)\n\n\nclass ScriptableLMForPreTraining(PreTrainedModel):\n    \"\"\"Pretraining version with optional prediction head and variant for sparse prediction.\"\"\"\n\n    config_class = crammedTransformerConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.cfg = OmegaConf.create(config.arch)\n\n        self.encoder = ScriptableLM(config)\n\n        self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)\n        self.decoder.weight = self.encoder.embedding.word_embedding.weight\n\n        self.loss_fn = torch.nn.CrossEntropyLoss()\n        self._init_weights()\n\n    def _init_weights(self, module=None):\n        modules = self.modules() if module is None else [module]\n        for module in modules:\n            _init_module(\n                module,\n                self.cfg.init.type,\n                self.cfg.init.std,\n                self.cfg.hidden_size,\n                self.cfg.num_transformer_layers,\n            )\n\n    def forward(self, input_ids: torch.Tensor, *args, **kwargs):\n        outputs = self.decoder(self.encoder(input_ids))\n\n        if self.encoder.seq_first:\n            shifted_outputs = outputs[:-1]\n            shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()\n            outputs = outputs.detach().transpose(0, 1)\n        else:\n            shifted_outputs = outputs[..., :-1, :].contiguous()\n            shifted_labels = input_ids[..., 1:].contiguous()\n            outputs = outputs.detach()\n        # Flatten the tokens and compute loss\n        loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), shifted_labels.view(-1))\n\n        return {\"loss\": loss, \"logits\": outputs[:, -1, :], \"log_perplexity\": loss.clone().detach()}\n\n\n# ###### HF registry here? ############### #\n\nAutoConfig.register(\"crammedTransformer\", crammedTransformerConfig)\nAutoModel.register(crammedTransformerConfig, ScriptableLM)\nAutoModelForCausalLM.register(crammedTransformerConfig, ScriptableLMForPreTraining)\n"
  },
  {
    "path": "cramming/architectures/embeddings.py",
    "content": "\"\"\"Non-standard embedding implementations.\"\"\"\n\nimport torch\nimport math\n\nfrom typing import Tuple\nfrom einops import repeat\nimport random\n\n\nclass PositionalEmbedding(torch.nn.Module):\n    # https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15C1-L31C37\n    def __init__(self, demb):\n        super(PositionalEmbedding, self).__init__()\n\n        self.demb = demb\n\n        inv_freq = (1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))).float()\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n    def forward(self, pos_seq, bsz=None):\n        # sinusoid_inp = torch.ger(pos_seq, self.inv_freq)\n        tensor_24_17_1 = pos_seq.float().unsqueeze(2)\n\n        vector_512_expanded = self.inv_freq.unsqueeze(0).unsqueeze(1)\n\n        result = torch.matmul(tensor_24_17_1, vector_512_expanded)\n\n        sinusoid_inp = result.squeeze(2)\n\n        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)\n        return pos_emb\n\n\nclass RandomNoise(torch.nn.Module):\n\n    def __init__(self, embedding_dim, max_seq_length=5000):\n        super().__init__()\n        self.embedding_dim = embedding_dim\n\n    def forward(self, input_ids):\n        return torch.normal(0, 0.1, size=(input_ids.size(0), input_ids.size(1), self.embedding_dim)).to(input_ids.device)\n\n\nclass RPE(torch.nn.Module):\n    # https://jaketae.github.io/study/relative-positional-encoding/\n    # def __init__(self, embedding_dim, max_seq_length=5000):\n    #     super().__init__()\n\n    # def forward(self, input_ids):\n    #     return torch.normal(0, 0.1, size=input_ids.shape)\n    def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):\n        super().__init__()\n        d_head, remainder = divmod(d_model, num_heads)\n        if remainder:\n            raise ValueError(\"incompatible `d_model` and `num_heads`\")\n        self.max_len = max_len\n        self.d_model = d_model\n        self.num_heads = num_heads\n        self.key = torch.nn.Linear(d_model, d_model)\n        self.value = torch.nn.Linear(d_model, d_model)\n        self.query = torch.nn.Linear(d_model, d_model)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.Er = torch.nn.Parameter(torch.randn(max_len, d_head))\n        self.register_buffer(\"mask\", torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0))\n        # self.mask.shape = (1, 1, max_len, max_len)\n\n    def forward(self, x):\n        # x.shape == (batch_size, seq_len, d_model)\n        batch_size, seq_len, _ = x.shape\n\n        if seq_len > self.max_len:\n            raise ValueError(\"sequence length exceeds model capacity\")\n\n        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)\n        # k_t.shape = (batch_size, num_heads, d_head, seq_len)\n        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)\n        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)\n        # shape = (batch_size, num_heads, seq_len, d_head)\n\n        start = self.max_len - seq_len\n        Er_t = self.Er[start:, :].transpose(0, 1)\n        # Er_t.shape = (d_head, seq_len)\n        QEr = torch.matmul(q, Er_t)\n        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)\n        Srel = self.skew(QEr)\n        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)\n\n        QK_t = torch.matmul(q, k_t)\n        # QK_t.shape = (batch_size, num_heads, seq_len, seq_len)\n        attn = (QK_t + Srel) / math.sqrt(q.size(-1))\n        mask = self.mask[:, :, :seq_len, :seq_len]\n        # mask.shape = (1, 1, seq_len, seq_len)\n        attn = attn.masked_fill(mask == 0, float(\"-inf\"))\n        # attn.shape = (batch_size, num_heads, seq_len, seq_len)\n        attn = torch.nn.functional.softmax(attn, dim=-1)\n        out = torch.matmul(attn, v)\n        # out.shape = (batch_size, num_heads, seq_len, d_head)\n        out = out.transpose(1, 2)\n        # out.shape == (batch_size, seq_len, num_heads, d_head)\n        out = out.reshape(batch_size, seq_len, -1)\n        # out.shape == (batch_size, seq_len, d_model)\n        return self.dropout(out)\n\n    def skew(self, QEr):\n        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)\n        padded = torch.nn.functional.pad(QEr, (1, 0))\n        # padded.shape = (batch_size, num_heads, seq_len, 1 + seq_len)\n        batch_size, num_heads, num_rows, num_cols = padded.shape\n        reshaped = padded.reshape(batch_size, num_heads, num_cols, num_rows)\n        # reshaped.size = (batch_size, num_heads, 1 + seq_len, seq_len)\n        Srel = reshaped[:, :, 1:, :]\n        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)\n        return Srel\n\n\n# module partially stolen from pytorch examples:\nclass SinusoidalPositional(torch.nn.Module):\n    r\"\"\"Inject some information about the relative or absolute position of the tokens\n    in the sequence. The positional encodings have the same dimension as\n    the embeddings, so that the two can be summed. Here, we use sine and cosine\n    functions of different frequencies.\n    \"\"\"\n\n    def __init__(self, embedding_dim, max_seq_length=5000):\n        super().__init__()\n\n        pe = torch.zeros(max_seq_length, embedding_dim)\n        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n\n        pe = pe.unsqueeze(0)\n        self.register_buffer(\"pe\", pe, persistent=False)\n\n    def forward(self, input_ids):\n        r\"\"\"Inputs of forward function\n        Args:\n            x: the sequence fed to the positional encoder model (required).\n        Shape:\n            x: [batch size, sequence length, embed dim]\n            output: [batch size, sequence length, embed dim]\n        Examples:\n            >>> output = pos_encoder(x)\n        \"\"\"\n        return self.pe[:, : input_ids.shape[1], :]\n\n\nclass ScaledSinosoidal(SinusoidalPositional):\n    \"\"\"Sinusoidal with scaling (see FLASH paper).\"\"\"\n\n    def __init__(self, embedding_dim, max_seq_length):\n        super().__init__(embedding_dim, max_seq_length)\n        self.scale_factor = torch.nn.Parameter(torch.tensor([1.0 / embedding_dim**0.5]))\n\n    def forward(self, input_ids):\n        r\"\"\"Inputs of forward function\n        Args:\n            x: the sequence fed to the positional encoder model (required).\n        Shape:\n            x: [batch size, sequence length, embed dim]\n            output: [batch size, sequence length, embed dim]\n        Examples:\n            >>> output = pos_encoder(x)\n        \"\"\"\n        return self.scale_factor * self.pe[:, : input_ids.shape[1], :]\n\n\nclass LearnablePositional(torch.nn.Module):\n    \"\"\"Shorthand for a learnable embedding.\"\"\"\n\n    def __init__(self, embedding_dim, max_seq_length=1024):\n        super().__init__()\n        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)\n        self.register_buffer(\"position_ids\", torch.arange(max_seq_length).expand((1, -1)))\n\n    def forward(self, input_ids):\n        \"\"\"This is a batch-first implementation\"\"\"\n        position_ids = self.position_ids[:, : input_ids.shape[1]]\n        return self.embedding(position_ids)\n\n\nclass LearnablePositionalRand(torch.nn.Module):\n    \"\"\"Shorthand for a learnable embedding.\"\"\"\n\n    def __init__(self, embedding_dim, max_seq_length=1024):\n        super().__init__()\n        self.max_length = max_seq_length\n        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)\n        self.register_buffer(\"position_ids\", torch.arange(max_seq_length).expand((1, -1)))\n\n    def forward(self, input_ids):\n        \"\"\"This is a batch-first implementation\"\"\"\n        seq_length = input_ids.shape[1]\n        device = input_ids.device\n        if seq_length > self.max_length:  # max length will be increased to max sequnece length if max length is short\n            max_length = seq_length\n        else:\n            max_length = self.max_length\n        position_ids = self.position_ids[:, : input_ids.shape[1]]\n        position_ids = torch.sort(torch.randperm(max_length, dtype=torch.long, device=device)[:seq_length]).values\n        return self.embedding(position_ids)\n\n# Code stolen from GPT-X:\nclass Rotary(torch.nn.Module):\n    def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=True)\n        self.seq_len_cached = def_seq_length\n        self.seq_dim = seq_dim\n        cos_cache, sin_cache = self._get_cos_sin()\n        self.register_buffer(\"cos_cached\", cos_cache, persistent=False)\n        self.register_buffer(\"sin_cached\", sin_cache, persistent=False)\n\n        # Force fusions on batched version\n        def rotate_half(x: torch.Tensor):\n            x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster\n            return torch.cat((-x2, x1), dim=-1)\n\n        def rope_fn(cos: torch.Tensor, sin: torch.Tensor, query_layer: torch.Tensor, key_layer: torch.Tensor):\n            QK = torch.cat([query_layer, key_layer], dim=1)\n            rotated = QK * cos[: QK.shape[0]] + rotate_half(QK) * sin[: QK.shape[0]]\n            return torch.split(rotated, query_layer.shape[1], dim=1)\n\n        self.rope_fn = rope_fn  # handle fusion on module level\n\n    @torch.no_grad()\n    def get_cos_sin_cache(self, x: torch.Tensor):\n        seq_len = x.shape[self.seq_dim]\n        if seq_len != self.seq_len_cached:\n            self.seq_len_cached = x.shape[self.seq_dim]\n            cos_cache, sin_cache = self._get_cos_sin()\n            self.cos_cached = cos_cache.to(x.device)\n            self.sin_cached = sin_cache.to(x.device)\n        return self.cos_cached, self.sin_cached\n\n    def _get_cos_sin(self):\n        t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        emb = torch.cat((freqs, freqs), dim=-1)\n        if self.seq_dim == 0:\n            return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()\n        else:\n            return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()\n\n    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):\n        cos_cached, sin_cached = self.get_cos_sin_cache(query_layer)\n        return self.rope_fn(cos_cached, sin_cached, query_layer, key_layer)\n\n    @torch.jit.export\n    def single_forward(self, inputs: torch.Tensor):\n        \"\"\"For cases where shapes of Q and K do not match.\"\"\"\n        cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]\n        return inputs * cos + self.rotate_half(inputs) * sin\n\n    def rotate_half(self, x: torch.Tensor):\n        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\n        return torch.cat((-x2, x1), dim=-1)  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster\n\nclass RotarySanityCheck(torch.nn.Module):\n    \"\"\"not again...\"\"\"\n\n    def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=True)\n        self.seq_len_cached = def_seq_length\n        self.seq_dim = seq_dim\n        cos_cache, sin_cache = self._get_cos_sin()\n        self.register_buffer(\"cos_cached\", cos_cache, persistent=False)\n        self.register_buffer(\"sin_cached\", sin_cache, persistent=False)\n\n    @torch.no_grad()\n    def get_cos_sin_cache(self, x: torch.Tensor):\n        seq_len = x.shape[self.seq_dim]\n        if seq_len != self.seq_len_cached:\n            self.seq_len_cached = x.shape[self.seq_dim]\n            cos_cache, sin_cache = self._get_cos_sin()\n            self.cos_cached = cos_cache.to(x.device)\n            self.sin_cached = sin_cache.to(x.device)\n        return self.cos_cached, self.sin_cached\n\n    def _get_cos_sin(self):\n        t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        emb = torch.cat((freqs, freqs), dim=-1)\n        if self.seq_dim == 0:\n            return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()\n        else:\n            return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()\n\n    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):\n        # cos, sin = self.get_cos_sin_cache(key_layer)\n        # cos, sin = (cos[offset : query_layer.shape[0] + offset, ...], sin[offset : query_layer.shape[0] + offset, ...])\n        cos, sin = self.cos_cached, self.sin_cached\n        return (query_layer * cos) + (self.rotate_half(query_layer) * sin), (key_layer * cos) + (self.rotate_half(key_layer) * sin)\n\n    def rotate_half(self, x: torch.Tensor):\n        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\n        return torch.cat((-x2, x1), dim=-1)  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster\n\n    @torch.jit.export\n    def single_forward(self, inputs: torch.Tensor):\n        \"\"\"For cases where shapes of Q and K do not match.\"\"\"\n        cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]\n        return inputs * cos + self.rotate_half(inputs) * sin\n\n\n# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/rotary.py who adapted from\n# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py\nclass RotaryEleutherAI(torch.nn.Module):\n    \"\"\"\n    The rotary position embeddings from RoFormer_ (Su et. al).\n    A crucial insight from the method is that the query and keys are\n    transformed by rotation matrices which depend on the relative positions.\n    Other implementations are available in the Rotary Transformer repo_ and in\n    GPT-NeoX_, GPT-NeoX was an inspiration\n    .. _RoFormer: https://arxiv.org/abs/2104.09864\n    .. _repo: https://github.com/ZhuiyiTechnology/roformer\n    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox\n    \"\"\"\n\n    _seq_len_cached: int\n    # _cos_cached: Optional[torch.Tensor]\n    # _sin_cached: Optional[torch.Tensor]\n\n    def __init__(self, dim_model: int, *_, **__):\n        super().__init__()\n        # Generate and save the inverse frequency buffer (non trainable)\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n        _cos_cached, _sin_cached = self._update_cos_sin_tables(torch.randn(1, 128, 1), seq_dimension=-2)\n        self.register_buffer(\"_cos_cached\", _cos_cached, persistent=False)\n        self.register_buffer(\"_sin_cached\", _sin_cached, persistent=False)\n\n    @torch.jit.ignore\n    def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:\n        seq_len = x.shape[seq_dimension]\n\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        # if seq_len != self._seq_len_cached:  # or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:\n        self._seq_len_cached = seq_len\n        t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)\n        # Don't do einsum, it converts fp32 to fp16\n        # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        freqs = torch.outer(t, self.inv_freq)\n        cos_cached = repeat(torch.cos(freqs).to(x.dtype), \"... d -> ... (d 2)\")\n        sin_cached = repeat(torch.sin(freqs).to(x.dtype), \"... d -> ... (d 2)\")\n\n        return cos_cached, sin_cached\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:\n        # assert seq_dimension in [-2, -3]  # Either (bs, h, s, d) or (bs, s, h, d)\n        # self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=seq_dimension)\n\n        return (\n            self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),\n            self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),\n        )\n\n    def rotate_half(self, x: torch.Tensor):\n        x = x.unflatten(dim=-1, sizes=(-1, 2))\n        x1, x2 = x.unbind(dim=-1)\n        rotated_x = torch.stack((-x2, x1), dim=-1)\n        return rotated_x.flatten(start_dim=-2)\n\n    def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_dimension: int = -2):\n        # NOTE: This could probably be moved to Triton\n\n        # Handle a possible sequence length mismatch in between q and k\n        cos = cos[: x.shape[seq_dimension], :]\n        sin = sin[: x.shape[seq_dimension], :]\n        if seq_dimension == -3:\n            cos = cos[:, None, :]\n            sin = sin[:, None, :]\n        return (x * cos) + (self.rotate_half(x) * sin)\n\n\nclass RotaryLLAMA(torch.nn.Module):\n    \"\"\"Facebook implementation of rotary embeddings.\"\"\"\n\n    def __init__(self, hidden_per_head, base=10000, max_seq_length=512, seq_dim: int = 0):\n        super().__init__()\n        self.seq_dim: int = seq_dim\n        freqs_cis = self.precompute_freqs_cis(dim=hidden_per_head, end=max_seq_length * 2, theta=base)\n        self.register_buffer(\"freqs_cis\", freqs_cis)\n\n    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):\n        return self.apply_rotary_emb(query_layer, key_layer, freqs_cis=self.freqs_cis)\n\n    def apply_rotary_emb(self, xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))\n        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)\n\n        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n        return xq_out.type_as(xq), xk_out.type_as(xk)\n\n    def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):\n        freqs_cis = freqs_cis[: x.shape[self.seq_dim]]\n        # shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]\n        # shape = [1, seq_length, 1, hidden_per_head]\n        shape = [s if i == self.seq_dim or i == x.ndim - 1 else 1 for i, s in enumerate(x.shape)]\n        return freqs_cis.view(*shape)\n\n    @staticmethod\n    def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n        t = torch.arange(end, device=freqs.device)  # type: ignore\n        freqs = torch.outer(t, freqs).float()  # type: ignore\n        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n        return freqs_cis\n\nclass FIRE(torch.nn.Module):\n    def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-6, max_length=0):\n        \"\"\"\n        FIRE attention bias module (https://arxiv.org/abs/2310.04418).\n\n        Args:\n            num_heads: number of attention heads.\n            mlp_width: Width of MLP.\n            init_c: initial value of log transformation parameter\n            init_L: initial value of thresholding parameter\n            eps: small constant for numerical stability\n        \"\"\"\n        super(FIRE, self).__init__()\n        self.max_length = max_length  # using random PE\n\n        # Define the MLP layers\n        self.mlp = torch.nn.Sequential(torch.nn.Linear(1, mlp_width), torch.nn.ReLU(), torch.nn.Linear(mlp_width, num_heads))\n\n        # Initialize c (log transformation parameter)\n        self.c = torch.nn.Parameter(torch.tensor(init_c))\n\n        # Initialize L (threshold)\n        self.init_L = torch.nn.Parameter(torch.tensor(init_L), requires_grad=False)\n        self.L_multiplier = torch.nn.Parameter(torch.tensor(1.0))  # learn a multiplier to L\n\n        self.eps = eps\n\n    def forward(self, seq_length, device):\n        \"\"\"\n        Compute FIRE attention bias (https://arxiv.org/abs/2310.04418).\n\n        Args:\n            x: input sequence, shape [bsz, num_heads, seq_len, hidden_dim]\n\n        Returns:\n            attention bias of shape [1, num_heads, seq_len, seq_len]\n        \"\"\"\n        if (seq_length > self.max_length) or (\n            not self.training\n        ):  # max length will be increased to max sequnece length if max length is short\n            max_length = seq_length\n        else:\n            max_length = self.max_length\n\n        # take a subset (of length seq_length) of a random permutation of length max_length, then sort it to\n        positions = torch.sort(torch.randperm(max_length, dtype=torch.float, device=device)[:seq_length]).values\n        relative_distances = positions[:, None] - positions[None, :]\n        \n        # Thresholding the normalizer for short sequence modeling\n        threshold = torch.abs(self.L_multiplier * self.init_L)\n        position_normalizer = torch.max(positions, threshold)[:, None]\n\n        # Amplifying differences among local positions with log transform\n        relative_distances = torch.log(torch.abs(self.c * relative_distances) + 1)\n        position_normalizer = torch.log(torch.abs(self.c * position_normalizer) + 1)\n\n        # Progressive interpolation\n        normalized_distances = relative_distances / (position_normalizer + self.eps)\n        fire_bias = self.mlp(normalized_distances.unsqueeze(-1)).unsqueeze(0)\n        fire_bias = fire_bias.permute(0, 3, 1, 2)\n        \n        return fire_bias\n\nclass Abacus(torch.nn.Module):\n    \"\"\"Abacus Embeddings, learned emebddings resued for each digit\"\"\"\n\n    def __init__(self, embedding_dim, max_seq_length=1024, max_k=99):\n        super().__init__()\n        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)\n        self.register_buffer(\"position_ids\", torch.arange(max_seq_length).expand((1, -1)))\n        self.max_k = max_k # the max_k here by default is 99 as we add it on after istead of generate with it\n\n    def helper(self, mask, device):\n        mask_shape = mask.shape\n        \n        # Create a shifted version of the mask to detect changes from 0 to 1\n        shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)\n        starts = (shifted_mask != mask) & mask\n        \n        # Generate IDs for each segment of 1s, processing row-wise\n        segment_ids = torch.cumsum(starts, dim=1)\n        \n        # Generate an index array row-wise\n        index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)\n        \n        # Reset index at the start of each segment\n        reset_index = torch.zeros_like(mask).long()\n        second_term = index * starts.long()\n        reset_index = reset_index.scatter_add(1, segment_ids, second_term)\n        \n        # Calculate positions in segment\n        positions = index - reset_index.gather(1, segment_ids) + 1\n        \n        # Ensure only values within 1-segments are non-zero\n        result = positions * mask\n\n        return result\n\n    def forward(self, input_ids):\n        \"\"\"This is a batch-first implementation\"\"\"\n        \"\"\"\n        This is a batch-first implementation\n        designed to work with our tokenizers, for a more versatile implementation, look at the abacus.py file\n        sort tokenizer: '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13\n\n        {'0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'D': 14, ',': 15, ':': 16, '=': 17, ' ': 18, 'A': 19, 'B': 20, 'C': 21, 'E': 22, 'F': 23, 'G': 24, 'H': 25, 'I': 26, 'J': 27, 'K': 28, 'L': 29, 'M': 30, 'N': 31, 'O': 32, 'P': 33, 'Q': 34, 'R': 35, 'S': 36, 'T': 37, 'U': 38, 'V': 39, 'W': 40, 'X': 41, 'Y': 42, 'Z': 43, 'a': 44, 'b': 45, 'c': 46, 'd': 47, 'e': 48, 'f': 49, 'g': 50, 'h': 51, 'i': 52, 'j': 53, 'k': 54, 'l': 55, 'm': 56, 'n': 57, 'o': 58, 'p': 59, 'q': 60, 'r': 61, 's': 62, 't': 63, 'u': 64, 'v': 65, 'w': 66, 'y': 67, 'z': 68, '!': 69, '@': 70, '£': 71, '#': 72, '$': 73, '%': 74, '^': 75, '&': 76, '*': 77, '(': 78, ')': 79, '~': 80, '?': 81, '.': 82, '<': 83, '>': 84, '{': 85, '}': 86, '[': 87, ']': 88, ';': 89, '/': 90, '|': 91, 'β': 92, 'Γ': 93, 'Δ': 94, 'δ': 95, 'ε': 96, 'ζ': 97, 'η': 98, 'θ': 99, 'κ': 100, 'Λ': 101, 'λ': 102, 'μ': 103, 'Ξ': 104, 'ξ': 105, 'Π': 106, 'π': 107, 'Σ': 108, 'ς': 109, 'τ': 110, 'Φ': 111, 'φ': 112, 'χ': 113, 'Ψ': 114, 'ψ': 115, 'Ω': 116, 'ω': 117, '[PAD]': 0, '[UNK]': 1, '[BOS]': 2, '[EOS]': 3}\n        \"\"\"\n        mask = (input_ids >= 4) & (input_ids <= 13)\n        output = self.helper(mask, input_ids.device)\n        \n        k=0\n        if self.training:\n            k = random.randint(0, self.max_k)\n            output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1\n\n        return self.embedding(output)"
  },
  {
    "path": "cramming/architectures/huggingface_interface.py",
    "content": "\"\"\"HF model variations based on reconfiguring their huggingface implementations.\"\"\"\n\nimport transformers\n\n\ndef construct_huggingface_model(cfg_arch, vocab_size):\n    \"\"\"construct model from given configuration. Only works if this arch exists on the hub.\"\"\"\n\n    if isinstance(cfg_arch, transformers.PretrainedConfig):\n        configuration = cfg_arch\n    else:\n        model_type = cfg_arch[\"model_type\"]\n        configuration = transformers.AutoConfig.from_pretrained(pretrained_model_name_or_path=model_type, **cfg_arch)\n    configuration.vocab_size = vocab_size\n    model = transformers.AutoModelForPreTraining.from_config(configuration)\n    model.vocab_size = model.config.vocab_size\n\n    old_forward = model.forward\n\n    def modified_forward(input_ids, attention_mask=None, **kwargs):\n        return old_forward(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)\n\n    model.forward = modified_forward\n\n    return model\n"
  },
  {
    "path": "cramming/architectures/losses.py",
    "content": "import torch\nimport math\n\n\nclass CosineLoss(torch.nn.Module):\n    __constants__ = [\"reduction\"]\n    reduction: str\n\n    def __init__(self, reduction: str = \"mean\", dim=-1, eps=1e-8) -> None:\n        super().__init__()\n        self.reduction = reduction\n        assert self.reduction == \"mean\"\n        self.dim = dim\n        self.eps = eps\n\n    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:\n        return 1 - torch.nn.functional.cosine_similarity(x1, x2, self.dim, self.eps).mean()\n\n\nclass CrossEntropyWithZLoss(torch.nn.Module):\n    \"\"\"Cross Entropy plus logit regularization via z_loss.\"\"\"\n\n    __constants__ = [\"ignore_index\", \"z_loss_factor\"]\n    ignore_index: int\n    z_loss_factor: float\n\n    def __init__(self, ignore_index=-100, z_loss_factor=1e-4):\n        super().__init__()\n        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)\n        self.z_loss_factor = z_loss_factor\n        self.ignore_index = ignore_index\n\n    def forward(self, inputs, labels):\n        \"\"\"Is this is the optimal implementation? Is this even what is meant?\n        I wish there were more answers or code for PaLM\n\n        This implementation assumes that log(Z) is log(sum(exp(logits))).\n        The usage of log2 here is also a bit wild...\n        \"\"\"\n        z_reg = inputs.exp().sum(dim=-1).log2().sum() * self.z_loss_factor\n        return self.loss_fn(inputs, labels) + z_reg\n\n\nclass MSELoss(torch.nn.Module):\n    \"\"\"MSE Loss as a drop-in replacement for Cross Entropy Loss.\n\n    This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes.\"\"\"\n\n    def __init__(self, ignore_index=-100):\n        \"\"\"Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)\"\"\"\n        super().__init__()\n        self.ignore_index = ignore_index\n\n    def forward(self, inputs, labels):\n        \"\"\"Is this is the optimal implementation? Could also do an index_select variation...\"\"\"\n        num_classes = inputs.shape[-1]\n        valid_mask = labels != self.ignore_index\n        M = math.sqrt(num_classes)\n        onehot_labels = self._label_to_onehot(labels[valid_mask], M, num_classes=num_classes)\n        return 1 / (2 * M * num_classes) * (inputs[valid_mask] - onehot_labels).pow(2).sum()\n\n    @staticmethod\n    @torch.jit.script\n    def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):\n        onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)\n        onehot_target.scatter_(1, target.view(-1, 1), M)\n        return onehot_target\n\n\nclass MSELossFast(torch.nn.Module):\n    \"\"\"MSE Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels\n\n    This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes.\"\"\"\n\n    def __init__(self, ignore_index=-100):\n        \"\"\"Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)\"\"\"\n        super().__init__()\n        self.ignore_index = ignore_index\n\n    def forward(self, inputs, labels):\n        \"\"\"Is this is the optimal implementation? This at least circumvents literal 1-hot labels\"\"\"\n        num_examples, num_classes = inputs.shape\n        valid_mask = labels != self.ignore_index\n        M = math.sqrt(num_classes)\n\n        inputs = inputs[valid_mask]\n        labels = labels[valid_mask]\n\n        x_i = inputs.pow(2).sum()\n        x_j = inputs[torch.arange(labels.shape[-1]), labels].sum()\n        return 1 / (2 * M * num_classes) * (x_i - 2 * M * x_j + labels.shape[-1] * M**2)\n\n\nclass L1Loss(torch.nn.Module):\n    \"\"\"L1 Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels\n\n    This implementation includes a mean reduction in batch dimension and a 1/num_classes reduction in classes.\"\"\"\n\n    def __init__(self, ignore_index=-100):\n        \"\"\".\"\"\"\n        super().__init__()\n        self.ignore_index = ignore_index\n\n    def forward(self, inputs, labels):\n        \"\"\"Optimal scaling is less clear for L1\"\"\"\n        num_classes = inputs.shape[-1]\n        valid_mask = labels != self.ignore_index\n        M = math.sqrt(num_classes)\n        onehot_labels = self._label_to_onehot(labels[valid_mask], float(num_classes), num_classes=num_classes)\n        return 1 / inputs.shape[0] / M * (inputs[valid_mask] - onehot_labels).abs().sum()\n\n    @staticmethod\n    @torch.jit.script\n    def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):\n        onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)\n        onehot_target.scatter_(1, target.view(-1, 1), M)\n        return onehot_target\n\n\nclass SzegedyLoss(torch.nn.Module):\n    \"\"\"Regression directly back to input embedding. Remove the decoding layer if using this loss.\n\n    As mentioned at https://twitter.com/ChrSzegedy/status/1533322132368728064?t=xz00T1YT3-WiE0id-h3MEA&s=19\n    \"\"\"\n\n    def __init__(self, embedding_layer, ignore_index=-100, overrelaxation=2.0):\n        \"\"\"Overrelax parameter is quite a bit speculative...\"\"\"\n        super().__init__()\n        self.embedding = embedding_layer\n        self.ignore_index = ignore_index\n        self.overrelaxation = overrelaxation\n\n    def forward(self, inputs, labels):\n        \"\"\"This really just does L2(DNN(embed(x[:,:-1]), 2.0 * stop_gradient(embed(x[:,1:]))) as quoted above\"\"\"\n        num_examples, num_classes = inputs.shape\n        valid_mask = labels != self.ignore_index\n        M = math.sqrt(num_classes)\n\n        inputs = inputs[valid_mask]\n        with torch.no_grad():\n            embedded_labels = self.overrelaxation * self.embedding(labels)[valid_mask]\n\n        return (inputs - embedded_labels).pow(2).sum() / labels.shape[-1] / num_classes\n\n\n\"\"\"Focal Loss from https://github.com/clcarwin/focal_loss_pytorch (minimally modernized into pytorch 1.12)\"\"\"\n\n\"\"\"\nMIT License\n\nCopyright (c) 2017 carwin\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\n\nclass FocalLoss(torch.nn.Module):\n    def __init__(self, gamma: float = 5.0, size_average: bool = True, ignore_index: int = -100):\n        super().__init__()\n        self.register_buffer(\"gamma\", torch.as_tensor(gamma, dtype=torch.float), persistent=False)\n        self.size_average = size_average\n        self.ignore_index = ignore_index\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        valid_mask = target != self.ignore_index\n\n        log_probs = torch.nn.functional.log_softmax(input[valid_mask]).gather(1, target[None, valid_mask])\n        loss = -1 * (1 - log_probs.exp()) ** self.gamma * log_probs\n        if self.size_average:\n            return loss.mean()\n        else:\n            return loss.sum()\n\n\nclass IncorrectCrossEntropyLoss(torch.nn.CrossEntropyLoss):\n    \"\"\"CrossEntropyLoss, but only on incorrectly classified examples.\"\"\"\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        with torch.no_grad():\n            incorrect_preds = input.argmax(dim=-1) != target\n        return torch.nn.functional.cross_entropy(\n            input[incorrect_preds],\n            target[incorrect_preds],\n            weight=self.weight,\n            ignore_index=self.ignore_index,\n            reduction=self.reduction,\n            label_smoothing=self.label_smoothing,\n        )\n"
  },
  {
    "path": "cramming/architectures/sanity_check.py",
    "content": "\"\"\"Sanity Check architecture.\"\"\"\nimport torch\nfrom typing import Optional\n\n\nclass SanityCheckforPreTraining(torch.nn.Module):\n    \"\"\"Make big go fast.\"\"\"\n\n    def __init__(self, width, vocab_size):\n        super().__init__()\n        self.word_embedding = torch.nn.Embedding(vocab_size, width, padding_idx=0)\n        self.transform = torch.nn.Linear(width, width, bias=False)\n\n    def forward(\n        self,\n        input_ids,\n        attention_mask: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        token_type_ids: Optional[torch.Tensor] = None,\n    ) -> dict[str, torch.Tensor]:\n\n        embeds = self.word_embedding(input_ids)\n        outputs = self.transform(embeds)\n        loss = outputs.mean()\n\n        return {\"logits\": outputs, \"loss\": loss}\n"
  },
  {
    "path": "cramming/backend/__init__.py",
    "content": "\"\"\"This module implements interfaces to the various backends.\"\"\"\n\nfrom .prepare_backend import load_backend\nfrom .utils import load_model_checkpoint, get_model_engine_tokenizer_dataloaders\n\n__all__ = [\n    \"load_backend\",\n    \"load_model_checkpoint\",\n    \"get_model_engine_tokenizer_dataloaders\",\n]\n"
  },
  {
    "path": "cramming/backend/optimizers/__init__.py",
    "content": "from .progressive_batching import ProgressiveBatching\nfrom .optimizer_modifiers import SAM, LARS\nfrom .schedulers import get_schedule_fn\n"
  },
  {
    "path": "cramming/backend/optimizers/optimizer_modifiers.py",
    "content": "\"\"\"This is the apex LARS implementation, from the apex repository.\n\nIt implements LARS + optional clipping\n\nhttps://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py\n\n\nI did rename it to \"LARS\".\n\"\"\"\n\nimport torch\n\n\nclass MetaOptimizer(torch.optim.Optimizer):\n    \"\"\"base class for a meta optimizer that wraps and modifies an existing pytorch optimizer.\"\"\"\n\n    def __init__(self, optimizer):\n        self.param_groups = optimizer.param_groups\n        self.optim = optimizer\n\n    def __getstate__(self):\n        return self.optim.__getstate__()\n\n    def __setstate__(self, state):\n        self.optim.__setstate__(state)\n\n    def __repr__(self):\n        return self.__class__.__name__ + self.optim.__repr__()\n\n    def __getattr__(self, name):\n        \"\"\"Call this only if all other attributes are exhausted.\"\"\"\n        return getattr(self.optim, name)\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        return self.optim.step(closure)\n\n\nclass LARS(MetaOptimizer):\n    \"\"\"\n    :class:`LARS` [LARC in apex] is a pytorch implementation of both the scaling and clipping variants of LARS,\n    in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive\n    local learning rate for each individual parameter. The algorithm is designed to improve\n    convergence of large batch training.\n\n    See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.\n\n    In practice it modifies the gradients of parameters as a proxy for modifying the learning rate\n    of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.\n\n    ```\n    model = ...\n    optim = torch.optim.Adam(model.parameters(), lr=...)\n    optim = LARS(optim)\n    ```\n\n    Args:\n        optimizer: Pytorch optimizer to wrap and modify learning rate for.\n        trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888\n        clip: Decides between clipping or scaling mode of LARC [LARS + clip].\n              If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter.\n              If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.\n        eps: epsilon kludge to help with numerical stability while calculating adaptive_lr\n    \"\"\"\n\n    def __init__(self, optimizer, trust_coefficient=0.02, clip=False, eps=1e-8):\n        self.param_groups = optimizer.param_groups\n        self.optim = optimizer\n        self.trust_coefficient = trust_coefficient\n        self.eps = eps\n        self.clip = clip\n\n    def step(self, closure=None):\n        loss = None\n        with torch.no_grad():\n            weight_decays = []\n            for group in self.optim.param_groups:\n                # absorb weight decay control from optimizer\n                weight_decay = group[\"weight_decay\"] if \"weight_decay\" in group else 0\n                weight_decays.append(weight_decay)\n                group[\"weight_decay\"] = 0\n                for p in group[\"params\"]:\n                    if p.grad is None:\n                        continue\n                    param_norm = torch.norm(p.data)\n                    grad_norm = torch.norm(p.grad.data)\n\n                    if param_norm != 0 and grad_norm != 0:\n                        # calculate adaptive lr + weight decay\n                        adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)\n\n                        # clip learning rate for LARC\n                        if self.clip:\n                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`\n                            adaptive_lr = min(adaptive_lr / group[\"lr\"], 1)\n\n                        p.grad.data += weight_decay * p.data\n                        p.grad.data *= adaptive_lr\n\n        loss = self.optim.step(closure)\n        # return weight decay control to optimizer\n        for i, group in enumerate(self.optim.param_groups):\n            group[\"weight_decay\"] = weight_decays[i]\n\n        return loss\n\n\n\"\"\"This the SAM pytorch implementation from https://github.com/davda54/sam\nwith a minor modification \"\"\"\n\n\"\"\"\nMIT License\nCopyright (c) 2021 David Samuel\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\n\nclass SAM(MetaOptimizer):\n    def __init__(self, base_optimizer_instance, rho=0.05):\n        assert rho >= 0.0, f\"Invalid rho, should be non-negative: {rho}\"\n        self.rho = rho\n\n        self.optim = base_optimizer_instance\n        self.param_groups = base_optimizer_instance.param_groups\n\n    @torch.no_grad()\n    def first_step(self, zero_grad=False):\n        grad_norm = self._grad_norm()\n        for group in self.param_groups:\n            scale = self.rho / (grad_norm + 1e-12)\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                e_w = p.grad * scale.to(p)\n                p.add_(e_w)  # climb to the local maximum \"w + e(w)\"\n                self.state[p][\"e_w\"] = e_w\n\n        if zero_grad:\n            self.zero_grad()\n\n    @torch.no_grad()\n    def second_step(self, zero_grad=False):\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                p.sub_(self.state[p][\"e_w\"])  # get back to \"w\" from \"w + e(w)\"\n\n        self.optim.step()  # do the actual \"sharpness-aware\" update\n\n        if zero_grad:\n            self.zero_grad()\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        assert closure is not None, \"Sharpness Aware Minimization requires closure, but it was not provided\"\n        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass\n\n        closure()\n        self.first_step(zero_grad=True)\n        loss = closure()\n        self.second_step()\n        return loss\n\n    def _grad_norm(self):\n        # put everything on the same device, in case of model parallelism\n        shared_device = self.param_groups[0][\"params\"][0].device\n        norm = torch.norm(\n            torch.stack([p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group[\"params\"] if p.grad is not None]),\n            p=2,\n        )\n        return norm\n"
  },
  {
    "path": "cramming/backend/optimizers/progressive_batching.py",
    "content": "\"\"\"Implementation of a progressive batching meta optimizer.\nThe optimizer may defer an optimization step until gradient variance is small enough\n\"\"\"\n\nimport torch\n\nfrom collections import defaultdict\nfrom .optimizer_modifiers import MetaOptimizer\n\n\nimport logging\n\nlog = logging.getLogger(__name__)\nDEBUG = False\n\n\nclass ProgressiveBatching(MetaOptimizer):\n    def __init__(self, optimizer, progress_rule=\"norm-based\", theta=0.9, monotone=False, min_sample_guard=2, max_sample_guard=128):\n        super().__init__(optimizer)\n\n        self.progress_rule = progress_rule\n        self.theta = theta\n        self.monotone = monotone\n\n        self.min_sample_guard = min_sample_guard\n        self.max_sample_guard = max_sample_guard\n\n        self.progress_state = defaultdict(dict)\n        self.accumulated_steps = 0\n        self.reset_sample_statistics()\n\n    @torch.no_grad()\n    def step(self):\n        \"\"\"(Maybe) performs a single optimization step.\"\"\"\n        self.update_sample_statistics()\n        if self.accumulated_steps < self.min_sample_guard:\n            rule_check = False\n        else:\n            if self.accumulated_steps > self.max_sample_guard:\n                rule_check = True\n            else:\n                if self.progress_rule == \"norm-based\":\n                    rule_check = self.norm_test()\n                elif self.progress_rule == \"inner-product\":\n                    rule_check = self.inner_product_test()\n                elif self.progress_rule == \"cov\":\n                    rule_check = self.coefficient_of_variation()\n                elif self.progress_rule == \"cosine\":\n                    rule_check = self.cosine_test()\n                else:\n                    raise ValueError(f\"Invalid progress rules {self.progress_rule} given.\")\n\n        if rule_check:\n            self.copy_mean_grad()  # reference running mean in p.grad attributes\n            if self.monotone:\n                self.min_sample_guard = self.accumulated_steps  # raise lower limit if forcing monotone batch sizes\n            self.reset_sample_statistics()  # reset running mean\n            super().step()\n        else:\n            # otherwise defer the step and accumulate more gradients\n            pass\n\n    def inner_product_test(self):\n        \"\"\"Inner product similar to description in Bollapragada,Byrd,Nocedal, \"Adaptive Sampling Strategies for Stochastic Optimization\".\n\n        This is only a zero-memory inner product test.\n        \"\"\"\n\n        global_inner_product, global_variance = 0, 0\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                ndivn1 = self.accumulated_steps / (self.accumulated_steps - 1)\n                corrected_mean = (state[\"running_mean\"] - p.grad / self.accumulated_steps) * ndivn1\n                global_inner_product += (p.grad * corrected_mean).sum()\n                global_variance += corrected_mean.pow(2).sum()\n        final_v = (global_inner_product - global_variance).pow(2)\n\n        if DEBUG:\n            inequality_repr = f\"{final_v / (self.accumulated_steps - 1):10.2f} < {self.theta * global_variance**2:10.2f}\"\n            log.info(f\"{self.accumulated_steps} - {inequality_repr}\")\n\n        return final_v / (self.accumulated_steps - 1) < self.theta * global_variance**2\n\n    def norm_test(self):\n        \"\"\"Sohams version.\"\"\"\n\n        sample_var, mean_norm = 0, 0\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                sample_var += state[\"running_variance\"].sum() / (self.accumulated_steps - 1)  # bessel-corrected variance\n                mean_norm += state[\"running_mean\"].pow(2).sum()\n\n        if DEBUG:\n            log.info(f\"{self.accumulated_steps} -  {sample_var / self.accumulated_steps:10.2f} < {self.theta * mean_norm:10.2f}\")\n\n        return sample_var / self.accumulated_steps < self.theta * mean_norm  # divide by |B| as in bigbatch, original version is theta=1\n\n    def cosine_test(self):\n        \"\"\"Experimental.\"\"\"\n\n        total_angles, num_params = 0, 0\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                ndivn1 = self.accumulated_steps / (self.accumulated_steps - 1)\n                corrected_mean = (state[\"running_mean\"] - p.grad / self.accumulated_steps) * ndivn1\n                total_angles += (p.grad * corrected_mean).sum() / corrected_mean.norm() / p.grad.norm()\n                num_params += 1\n\n        average_angle = total_angles / num_params  # rather the average cosine, this not (yet) the angle\n\n        if DEBUG:\n            log.info(f\"{self.accumulated_steps} -  {average_angle:10.2f} > {self.theta:10.2f}\")\n\n        return average_angle > self.theta\n\n    def coefficient_of_variation(self):\n        \"\"\"unbiased cov test.\"\"\"\n        cov, mean_norm, num_params = 0, 0, 0\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                cov += (state[\"running_variance\"].sum() / (self.accumulated_steps - 1)).sqrt() / (state[\"running_mean\"].pow(2).sum() + 1e-6)\n                mean_norm += state[\"running_mean\"].pow(2).sum()\n                num_params += 1\n\n        unbiased_avg_cov = (1 + 1 / (4 * self.accumulated_steps)) * cov / num_params / self.accumulated_steps\n\n        if DEBUG:\n            log.info(f\"{self.accumulated_steps} -  {unbiased_avg_cov:10.2f} < {self.theta * 100:10.2f}\")\n\n        return unbiased_avg_cov < self.theta * 100\n\n    def update_sample_statistics(self):\n        \"\"\"Update sample statistics based on welford accumulation. At any step variance can be finalized via running_variance / count\"\"\"\n        self.accumulated_steps += 1\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                current_delta = p.grad - state[\"running_mean\"]\n                state[\"running_mean\"] += current_delta / self.accumulated_steps\n                corrected_delta = p.grad - state[\"running_mean\"]\n                state[\"running_variance\"] += current_delta * corrected_delta\n\n    def reset_sample_statistics(self):\n        \"\"\"Allocate new tensors, old references are still required for the optimizer step.\"\"\"\n        self.last_full_step_accumulation = self.accumulated_steps + 1\n        self.accumulated_steps = 0\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                state = self.progress_state[p]\n                state[\"running_mean\"] = torch.zeros_like(p, memory_format=torch.preserve_format)\n                state[\"running_variance\"] = torch.zeros_like(p, memory_format=torch.preserve_format)\n\n    def copy_mean_grad(self):\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                p.grad = self.progress_state[p][\"running_mean\"]\n"
  },
  {
    "path": "cramming/backend/optimizers/schedulers.py",
    "content": "\"\"\"Misc. optimizer implementations.\"\"\"\nimport transformers\nimport math\n\nfrom torch.optim.lr_scheduler import LambdaLR\nimport time\nfrom functools import partial\n\n\ndef get_schedule_fn(cfg_train, elapsed_time: float=0.0, true_budget: float = -1):\n    \"\"\"Returns a callable scheduler_fn(optimizer).\n\n    Todo: Sanitize and unify these schedulers...\n    \"\"\"\n    if true_budget <= 0:\n        true_budget = cfg_train.budget\n    if (cfg_train.warmup_steps) > 0 and (cfg_train.warmup_steps < 1):\n        # warmup could be a percentage in which case this line converts to steps again\n        cfg_train.warmup_steps = int(cfg_train.warmup_steps * cfg_train.steps)\n\n    if (cfg_train.cooldown_steps) > 0 and (cfg_train.cooldown_steps < 1):\n        # cooldown could be a percentage in which case this line converts to steps again\n        cfg_train.cooldown_steps = int(cfg_train.cooldown_steps * cfg_train.steps)\n\n    # Load huggingface schedulers based on total steps\n    if cfg_train.scheduler == \"polynomial-decay\":\n        scheduler_fn = partial(\n            transformers.get_polynomial_decay_schedule_with_warmup,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            lr_end=1e-7,\n            power=1.0,\n        )\n    elif cfg_train.scheduler == \"cosine-decay\":\n        scheduler_fn = partial(\n            transformers.get_cosine_schedule_with_warmup,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            num_cycles=0.5,\n        )\n    elif cfg_train.scheduler == \"inverse-sqrt\":\n        scheduler_fn = partial(\n            get_inverse_sqrt_scheduler,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n        )\n    elif cfg_train.scheduler == \"one-cycle\":  # this is a simplified one-cycle\n        scheduler_fn = partial(\n            get_one_cycle,\n            num_training_steps=cfg_train.steps,\n        )\n    elif cfg_train.scheduler == \"ramp\":  # this is a simplified one-cycle\n        scheduler_fn = partial(\n            get_ramp,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n        )\n        \"\"\"Budget Schedulers from here: \"\"\"\n    elif cfg_train.scheduler == \"budget-inverse-sqrt\":\n        scheduler_fn = partial(\n            get_budget_inv_sqrt_scheduler,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-constant\":\n        scheduler_fn = partial(\n            get_budget_constant_scheduler,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-cosine-decay\":\n        scheduler_fn = partial(\n            get_budget_cosine_schedule_with_warmup,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            num_cycles=0.5,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-cosine-annealing\":\n        scheduler_fn = partial(\n            get_budget_cosine_half_cycles_with_warmup,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            num_cycles=4,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-linear\":\n        scheduler_fn = partial(\n            get_budget_linear_schedule_with_warmup,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-polynomial\":\n        scheduler_fn = partial(\n            get_budget_polynomial_decay_with_warmup,\n            hour_budget=true_budget,\n            num_warmup_steps=cfg_train.warmup_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-one-cycle\":  # this is a simplified one-cycle\n        scheduler_fn = partial(\n            get_budget_one_cycle,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-multi-cycle\":\n        scheduler_fn = partial(\n            get_budget_multi_cycle,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-ramp\":\n        scheduler_fn = partial(\n            get_budget_ramp,\n            hour_budget=true_budget,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-inv-cosine\":\n        scheduler_fn = partial(\n            get_budget_inv_cosine_schedule,\n            hour_budget=true_budget,\n            num_cooldown_steps=cfg_train.cooldown_steps,\n            num_training_steps=cfg_train.steps,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-dive\":\n        scheduler_fn = partial(\n            get_budget_dive,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            num_warmup_steps=cfg_train.warmup_steps,\n            falloff=0.5,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-dive-slow\":\n        scheduler_fn = partial(\n            get_budget_dive,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            num_warmup_steps=cfg_train.warmup_steps,\n            falloff=0.75,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-dive-fast\":\n        scheduler_fn = partial(\n            get_budget_dive,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            num_warmup_steps=cfg_train.warmup_steps,\n            falloff=0.25,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-triangle1\":\n        scheduler_fn = partial(\n            get_budget_triangle,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            falloff=0.25,\n            base_percentage=0.5,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler == \"budget-triangle2\":\n        scheduler_fn = partial(\n            get_budget_triangle,\n            hour_budget=true_budget,\n            num_training_steps=cfg_train.steps,\n            falloff=0.25,\n            base_percentage=0.25,\n            elapsed_time=elapsed_time,\n        )\n    elif cfg_train.scheduler in [\n        \"linear\",\n        \"cosine\",\n        \"cosine_with_restarts\",\n        \"polynomial\",\n        \"constant\",\n        \"constant_with_warmup\",\n        \"get_cosine_with_hard_restarts_schedule_with_warmup\",\n        \"get_polynomial_decay_schedule_with_warmup\",\n    ]:\n\n        def scheduler_fn(optimizer):\n            return transformers.get_scheduler(\n                name=cfg_train.scheduler,\n                optimizer=optimizer,\n                num_warmup_steps=cfg_train.warmup_steps,\n                num_training_steps=cfg_train.steps,\n            )\n\n    elif cfg_train.scheduler == \"none\" or cfg_train.scheduler is None:\n        scheduler_fn = DumbScheduler\n    else:\n        raise ValueError(f\"Invalid schedule {cfg_train.scheduler} given.\")\n    return scheduler_fn\n\n\nclass DumbScheduler:\n    def __init__(self, *args, **kwargs):\n        self._step_count = 0\n\n    def step(self, *args, **kwargs):\n        self._step_count += 1\n\n    def _initial_step(self):\n        self.optimizer._step_count = 0\n        self._step_count = 0\n        self.step()\n\n    def state_dict(self):\n        return {}\n\n    def load_state_dict(self, state_dict):\n        self.__dict__.update(state_dict)\n\n    def get_last_lr(self):\n        \"\"\"Return last computed learning rate by current scheduler.\"\"\"\n        return float(\"NaN\")\n\n    def get_lr(self):\n        return float(\"NaN\")\n\n    def print_lr(self, is_verbose, group, lr, epoch=None):\n        print(float(\"NaN\"))\n\n\n\"\"\"FairSeq-like inverse-square-root scheduler:\"\"\"\n\n\ndef get_inverse_sqrt_scheduler(optimizer, num_warmup_steps, num_cooldown_steps, num_training_steps):\n    \"\"\"Decay the LR based on the inverse square root of the update number.\n    We also support a warmup phase where we linearly increase the learning rate\n    from some initial learning rate (`--warmup-init-lr`) until the configured\n    learning rate (`--lr`). Thereafter we decay proportional to the number of\n    updates, with a decay factor set to align with the configured learning rate.\n    During warmup:\n      lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)\n      lr = lrs[update_num]\n    After warmup:\n      lr = decay_factor / sqrt(update_num)\n    where\n      decay_factor = args.lr * sqrt(args.warmup_updates)\n    \"\"\"\n    # linearly warmup for the first args.warmup_updates\n    lr_step = 1 / num_warmup_steps\n    # then, decay prop. to the inverse square root of the update number\n    decay_factor = num_warmup_steps**0.5\n    decayed_lr = decay_factor * (num_training_steps - num_cooldown_steps) ** -0.5\n\n    def lr_lambda(current_step: int):\n        if current_step < num_warmup_steps:\n            return float(current_step * lr_step)\n        elif current_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float(decayed_lr * (num_training_steps - current_step) / num_cooldown_steps))\n        else:\n            return float(decay_factor * current_step**-0.5)\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch=-1)\n\n\ndef get_one_cycle(optimizer, num_training_steps):\n    \"\"\"Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry.\"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_training_steps / 2:\n            return float(current_step / (num_training_steps / 2))\n        else:\n            return float(2 - current_step / (num_training_steps / 2))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_ramp(optimizer, num_cooldown_steps, num_training_steps):\n    \"\"\"to the MOON.\"\"\"\n    max_lr = (num_training_steps - num_cooldown_steps) / num_training_steps\n\n    def lr_lambda(current_step):\n        if current_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float(max_lr * (num_training_steps - current_step) / num_cooldown_steps))\n        else:\n            return float(current_step / num_training_steps)\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\n\"\"\"Wallclock time schedulers.\"\"\"\ndef _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, prev_elapsed_time: float = 0.0):\n    elapsed_hours = (time.time() - initial_time + prev_elapsed_time) / 60 / 60\n    if current_step == 0:\n        fake_step = 0\n    else:\n        fake_step = int(elapsed_hours / hour_budget * num_training_steps)\n        # Warning: denominator could be bigger than 1 if passed original budget, so be careful with checkpointing\n    return fake_step\n\n\ndef get_budget_inv_sqrt_scheduler(optimizer, hour_budget, num_warmup_steps, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):\n    \"\"\"Time-based scheduler as described in Iszak et al. plus inv_sqrt.\n    Takes in num_warmup_steps and num_training_steps as normal, but actually squeezes the planned schedule into the\n    budget given by hour_budget, based on wallclock measurements.\n\n    Reference: https://github.com/IntelLabs/academic-budget-bert/blob/main/pretraining/schedules.py\n    \"\"\"\n    decay_factor = num_warmup_steps**0.5\n    decayed_lr = decay_factor * (num_training_steps - num_cooldown_steps) ** -0.5\n    initial_time = time.time()\n\n    def lr_lambda(current_step: int):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_warmup_steps:\n            return float(fake_step / num_warmup_steps)\n        elif fake_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float(decayed_lr * (num_training_steps - fake_step) / num_cooldown_steps))\n        else:\n            return float(decay_factor * fake_step**-0.5)\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch=-1)\n\n\ndef get_budget_constant_scheduler(optimizer, hour_budget, num_warmup_steps, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):\n    \"\"\"Time-based scheduler with optional warmup and cooldown (so technically a trapezoidal shape)\"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step: int):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_warmup_steps:\n            return float(fake_step / num_warmup_steps)\n        elif fake_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float((num_training_steps - fake_step) / num_cooldown_steps))\n        else:\n            return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch=-1)\n\n\ndef get_budget_linear_schedule_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):\n    \"\"\"Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget\"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_warmup_steps:\n            return float(fake_step) / float(max(1, num_warmup_steps))\n        return max(0.0, float(num_training_steps - fake_step) / float(max(1, num_training_steps - num_warmup_steps)))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_cosine_schedule_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):\n    \"\"\"Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget\"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_warmup_steps:\n            return float(fake_step) / float(max(1, num_warmup_steps))\n        progress = float(fake_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_cosine_half_cycles_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):\n    \"\"\"Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget\"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_warmup_steps:\n            return float(fake_step) / float(max(1, num_warmup_steps))\n        progress = float(fake_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_one_cycle(optimizer, hour_budget, num_training_steps, elapsed_time: float = 0.0):\n    \"\"\"Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry.\"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step < num_training_steps / 2:\n            return float(fake_step / (num_training_steps / 2))\n        else:\n            return float(2 - fake_step / (num_training_steps / 2))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_multi_cycle(optimizer, hour_budget, num_training_steps, num_cycles=8, elapsed_time: float = 0.0):\n    \"\"\"Simple multi-cycle scheduler. Not including paper/fastai three-phase things or asymmetry.\"\"\"\n    initial_time = time.time()\n    cycle_length = int(num_training_steps / num_cycles)\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps) % cycle_lengt, elapsed_timeh\n        if fake_step < cycle_length / 2:\n            return float(fake_step / (cycle_length / 2))\n        else:\n            return float(2 - fake_step / (cycle_length / 2))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_ramp(optimizer, hour_budget, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):\n    \"\"\"to the moon.\"\"\"\n    initial_time = time.time()\n    max_lr = (num_training_steps - num_cooldown_steps) / num_training_steps\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if fake_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float(max_lr * (num_training_steps - fake_step) / num_cooldown_steps))\n        else:\n            return float(fake_step / num_training_steps)\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_inv_cosine_schedule(optimizer, hour_budget, num_cooldown_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):\n    \"\"\"An inverse cosine schedule, with limited budget.\"\"\"\n    initial_time = time.time()\n    ult_step = num_training_steps - num_cooldown_steps\n    max_lr = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * (1 - ult_step / float(max(1, num_training_steps))))))\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n\n        progress = 1 - fake_step / float(max(1, num_training_steps))\n        if fake_step > (num_training_steps - num_cooldown_steps):\n            return max(0.0, float(max_lr * (num_training_steps - fake_step) / num_cooldown_steps))\n        else:\n            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_triangle(optimizer, hour_budget, num_training_steps, base_percentage=0.5, falloff=0.5, elapsed_time: float = 0.0):\n    \"\"\"Linear increase from a percentage of the base learning rate, then linear decay.\n\n    plot min(0.5 + x * (1 - 0.5)/(1-0.25) / 1000, 1/0.25 - x / (1000 * 0.25)) from 0 to 1000 in the plot range 0 to 1\n    \"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        return min(\n            base_percentage + fake_step * (1 - base_percentage) / (1 - falloff) / num_training_steps,\n            float(1 / falloff - fake_step / (num_training_steps * falloff)),\n        )\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_dive(optimizer, hour_budget, num_training_steps, num_warmup_steps=0, falloff=0.5, elapsed_time: float = 0.0):\n    \"\"\"Constant, then linear decay.\n    plot min(1, 1/0.5 - x / (1000 * 0.5)) from 0 to 1000 in the plot range 0 to 1\n    \"\"\"\n    initial_time = time.time()\n\n    def lr_lambda(current_step):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n        if current_step < num_warmup_steps:\n            return float(fake_step) / float(max(1, num_warmup_steps))\n        else:\n            return min(1.0, float(1 / falloff - fake_step / (num_training_steps * falloff)))\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n\n\ndef get_budget_polynomial_decay_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, lr_end=0.0, power=1.0, elapsed_time: float = 0.0):\n    \"\"\"Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget\"\"\"\n    initial_time = time.time()\n    lr_init = optimizer.defaults[\"lr\"]\n\n    def lr_lambda(current_step: int):\n        fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)\n\n        if fake_step < num_warmup_steps:\n            return float(fake_step) / float(max(1, num_warmup_steps))\n        elif fake_step > num_training_steps:\n            return lr_end / lr_init  # as LambdaLR multiplies by lr_init\n        else:\n            lr_range = lr_init - lr_end\n            decay_steps = num_training_steps - num_warmup_steps\n            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n            decay = lr_range * pct_remaining**power + lr_end\n            return decay / lr_init  # as LambdaLR multiplies by lr_init\n\n    return LambdaLR(optimizer, lr_lambda, -1)\n"
  },
  {
    "path": "cramming/backend/prepare_backend.py",
    "content": "\"\"\"Instantiate backend objects in a congruent format.\"\"\"\nimport torch\n\nfrom .torch_default import initialize_torch\n\n_default_setup = dict(device=torch.device(\"cpu\"), dtype=torch.float)\n\n\ndef load_backend(model, tokenizer, cfg_train, cfg_impl, setup=_default_setup, init_compile_and_distribute=True):\n    if cfg_impl.name == \"torch-default\":\n        return initialize_torch(model, tokenizer, cfg_train, cfg_impl, setup=setup, init_compile_and_distribute=init_compile_and_distribute)\n    else:\n        raise ValueError(f\"Invalid backend {cfg_impl.name} given.\")\n"
  },
  {
    "path": "cramming/backend/torch_default.py",
    "content": "\"\"\"Basic training backend engine for pytorch training with all bells and whistles.\n\nInterface set up to be compliant with the deepspeed engine interface.\n\n\nThere are two versions here, the TorchEngineMinimal, which is the default, and TorchEngineFull which contains a few training variations\nthat were tested but ultimately discarded, so read that part only if you're interested.\n\n\"\"\"\n\nimport json\nimport logging\nimport os\nimport time\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Any, Dict, Union\n\nimport torch\nimport torch._inductor.utils\nimport transformers\nfrom omegaconf import OmegaConf\nfrom safetensors.torch import save_file\nfrom torch.distributed.optim import ZeroRedundancyOptimizer\nfrom transformers.utils.generic import working_or_temp_dir\n\nfrom .optimizers import LARS, SAM, ProgressiveBatching\nfrom .optimizers.schedulers import get_schedule_fn\n\n# from .utils import group_parameters, prepare_pretraining_dataloader, prepare_validation_dataloader\nfrom .utils import group_parameters, load_model_checkpoint\n\nlog = logging.getLogger(__name__)\n_default_setup = dict(device=torch.device(\"cpu\"), dtype=torch.float)\nimport warnings\nfrom ..utils import flatten\n\nwarnings.filterwarnings(\"ignore\", \"Detected call of \", UserWarning)  # schedulers are deliberately used differently\n\n\ndef initialize_torch(model, tokenizer, cfg_train, cfg_impl, setup=_default_setup, init_compile_and_distribute=True):\n    \"\"\"initialize a torch engine.\"\"\"\n    model_engine = TorchEngine(\n        model,\n        cfg_train,\n        cfg_impl,\n        setup=setup,\n        seq_length=tokenizer.model_max_length,\n        init_compile_and_distribute=init_compile_and_distribute,\n    )\n    model_engine.train()\n    return model_engine\n\n\nclass TorchEngine(torch.nn.Module):\n    \"\"\"This class mirrors deepspeed functionality and hides variable batch sizes, microbatching, AMP details and compilation\"\"\"\n\n    def __init__(self, model, cfg_train, cfg_impl, setup=_default_setup, seq_length=128, init_compile_and_distribute=True):\n        \"\"\"Load Engine. The model will be compiled by default.\n        init_compile_and_distribute=False => In the case we are loading in a checkpoint we might aswell not send it across GPUs as this will be redone later\n        \"\"\"\n\n        super().__init__()\n\n        self.cfg_train = cfg_train\n        self.cfg_impl = cfg_impl\n        if self.cfg_impl.microbatch_size is None:\n            self.cfg_impl.microbatch_size = self.cfg_train.batch_size\n        if self.cfg_impl.microbatch_size > self.cfg_train.batch_size:\n            raise ValueError(f\"MBS is {self.cfg_impl.microbatch_size}, but BS is only {self.cfg_train.batch_size}.\")\n        self.current_seq_length = seq_length\n\n        # Mixed Precision:\n        enabled = self.cfg_impl.mixed_precision if setup[\"device\"].type != \"cpu\" else False\n        # Modules like LN are unsupported on CPU amp, so mixed precision args are disregarded on CPU\n        # See https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior and check for layer_norm\n        enable_scaling = self.cfg_impl.grad_scaling and self.cfg_impl.mixed_precision and setup[\"device\"].type != \"cpu\"\n        self.scaler = torch.cuda.amp.GradScaler(enabled=enable_scaling)\n        amp_dtype = getattr(torch, self.cfg_impl.mixed_precision_target_dtype) if setup[\"device\"].type != \"cpu\" else torch.bfloat16\n        self.amp_settings = dict(device_type=setup[\"device\"].type, enabled=enabled, dtype=amp_dtype)\n\n        # Choose setup and move model\n        self.setup = setup\n        model.to(**self.setup)\n        self._original_model = model\n        log.info(\"Compiling model, in the Constructor of TorchEngine\")\n        model = torch.compile(\n            model,\n            mode=self.cfg_impl.mode,\n            dynamic=self.cfg_impl.dynamic,\n            fullgraph=self.cfg_impl.fullgraph,\n            backend=self.cfg_impl.backend,\n            disable=not cfg_impl.compile_torch,\n            # detailed options; cannot be given at the same time as mode:\n            options=flatten(cfg_impl._inductor_vars, parent_key=\"\", sep=\".\") if cfg_impl._inductor_vars is not None else None,\n        )\n\n        if torch.distributed.is_initialized():\n            if init_compile_and_distribute:\n                log.info(\"Distributing model, in the Constructor of TorchEngine\")\n                self.model = self._init_distributed(model)\n            else:\n                log.info(\n                    \"<WARNING> NOT Distirbuting model in the Constructor of TorchEngine, we will attempt to do this later as we are loading in a checkpoint\"\n                )\n                self.model = model\n            self.num_machines = torch.distributed.get_world_size()\n        else:\n            self.model = model\n            self.model.no_sync = nullcontext\n            self.num_machines = 1\n\n        # Microbatch accumulation settings and counters\n        self.effective_mbs = self.cfg_impl.microbatch_size * self.num_machines  # across machines\n        self.current_batch_size = self.cfg_train.batch_size if self.cfg_train.batch_size_ramp == 0 else self.effective_mbs\n        self.accumulation_steps_expected = self.current_batch_size // self.effective_mbs\n        self.accumulated_samples = 0  # Record the number of samples seen, reset after triggering gradient update\n        self.steps = 0  # Record the number of times \"step\" has been triggered\n        self.steps_since_reset = 0  # Record the number of times \"step\" has been triggered\n\n        self.initial_time = time.time()\n        self.previous_elapsed_time = 0.0\n        self.optimizer, self.scheduler = _load_optimizer(model, cfg_train, cfg_impl, self.previous_elapsed_time, self.get_true_budget())\n\n    def get_true_budget(self):\n        return (\n            min(self.cfg_train.budget, self.cfg_train.overall_budget - self.previous_elapsed_time / 3600)\n            + self.previous_elapsed_time / 3600\n        )\n\n    def step(self, batch: dict[str, torch.Tensor]):\n        loss = self.forward(**batch)[\"loss\"]\n        self.backward(loss)\n        self.optimizer_step()\n        return loss.detach()\n\n    def to_device(self, batch: dict[str, torch.Tensor], keys: list[str] = [\"input_ids\"]):\n        \"\"\"Move batch of data into device memory.\"\"\"\n        device_batch = {\n            k: v.to(device=self.setup[\"device\"], dtype=torch.long if k == \"input_ids\" else None, non_blocking=True)\n            for k, v in batch.items()\n            if k in keys  # Add more keywords here if needed\n        }\n        return device_batch\n\n    def forward(self, *inputs, **kwargs):\n        self.accumulated_samples += self.effective_mbs\n        context = self.model.no_sync if self.accumulated_samples < self.current_batch_size else nullcontext\n        with context():\n            with torch.autocast(**self.amp_settings):\n                return self.model(*inputs, **kwargs)\n\n    def backward(self, loss):\n        context = self.model.no_sync if self.accumulated_samples < self.current_batch_size else nullcontext\n        with context():\n            return self.scaler.scale(loss / self.accumulation_steps_expected).backward()\n\n    @torch.no_grad()\n    @torch._dynamo.disable()\n    def forward_inference(self, *inputs, **kwargs):\n        with torch.autocast(**self.amp_settings):\n            outputs = self.model(*inputs, **kwargs)[\"logits\"]\n        predictions = outputs.argmax(dim=-1)\n        return outputs, predictions\n\n    @torch._dynamo.disable()\n    @torch.inference_mode()\n    def dynamic_generation(self, *inputs, temperature=0.7, token_limit=100, **kwargs):\n        with torch.autocast(**self.amp_settings):\n            try:\n                if hasattr(self._original_model, \"_generate\"):  # my signature\n                    outputs = self._original_model._generate(*inputs, temperature=temperature, token_limit=token_limit, **kwargs)\n                elif hasattr(self._original_model, \"generate\"):  # hf signature\n                    outputs = self._original_model.generate(\n                        *inputs, do_sample=True, num_beams=1, temperature=temperature, max_new_tokens=token_limit, **kwargs\n                    )\n                else:\n                    raise NotImplementedError()\n            except Exception as e:  # Fallback\n                log.info(f\"Falling back to default generation scheme due to error {e} in model._generate or model.generate.\")\n                # Generate new tokens the dumb way as a fall-back\n                # need to implement the improved way for transformers eventually\n                device_inputs = inputs[0]\n                predicted_ids = []\n                for gen_idx in range(token_limit):\n                    logits = self._original_model(device_inputs, *inputs[1:], **kwargs)[\"logits\"]\n                    predicted_token = torch.multinomial(torch.softmax(logits * temperature, dim=-1), 1)\n                    device_inputs = torch.cat([device_inputs, predicted_token], dim=-1)\n                    predicted_ids += [predicted_token]\n                outputs = torch.cat(predicted_ids, dim=-1)\n        return outputs\n\n    def optimizer_step(self):\n        \"\"\"Requires a scheduler that is based on iterations instead of epochs.\"\"\"\n        self.steps += 1\n        self.steps_since_reset += 1\n        if self.accumulated_samples >= self.current_batch_size:\n            self.accumulated_samples = 0\n\n            if self.cfg_train.gradient_clipping is not None:\n                self.scaler.unscale_(self.optimizer)\n                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg_train.gradient_clipping, norm_type=2.0)\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n            self.optimizer.zero_grad()\n            self.schedule_batch_size()\n        self.scheduler.step()  # Trigger in every step, otherwise things get annoying with grad accumulation\n\n    def set_train_batch_size(self, batch_size):\n        \"\"\"Allow dynamic modifications of batch size.\"\"\"\n        self.current_batch_size = batch_size\n        self.accumulation_steps_expected = self.current_batch_size // self.effective_mbs\n\n    def schedule_batch_size(self):\n        \"\"\"Optionally implement linear batch size ramp-ups.\"\"\"\n        mbs = self.effective_mbs\n\n        if (self.cfg_train.batch_size_ramp > 0) and (self.cfg_train.batch_size_ramp < 1):\n            # interpret as percentage of total budget\n            elapsed_time = (time.time() - self.initial_time) + self.previous_elapsed_time\n            elapsed_hours = elapsed_time / 60 / 60\n            fake_step = int(elapsed_hours / self.get_true_budget() * self.cfg_train.steps)\n            # WARNING: this does not correctly pick up from checkpoint if elapsed>budget i.e. going over the orginal budget may cause a problem here\n\n            batch_size_step = self.cfg_train.batch_size / (self.cfg_train.steps * self.cfg_train.batch_size_ramp)\n\n            new_batch_size = min(int(fake_step * batch_size_step // mbs + 1) * mbs, self.cfg_train.batch_size)\n        elif self.steps < self.cfg_train.batch_size_ramp:\n            batch_size_step = self.cfg_train.batch_size / self.cfg_train.batch_size_ramp\n            new_batch_size = int(self.steps * batch_size_step // mbs + 1) * mbs\n        else:\n            new_batch_size = self.cfg_train.batch_size\n        self.set_train_batch_size(new_batch_size)\n\n    def record_batch_size(self):\n        if self.cfg_train.optim_mod.name != \"progressive-batching\":\n            return self.current_batch_size\n        else:\n            return self.optimizer.last_full_step_accumulation * self.current_batch_size\n\n    def record_tokens_per_step(self):\n        \"\"\"Tokens in each microbatch step.\"\"\"\n        return self.current_seq_length * self.effective_mbs\n\n    @torch.no_grad()\n    def retrieve_model_state_dict(self):\n        if self.cfg_impl.compile_torch:\n            if torch.distributed.is_initialized():\n                state_dict = self.model.module._orig_mod.state_dict()  # ughhhh\n            else:\n                state_dict = self.model._orig_mod.state_dict()  # ugh\n        else:\n            if torch.distributed.is_initialized():\n                state_dict = self.model.module.state_dict()\n            else:\n                state_dict = self.model.state_dict()\n\n        state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}\n        return state_dict\n\n    def _init_distributed(self, model):\n        model = torch.nn.parallel.DistributedDataParallel(\n            model,\n            device_ids=[self.setup[\"device\"]] if self.setup[\"device\"].type == \"cuda\" else None,\n            output_device=self.setup[\"device\"] if self.setup[\"device\"].type == \"cuda\" else None,\n            broadcast_buffers=self.cfg_impl.broadcast_buffers,\n            bucket_cap_mb=self.cfg_impl.bucket_cap_mb,\n            gradient_as_bucket_view=self.cfg_impl.gradient_as_bucket_view,\n            static_graph=self.cfg_impl.static_graph,\n        )\n        return model\n\n    def load_checkpoint(self, cfg_arch, file, skip_optim_state=False) -> Dict[str, Any]:\n        \"\"\"Load list of states from checkpoint file. Not generally compatible with any other engine?\"\"\"\n        self.optimizer.zero_grad()\n        # defaults\n        metadata = {\"epochs\": 0, \"steps\": 0, \"loss\": 0, \"data_idx\": 0, \"elapsed_time\": 0.0}\n        if file.startswith(\"hf://\"):\n            if file.endswith(\"-untrained\"):\n                log.info(\"Loading NO pretrained model as a sanity check ...\")\n            else:\n                self.model = self.model.from_pretrained(file.split(\"hf://\")[1], config=cfg_arch).to(**self.setup)\n                # reinit optimizer:\n                self.optimizer, self.scheduler = _load_optimizer(\n                    self.model, self.cfg_train, self.cfg_impl, metadata.get(\"elapsed_time\", 0.0), self.get_true_budget()\n                )\n        else:\n            # we load back into original model as we want to redistribute the weights across ranks to be super sure!!\n            model = load_model_checkpoint(self._original_model, file)\n            model.to(**self.setup)\n            # reinitialising the model as we are losing the compile speed otherwise\n            model = torch.compile(\n                model,\n                mode=self.cfg_impl.mode,\n                dynamic=self.cfg_impl.dynamic,\n                fullgraph=self.cfg_impl.fullgraph,\n                backend=self.cfg_impl.backend,\n                disable=not self.cfg_impl.compile_torch,\n                # detailed options; cannot be given at the same time as mode:\n                options=flatten(self.cfg_impl._inductor_vars, parent_key=\"\", sep=\".\") if self.cfg_impl._inductor_vars is not None else None,\n            )\n            if torch.distributed.is_initialized():\n                self.model = self._init_distributed(model)\n                log.info(\"Recompiled and distributed\")\n            else:\n                self.model = model\n                log.info(\"Recompiled\")\n\n            if not skip_optim_state:\n                state_file = os.path.join(file, \"state_dict.pth\")\n                try:\n                    loaded = torch.load(state_file)\n                    optim_state = loaded[\"optim_state\"]\n                    scheduler_state = loaded[\"scheduler_state\"]\n                    scaler_state = loaded[\"scaler_state\"]\n                    metadata = loaded[\"metadata\"]\n                    self.load_metadata(metadata)\n\n                    # this is mainly so that the scheduler knows about the elapsed time\n                    self.optimizer, self.scheduler = _load_optimizer(\n                        self.model, self.cfg_train, self.cfg_impl, self.previous_elapsed_time, self.get_true_budget()\n                    )\n                    self.optimizer.load_state_dict(optim_state)\n                    self.scheduler.load_state_dict(scheduler_state)\n\n                    self.scaler.load_state_dict(scaler_state)\n                    log.info(f\"Successfully loaded state with metadata {metadata}\")\n                except Exception as e:\n                    raise ValueError(f\"Error loading optimizer and scheduler states from {state_file}. {e}\")\n        return metadata\n\n    def load_metadata(self, metadata: Dict[str, Any]):\n        self.steps = metadata.get(\"steps\", 0)\n        self.previous_elapsed_time = metadata.get(\"elapsed_time\", 0.0)\n        # add other state things here\n\n    def save_training_checkpoint(self, checkpoint_directory: str, checkpoint_name: Union[str, float], metadata: Dict[str, Any]):\n        \"\"\"Path, identifier and additional client state. This checkpoint can be used to resume training.\n        The default behavior is to save this checkpoint relative to the training working directory.\n        \"\"\"\n\n        os.makedirs(checkpoint_directory, exist_ok=True)\n        full_path = os.path.join(checkpoint_directory, checkpoint_name)\n\n        optim_state = self.optimizer.state_dict()\n        model_state = self.retrieve_model_state_dict()\n        scheduler_state = self.scheduler.state_dict()\n        scaler_state = self.scaler.state_dict()\n        state_dict = {\n            \"metadata\": metadata,\n            \"optim_state\": optim_state,\n            \"scaler_state\": scaler_state,\n            \"scheduler_state\": scheduler_state,\n        }\n        safetensor_name = f\"{full_path}_model_state.pth\"\n        save_file(model_state, safetensor_name)\n        other_name = f\"{full_path}_non_model.pth\"\n        torch.save(state_dict, other_name)\n\n    def save_final_model(self, base_directory, identifier, tokenizer, cfg_arch, dryrun=False):\n        \"\"\"This checkpoint can be used for downstream tasks.\n        The default behavior is to save this checkpoint to a checkpoints folder under base_directory/name/checkpoints\"\"\"\n        try:\n            identifier_str = f\"{identifier:2.4f}\"\n        except ValueError:\n            identifier_str = str(identifier)\n        full_path = os.path.join(base_directory, \"checkpoints\", identifier_str)\n        os.makedirs(full_path, exist_ok=True)\n        # This saves tokenizer_config.json, tokenizer.json and special_tokens_map.json to this folder\n        if not dryrun:\n\n            # Save model.safetensors, model_config.json\n            save_file(self.retrieve_model_state_dict(), os.path.join(full_path, \"model.safetensors\"))\n            # legacy save: torch.save(self.retrieve_model_state_dict(), os.path.join(full_path, \"model.pth\"))\n            with open(os.path.join(full_path, \"model_config.json\"), \"w\") as file:\n                json.dump(OmegaConf.to_container(cfg_arch, resolve=True), file)\n\n    def save_model(\n        self,\n        checkpoint_directory: str,\n        checkpoint_name: Union[str, float],\n        cfg_arch,\n        metadata: Dict[str, Any],\n        tokenizer=None,\n        save_safe: bool = False,\n    ):\n        \"\"\"This checkpoint can be used for downstream tasks.\n        The default behavior is to save this checkpoint to a checkpoints folder under base_directory/name/checkpoints\"\"\"\n        full_path = os.path.join(checkpoint_directory, checkpoint_name)\n        os.makedirs(full_path, exist_ok=True)\n\n        with open(os.path.join(full_path, \"model_config.json\"), \"w\") as file:\n            json.dump(OmegaConf.to_container(cfg_arch, resolve=True), file)\n\n        model_state = self.retrieve_model_state_dict()\n        state_dict = {\n            \"model_state\": model_state,\n        }\n\n        if save_safe:\n            # this is like the final checkpoint, saves as safetensor but doesn't save state\n            model_state = state_dict.pop(\"model_state\")\n            save_file(model_state, os.path.join(full_path, \"model.safetensors\"))\n\n        if metadata is not None:\n            optim_state = self.optimizer.state_dict()\n            scheduler_state = self.scheduler.state_dict()\n            scaler_state = self.scaler.state_dict()\n            state_dict[\"metadata\"] = metadata\n            state_dict[\"optim_state\"] = optim_state\n            state_dict[\"scheduler_state\"] = scheduler_state\n            state_dict[\"scaler_state\"] = scaler_state\n        if len(state_dict) > 0:\n            # if save_safe this will only save non-model stuff\n            state_dict_path = os.path.join(full_path, \"state_dict.pth\")\n            torch.save(state_dict, state_dict_path)\n\n        return full_path\n\n    def push_to_hub(self, tokenizer, cfg, dryrun=False):\n        \"\"\"Analogous to save_final_model, but save model to hugginface hub.\"\"\"\n        from huggingface_hub import HfApi\n        from io import BytesIO\n\n        api = HfApi()\n\n        if not dryrun:\n            log.info(f\"Pushing model to hub repository {cfg.impl.hf_directoy_name}.\")\n            final_state_dict = self.retrieve_model_state_dict()\n            self.model.load_state_dict(final_state_dict)\n\n            # Push model with safetensors:\n            # This is a manual modification of model.push_to_hub which doesn't support safetensors yet\n            repo_id = cfg.impl.hf_directoy_name\n            if os.path.isdir(repo_id):\n                working_dir = repo_id\n                repo_id = repo_id.split(os.path.sep)[-1]\n            else:\n                working_dir = repo_id.split(\"/\")[-1]\n            repo_id = self.model._create_repo(repo_id)\n            use_temp_dir = not os.path.isdir(working_dir)\n            with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:\n                files_timestamps = self.model._get_files_timestamps(work_dir)\n                # Save all files.\n                self.model.save_pretrained(\n                    work_dir,\n                    max_shard_size=\"10GB\",\n                    safe_serialization=True,\n                    state_dict=self.retrieve_model_state_dict(),\n                )\n                self.model._upload_modified_files(\n                    work_dir,\n                    repo_id,\n                    files_timestamps,\n                    commit_message=None,\n                    token=None,\n                    create_pr=None,\n                )\n            # Push tokenizer:\n            tokenizer.push_to_hub(cfg.impl.hf_directoy_name)\n            # Push config files:\n            for config_group, config_name in zip([cfg.arch, cfg.data, cfg.train], [\"arch\", \"data\", \"train\"]):\n                buffer = BytesIO()\n                buffer.write(json.dumps(OmegaConf.to_container(config_group, resolve=True), indent=4).encode())\n                api.upload_file(\n                    path_or_fileobj=buffer,\n                    path_in_repo=f\"{config_name}_budget_hours_{cfg.budget}.json\",\n                    repo_id=f\"{api.whoami()['name']}/{cfg.impl.hf_directoy_name}\",\n                    # there has to be a better way to do this, but ...\n                    repo_type=\"model\",\n                )\n        else:\n            log.info(f\"Skipping huggingface upload in dryrun state. Would upload to {cfg.impl.hf_directoy_name}.\")\n\n\ndef _load_optimizer(model, cfg_train, cfg_impl, elapsed_time=0.0, true_budget=-1):\n    # Filter some parameters\n    grouped_parameters = group_parameters(model, cfg_train)\n\n    # Select optimizer implementation\n    if cfg_train.optim.type == \"AdamW\":\n        optimizer_class = torch.optim.AdamW\n    elif cfg_train.optim.type == \"Adam\":\n        optimizer_class = torch.optim.Adam\n    elif cfg_train.optim.type == \"RAdam\":\n        optimizer_class = torch.optim.RAdam\n    elif cfg_train.optim.type == \"SGD\":\n        optimizer_class = torch.optim.SGD\n    elif cfg_train.optim.type == \"Adafactor\":\n        optimizer_class = transformers.Adafactor\n    elif cfg_train.optim.type == \"Shampoo\":\n        optimizer_class = Shampoo\n    elif cfg_train.optim.type == \"AdaHessian\":\n        optimizer_class = Adahessian\n    elif cfg_train.optim.type == \"AdamWScale\":\n        optimizer_class = AdamWScale\n    elif cfg_train.optim.type == \"Sophia-G\":\n        optimizer_class = Sophia\n    elif cfg_train.optim.type == \"Lion\":\n        from lion_pytorch import Lion\n\n        optimizer_class = Lion\n\n    elif cfg_train.optim.type == \"Adam8bit\":\n        import bitsandbytes as bnb\n\n        optimizer_class = bnb.optim.Adam8bit\n    elif cfg_train.optim.type == \"AGD\":\n        depth = len(list(model.parameters()))\n        optimizer_class = partial(AGD, depth=depth)\n    else:\n        raise ValueError(f\"Invalid optimizer {cfg_train.optim.type} given.\")\n    optimizer_args = {k: v for k, v in cfg_train.optim.items() if k != \"type\"}\n    if cfg_impl.foreach_optimizer and cfg_train.optim.type != \"Shampoo\":\n        optimizer_args[\"foreach\"] = True\n\n    if torch.distributed.is_initialized() and cfg_impl.zero_redundancy_optimizer:\n        # The overlap option is a whole bucket of problems in itself for now...\n        optimizer = ZeroRedundancyOptimizer(\n            grouped_parameters,\n            optimizer_class=optimizer_class,\n            parameters_as_bucket_view=True,\n            overlap_with_ddp=False,\n            **optimizer_args,\n        )\n    else:\n        optimizer = optimizer_class(grouped_parameters, **optimizer_args)\n\n    if cfg_train.optim_mod.name == \"none\":\n        optimizer_to_schedule = optimizer\n    else:\n        optim_params = {k: v for k, v in cfg_train.optim_mod.items() if k != \"name\"}\n        if cfg_train.optim_mod.name == \"LARS\":\n            optimizer = LARS(optimizer, **optim_params)\n        elif cfg_train.optim_mod.name == \"LARC\":\n            optimizer = LARS(optimizer, **optim_params)\n        elif cfg_train.optim_mod.name == \"SAM\":\n            optimizer = SAM(optimizer, **optim_params)\n        elif cfg_train.optim_mod.name == \"progressive-batching\":\n            optimizer = ProgressiveBatching(optimizer, **optim_params)\n\n        optimizer_to_schedule = optimizer.optim\n\n    scheduler = get_schedule_fn(cfg_train, elapsed_time=elapsed_time, true_budget=true_budget)(optimizer_to_schedule)\n\n    return optimizer, scheduler\n"
  },
  {
    "path": "cramming/backend/utils.py",
    "content": "import logging\nimport os\nimport torch\n\nimport logging\n\nfrom safetensors.torch import load_file, save_file\nimport cramming\n\nlog = logging.getLogger(__name__)\n\n\n\"\"\"Utilities common to several backends.\"\"\"\ndef group_parameters(model, cfg_train):\n    model_parameters = list(model.named_parameters())\n    if len(cfg_train.limited_decay_keys) > 0:\n        grouped_parameters = optimizer_grouped_parameters = [\n            {\n                \"params\": [p for n, p in model_parameters if not any(nd in n for nd in cfg_train.limited_decay_keys)],\n                \"weight_decay\": cfg_train.optim.weight_decay,\n            },\n            {\n                \"params\": [p for n, p in model_parameters if any(nd in n for nd in cfg_train.limited_decay_keys)],\n                \"weight_decay\": 0.0,\n            },\n        ]\n    else:\n        grouped_parameters = [p for n, p in model_parameters]\n    return grouped_parameters\n\n\ndef get_model_engine_tokenizer_dataloaders(cfg, setup, train_eval: bool = True):\n    \"\"\"This function gets the model, model engine (if needed), tokenizer, and data\"\"\"\n    if train_eval:\n        train_eval_cfg = cfg.train\n    else:\n        train_eval_cfg = cfg.eval\n\n    tokenizer_model = None\n    cfg_arch = cfg.arch  # if not loading from checkpoint, need architecture config\n    checkpoint_path = None\n    try:\n        # attempt to load latest in case of preemption\n        prev_checkpoint_path = os.path.join(cfg.model_dir, cfg.name, \"checkpoints\")\n        tokenizer_model, cfg_arch, checkpoint_path = cramming.utils.find_pretrained_checkpoint(\n            \"latest\",\n            local_checkpoint_folder=str(prev_checkpoint_path),\n            arch_modifications=train_eval_cfg.arch_modifications\n        )\n        log.info(f\"Getting latest checkpoint at {prev_checkpoint_path}\")\n\n    except:\n        # no previous checkpoint saved.  Checking separate model directory\n        if train_eval_cfg.checkpoint is not None:\n            try:\n                tokenizer_model, cfg_arch, checkpoint_path = cramming.utils.find_pretrained_checkpoint(\n                    train_eval_cfg.checkpoint,\n                    local_checkpoint_folder=cfg.model_dir,\n                    arch_modifications=train_eval_cfg.arch_modifications\n                )\n                log.info(f\"Found checkpoint at {cfg.model_dir} or {train_eval_cfg.checkpoint}\")\n                # importantly, if checkpoint is found, we will use that model arch, modifications doesnt seem to work.\n            except Exception as e:\n                log.info(f\"Unable to load checkpoint {train_eval_cfg.checkpoint} or in directory {cfg.model_dir}.\"\n                         f\"  Initializing model from scratch!\")\n\n    log.info(f\"Loading Data\")\n    datasets, tokenizer = cramming.load_pretraining_corpus(cfg.data, cfg.impl, cfg.data_dir)\n\n    real_dataset_sample_length = len(datasets['train'][0]['input_ids']) # for arithmetic datasets\n\n    if tokenizer_model is not None:\n        # todo consider if we even need to return the tokenizer with the checkpoint (only HF?)\n        tokenizer = tokenizer_model\n    dataloaders = cramming.prepare_dataloaders(datasets, tokenizer, train_eval_cfg, cfg.impl)\n\n    log.info(f\"Constructing Model\")\n    model = cramming.construct_model(cfg_arch, tokenizer)\n\n    metadata = {}\n\n    if train_eval:\n        # if in train mode, need engine\n        fully_init_model_to_begin = False if checkpoint_path is not None else True\n        model_engine = cramming.load_backend(\n            model,\n            tokenizer,\n            cfg.train,\n            cfg.impl,\n            setup=setup,\n            init_compile_and_distribute=fully_init_model_to_begin, # false if we are planning to load a checkpoint in later\n        )\n\n        if checkpoint_path is not None:\n            # load checkpoint, engine handles loaded model\n            metadata = model_engine.load_checkpoint(cfg_arch, checkpoint_path)\n            for k, v in dataloaders.items():\n                try:\n                    # for dataloaders with epochs (RuntimeInfiniteDataLoader) set that epoch to start here\n                    v.set_epoch(metadata.get(\"epoch\", 0))\n                except:\n                    pass\n\n        model_engine.train(train_eval_cfg.pretrain_in_train_mode)\n        model_engine.current_seq_length = real_dataset_sample_length # setting the number of tokens seen correctly for arithmetic data\n    else:\n        if checkpoint_path is not None:\n            model = load_model_checkpoint(model, checkpoint_path)\n        model_engine = None\n    return model, model_engine, tokenizer, dataloaders, metadata\n\n\ndef load_model_checkpoint(model, model_dir, forward_only_model_with_skip=False):\n    ext = \"model.safetensors\"\n    try:\n        model_file = os.path.join(model_dir, ext)\n        model_state = load_file(model_file)\n    except:\n        ext = \"state_dict.pth\"\n        model_file = os.path.join(model_dir, ext)\n        loaded = torch.load(model_file)\n        model_state = loaded.get(\"model_state\", None)\n\n    if model_state is None:\n        raise ValueError(f\"No model found in directory {model_dir} (in '/state_dict.pth' or '/model.safetensors')\")\n    else:\n        log.info(f\"Loading Model from {model_file}\")\n\n    if \"encoder.embedding.word_embedding.weight\" not in model_state:\n        # Hack to save space when saving the model, more clever though would be save the right one in the first place\n        model_state[\"encoder.embedding.word_embedding.weight\"] = model_state[\"decoder.weight\"]\n    sanitized_state = {}\n    try:\n        for k, v in model_state.items():\n            if k.startswith(\"module.\"):\n                k = k[7:]\n            if forward_only_model_with_skip:\n                if \"_orig_mod\" in k: # we load in original model to here so we can drop this\n                    k = k.replace(\"._orig_mod\", \"\")\n            sanitized_state[k] = v\n\n        model.load_state_dict(sanitized_state, strict=True)\n        log.info(\"finished loading state dict\")\n    except RuntimeError as e:\n        log.info(f\"State dict difference is {str(e).split('Error(s) in loading state_dict for')[1]}... Ok?\")\n        exit()\n\n    return model\n"
  },
  {
    "path": "cramming/config/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/arch/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/arch/albert.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This is set up to be as close to ALBERT-large (Lan et al.) as reasonable for a decoder-based model\n\nmodel_type: ScriptableCrammedDepthRecurrent\n\nlayers_in_recurrent_block: 1\nmaximal_recurrence: 24\nmax_backprop: # use half of maximal_recurrence if not given, minimal is 1 # only valid for TBTT\nmaximal_recurrence_in_eval: 24\n\nhidden_size: 1024\nintermed_size: 4096\ninput_injection_type: none\ninitial_hidden_randomized: False\nstate_init:\n\nnorm: LayerNorm\nnorm_eps: 1e-12\nnorm_scheme: post # can be \"pre\", \"post\"\nnonlin: GELU\nsub_normalization: False\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: True # Whether to include a bias in the decoding step\nuse_bias: True # Whether to learn biases on all dense layers\nfinal_norm: False # Add a final norm layer before the end\nhead: identity\n\nobjective_layout: fixed\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: learned\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: 128\n  normalization: True\n  stable_low_precision: False\n\nattention:\n  type: pytorch # also works with \"pytorch\"\n  num_attention_heads: 16 # for flash\n  skip_output_projection: False\n  qkv_bias: True\n  bias_in_proj: True\n\n  rotary_embedding: False\n  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)\n  sequence_op: torch-softmax # Can be normalization\n  sub_normalization: False # could be turned off separately # Is only used if type=self-attention (i.e the hand-made version)\n\ninit:\n  type: normal\n  std: 0.02\n\nthrottle: False # only active during TBPTT\nlocal_compilation: True # Try to compile the static block, no matter what the global compile setting is set to\n"
  },
  {
    "path": "cramming/config/arch/crammed-depthrecurrent.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This inherits architecture changes from the crammed-bert project\n\nmodel_type: ScriptableCrammedDepthRecurrent\n\nlayers_in_recurrent_block: 4\nmaximal_recurrence: 4\nmax_backprop: # use half of maximal_recurrence if not given, minimal is 1\nmaximal_recurrence_in_eval: ${arch.maximal_recurrence} # could be set to think longer\n\nhidden_size: 768\nintermed_size: 3072\ninput_injection_type: add\ninitial_hidden_randomized: True\nstate_init: embed # initialized random like embedding\n\n\nnorm: LayerNorm\nnorm_eps: 1e-12\nnorm_scheme: post # can be \"pre\", \"post\"\n\nnonlin: GELUglu\nsub_normalization: False # Sub-normalization in attn and ffn blocks\n\ntie_weights: False # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: False # Whether to learn biases on all dense layers\nfinal_norm: True # Add a final norm layer before the end\nhead: ffn\n\nobjective_layout: TBPTT\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: learned\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT\n  normalization: True\n  stable_low_precision: False\n  max_abacus_len: 100\n\nattention:\n  type: pytorch # also works with \"pytorch\"\n  num_attention_heads: 16 # for flash\n  skip_output_projection: False\n  qkv_bias: False\n  bias_in_proj: False\n  max_length: 0 # for randomised PE's (NOT IMPLEMENTED FOR ALL)\n\n  rotary_embedding: False\n  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)\n  sequence_op: torch-softmax # Can be normalization\n  sub_normalization: ${arch.sub_normalization} # could be turned off separately # Is only used if type=self-attention (i.e the hand-made version)\n\ninit:\n  type: deepnorm-straight\n  std: 0.02 # only used if type=normal\n\nthrottle: False # only active during TBPTT\nalpha: 1.0 # only active during TBPTT\nmask_before_equals: False\nlocal_compilation: True # Try to compile the static block, no matter what the global compile setting is set to\nloss_reduction: mean\nforward_only_model_with_skip: False # forward only model with skip"
  },
  {
    "path": "cramming/config/arch/crammed-fakeRNN.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable encoder-based LM with BERT as baseline\n# Modernized version of bert-c5\n\n# These are the huggingface bert parameters\nmodel_type: ScriptableFakeRNN\n\nn_blocks: 5\nstate_size: 512\nhidden_size: 512\nbottle_size: 256\nblock_type: resnet\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\n\nloss: cross-entropy\nobjective_layout: autoregressive\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: None\n  dropout_prob: 0.1 # equal to hidden_dropout_prob in BERT\n  pad_token_id: 0\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT\n  normalization: False\n  stable_low_precision: False\n\ninit:\n  type: normal\n  std: 0.02\n\n# Set dynamically:\neos_token_id:\n"
  },
  {
    "path": "cramming/config/arch/crammed-janus.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable janus-type RNN, right now with all tested bells-and-whistles\n\n# These are the huggingface bert parameters\nmodel_type: ScriptableCrammedJanus\n\nnum_transformer_layers: 8\nstate_dim: 1024\n\nnorm_scheme: shaped\nnorm: LayerNorm\nnorm_eps: 1e-12\n\nnonlin: GELUglu\nsub_normalization: False # Sub-normalization in attn and ffn blocks\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: True # Whether to learn biases on all dense layers\nfinal_norm: True # crashes without this improvement to stability\nforce_normalized_state: False # last normalization learnable?\n\nloss: cross-entropy\nobjective_layout: autoregressive # nothing else implemented so far\n\nffn_block:\n  structure: joined-injection # state-branch-embedding-injection\n\n  intermed_multiplier: 4\n  hidden_dropout_prob: 0.0\n\n  num_chunks_in_sequence: 16 # only necessary if head.structure=chunked\n\nhead:\n  structure: ffn # dense-nonlin-norm\n  nonlin: GELU\n  norm: LayerNorm\n  norm_eps: 1e-12\n  use_bias: True\n  include_attn_in_chunked_heads: False # only valid for chunked heads\n  num_chunked_heads: 4 # only valid for chunked heads\n  intermed_multiplier: 4\n\nobjective:\n  historian_weight: 1.0\n  predictor_weight: 1.0\n  present_historian_weight: 1.0\n  present_predictor_weight: 1.0\n  rscale_correction: False\n\n  antiquarian_weight: 0.0 #\n  antiquarian_range: ${data.seq_length} # maximal range a previous state may be looked up with # set to -1 to encompass all previous states\n  historian_loss_fn: MSE # can also be cosine\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding:\n  embedding_dim: 512\n  normalization: True\n  stable_low_precision: False\n  max_seq_length: ${data.seq_length} # legacy position, do not use\n\n\nmax_seq_length: ${data.seq_length} # max seq length during training (not always used)\nposition_information: learned # none learned or simple\n\ninit:\n  type: megatron\n  std: 0.02 # only used if type=normal\n\n# Experimental options:\nstate_corruption: 0.0\nstate_init: unit\neos_state_reset: True\n\n# Set dynamically:\neos_token_id:\n"
  },
  {
    "path": "cramming/config/arch/crammed-rnn.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable encoder-based LM with BERT as baseline\n# Modernized version of bert-c5\n\n# These are the huggingface bert parameters\nmodel_type: ScriptableCrammedRNN\n\n# PyTorch LSTM settings:\ninput_size: 512\nhidden_size: 512\nnum_layers: 2\nbias: True\nseq_first: True\ndropout: 0.1\nbidirectional: False\nproj_size: 0\n\nnorm: LayerNorm\nnorm_eps: 1e-12\nfinal_norm: True # Add a final norm layer before the end\nskip_head_transform: True # This is only possible if embedding_dim=hidden_size\nuse_bias: False # Whether to learn biases on all dense layers\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\n\nloss: cross-entropy\nobjective_layout: autoregressive\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: scaled-sinusoidal\n  dropout_prob: 0.1 # equal to hidden_dropout_prob in BERT\n  pad_token_id: 0\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.input_size} # has to be this value for crammedBERT\n  normalization: True\n  stable_low_precision: False\n\n# Set dynamically:\neos_token_id:\n"
  },
  {
    "path": "cramming/config/arch/crammed-stack-janus.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable janus-type RNN, right now with all tested bells-and-whistles\n\n# These are the huggingface bert parameters\nmodel_type: ScriptableCrammedJanus\n\nnum_transformer_layers: 8\nstate_dim: 3584\n\nnorm_scheme: shaped\nnorm: LayerNorm\nnorm_eps: 1e-12\n\nnonlin: GELUglu\nsub_normalization: False # Sub-normalization in attn and ffn blocks\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: True # Whether to learn biases on all dense layers\nfinal_norm: True # crashes without this improvement to stability\nforce_normalized_state: True # last normalization learnable?\n\nloss: cross-entropy\nobjective_layout: autoregressive # nothing else implemented so far\n\nffn_block:\n  structure: stack-sideways-transformer\n  intermed_multiplier: 4\n  hidden_dropout_prob: 0.0\n\n  # settings only relevant for structure=state-attention:\n  qkv_bias: True\n  proj_bias: True\n  num_chunks_in_sequence: 16\n  num_read_write_heads: 8\n  run_causal_heads: False\n  positional_info: True\n  garbage_collect_state: False\n  num_blocks_to_accumulate: 0 # Can be any number of embedding chunks that will added to state, this is N^2 atttention again :>\n  gradient_checkpointing: False\n  workspace: ${arch.ffn_block.num_chunks_in_sequence} # only used if block in structure, can be smaller than num_chunks_in_sequence\n\nhead:\n  structure: chunked # dense-nonlin-norm\n  nonlin: GELU\n  norm: LayerNorm\n  norm_eps: 1e-12\n  use_bias: True\n  include_attn_in_chunked_heads: True # only valid for chunked heads\n  num_chunked_heads: 4 # only valid for chunked heads\n  intermed_multiplier: 4\n\nobjective:\n  historian_weight: 1.0\n  predictor_weight: 1.0\n  present_historian_weight: 1.0\n  present_predictor_weight: 1.0\n  rscale_correction: False\n\n  antiquarian_weight: 0.0 #\n  antiquarian_range: ${data.seq_length} # maximal range a previous state may be looked up with # set to -1 to encompass all previous states\n  historian_loss_fn: MSE\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding:\n  embedding_dim: 512\n  normalization: True\n  stable_low_precision: False\n  max_seq_length: ${data.seq_length} # legacy position, do not use\n\n\nmax_seq_length: ${data.seq_length} # max seq length during training (not always used)\nposition_information: learned # none learned or simple\n\ninit:\n  type: deepnorm-straight\n  std: 0.02\n\n# Experimental options:\nstate_corruption: 0.0\neos_state_reset: True\nstate_init: unit\n\n# Set dynamically:\neos_token_id:\n"
  },
  {
    "path": "cramming/config/arch/crammed-tiny.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This is the tiny setting, modified from bert-tiny with larger hidden and lower number of heads\n\nmodel_type: ScriptableCrammedTransformer\n\nnum_transformer_layers: 4\nhidden_size: 256\nintermed_size: 1024\n\nnorm: LayerNorm\nnorm_eps: 1e-12\nnorm_scheme: pre # can be \"pre\", \"post\", \"sandwich\"\nnonlin: GELUglu\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: False # Whether to learn biases on all dense layers\nfinal_norm: True # Add a final norm layer before the end\nsub_normalization: False # Sub-normalization in attn and ffn blocks\n\nloss: cross-entropy\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: scaled-sinusoidal\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT\n  normalization: True\n  stable_low_precision: False\n\nattention:\n  type: pytorch # also works with \"pytorch\"\n  num_attention_heads: 8\n  skip_output_projection: False\n  qkv_bias: False\n  bias_in_proj: False\n\n  rotary_embedding: False\n  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)\n  sequence_op: torch-softmax # Can be normalization\n  sub_normalization: ${arch.sub_normalization} # could be turned off separately # Is only used if type=self-attention (i.e the hand-made version)\n\ninit:\n  type: normal\n  std: 0.02\n"
  },
  {
    "path": "cramming/config/arch/crammed-transformer.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This inherits architecture changes from the crammed-bert project\n# How performant is this?\n\nmodel_type: ScriptableCrammedTransformer\n\nnum_transformer_layers: 16\nhidden_size: 768\nintermed_size: 3072\n\nnorm: LayerNorm\nnorm_eps: 1e-12\nnorm_scheme: pre # can be \"pre\", \"post\"\nnonlin: GELUglu\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: False # Whether to learn biases on all dense layers\nfinal_norm: True # Add a final norm layer before the end\nsub_normalization: False # Sub-normalization in attn and ffn blocks\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: scaled-sinusoidal\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT\n  normalization: True\n  stable_low_precision: False\n\nattention:\n  type: pytorch # also works with \"pytorch\"\n  num_attention_heads: 16 # for flash\n  skip_output_projection: False\n  qkv_bias: False\n  bias_in_proj: False\n\n  rotary_embedding: False\n  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)\n  sequence_op: torch-softmax # Can be normalization\n  sub_normalization: ${arch.sub_normalization} # could be turned off separately # Is only used if type=self-attention (i.e the hand-made version)\n\ninit:\n  type: normal\n  std: 0.02\n"
  },
  {
    "path": "cramming/config/arch/gpt2-base.yaml",
    "content": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This matches the gpt2 settings in the custom implementation\n# (minus dropout which I did not even implement)\n\nmodel_type: ScriptableCrammedTransformer\n\nnum_transformer_layers: 12\nhidden_size: 768\nintermed_size: 3072\n\nnorm: LayerNorm\nnorm_eps: 1e-05\nnorm_scheme: post # can be \"pre\", \"post\"\nnonlin: GELU\n\ntie_weights: True # Tie input/output embedding\ndecoder_bias: False # Whether to include a bias in the decoding step\nuse_bias: True # Whether to learn biases on all dense layers\nfinal_norm: True # Add a final norm layer before the end\nsub_normalization: False\n\nembedding:\n  vocab_size: # will be populated automatically\n  pos_embedding: learned\n  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT\n  normalization: True\n  stable_low_precision: False\n\nattention:\n  type: pytorch # also works with \"pytorch\"\n  num_attention_heads: 12\n  skip_output_projection: False\n  qkv_bias: True\n  bias_in_proj: True\n\n  rotary_embedding: False\n  seq_op_in_fp32: True # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)\n  sequence_op: torch-softmax # Can be normalization\n  sub_normalization: False\n\ninit:\n  type: normal\n  std: 0.02\n"
  },
  {
    "path": "cramming/config/arch/hf-gpt2.yaml",
    "content": "# These are the huggingface bert parameters\n\nmodel_type: \"gpt2\"\n\nn_ctx: 1024\nn_embd: 768\nn_head: 12\nn_layer: 12\nn_positions: ${data.seq_length} # max seq length that the positional embedding is instantiated for\n\n\nactivation_function: \"gelu_new\"\nattn_pdrop: 0.1\nresid_pdrop: 0.1\nembd_pdrop: 0.1\ninitializer_range: 0.02\nlayer_norm_epsilon: 1e-05\n\n\n\n\nsummary_activation: null\nsummary_first_dropout: 0.1\nsummary_proj_to_labels: true\nsummary_type: \"cls_index\"\nsummary_use_proj: true\n\nbos_token_id: 50256\neos_token_id: 50256\n"
  },
  {
    "path": "cramming/config/arch/sanitycheck.yaml",
    "content": "model_type: SanityCheckLM\n\nwidth: 1024 # 8352\n"
  },
  {
    "path": "cramming/config/cfg_eval.yaml",
    "content": "# Configuration defaults\n# Settings are separated into hyperparameters for architecture, data, implementation and train/eval hyperparams\ndefaults:\n  - impl: torch-default\n  - train: common\n  - wandb: default\n  - eval: pythia\n  - data: arithemtic\n  - _self_\n  - override hydra/job_logging: custom\n\nreverse_inputs: True\npad_zeros: 0\nextended_eval: False\ngreedy: True\ntemp: 1.0\ntoken_limit: 30 # number of tokens in 'thinking plot'\nmax_rec: null # to give more or less recurrence at evaluation that during training\n\n## Addition\nremove_padding: True # used as our eval data has some padding in it that needs to be removed on the fly\nlarge: True\nood_only: False\nup_to_40: False\nup_to_50: False\n\ncheckerboard: null\nbig_eval_step_1: False\nbig_eval_step_2: False\nbig_eval_step_3: False\nbig_eval_step_4: False\nbig_eval_step_5: False\nbig_eval_step_6: False\nbig_eval_step_7: False\nbig_eval_step_8: False\nbig_eval_step_9: False\nbig_eval_step_10: False\n\n# for doing custom splits\nmax_size_given: null\nstart_ind_1_given: null\nstart_ind_2_given: null\n\n## Multiplication\nmul: False\n\n## Pos arithmetic\npos_arth: False\npos_arth_ood: False\n\nwandb:\n  project: generative-eval\n\n# Total and central computation budget in hours:\nbudget: 24\noverall_budget: ${budget}\n\nbase_dir: outputs\nmodel_dir:\n\nhydra:\n  sweep:\n    dir: ${base_dir}/${name}/downstream/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  run:\n    dir: ${base_dir}/${name}/downstream/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  job:\n    chdir: True\n\nseed: # Optional: Set initial seed\n\n# A name for this run [will draw the checkpoint from runs with this name\n# and use this name for the summary table and outputs folder]\nname: default\n\n# debug implementation by running every loop just once:\ndryrun: False\n"
  },
  {
    "path": "cramming/config/cfg_pretrain.yaml",
    "content": "# Configuration defaults\n# Settings are separated into hyperparameters for architecture, data, implementation and train/eval hyperparams\ndefaults:\n  - arch: crammed-depthrecurrent\n  - data: arithmetic\n  - impl: torch-default\n  - wandb: default\n  - train: cramming\n  - _self_\n  - override hydra/job_logging: custom\n\nbase_dir: outputs\nmodel_dir: ${base_dir}\ndata_dir:\n\nhydra:\n  sweep:\n    dir: ${base_dir}/${name}/pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  run:\n    dir: ${base_dir}/${name}/pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  job:\n    chdir: True\n\nseed: # Optional: Set initial seed\nname: default # A name for this run [will be used for the summary table and outputs folder]\n\n# Total and central computation budget in hours:\nbudget: 4\noverall_budget: ${budget}\n\n# debug implementation by running every loop just once:\ndryrun: False\n"
  },
  {
    "path": "cramming/config/data/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/data/arithmetic.yaml",
    "content": "name: arithmetic\ndefaults:\n  - sources:\n      - arithmetic\n\n\n\n# all the below stuff may not be required\n# Preprocessing\nnormalizer:\n  force_lowercase: False\n  strip_accents: False\n  force_english_keyboard: False\ntokenizer: bigcode/starcoder\nvocab_size: 49152 #32768 # 2^17\n\n# Dataset Formation\nseq_length: 512\ninclude_eot_token_in_corpus: True\n\nmax_entries_in_raw_dataset: 20e6 # Select only this many examples from the dataset # 20e6 are ok if all are chosen. Oversample if filtering\nmax_seq_in_tokenized_dataset: 80e6 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.25\ndeduplicate_entries: False\ndeduplication_threshold: 75\n\n# Data Order:\nordering: randomized # for now\n\n# Validation Split\nvalidation_seqs: 4096 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/c4-subset-processed.yaml",
    "content": "# This would be a slice of C4\nname: c4-subset\ndefaults:\n  - sources:\n      - c4\n\n# Preprocessing\nnormalizer:\n  force_lowercase: False\n  strip_accents: False\n  force_english_keyboard: False\ntokenizer: SentencePieceBPE\nvocab_size: 131072 # 2^17\n\n# Dataset Formation\nseq_length: 512\ninclude_eot_token_in_corpus: True\n\nmax_entries_in_raw_dataset: 25e6 # Select only this many examples from the dataset # 20e6 are ok if all are chosen. Oversample if filtering\nmax_seq_in_tokenized_dataset: 85e6 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.25\ndeduplicate_entries: False\ndeduplication_threshold: 75\n\n# Data Order:\nordering: randomized # for now\n\n# Validation Split\nvalidation_seqs: 4096 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/openweb.yaml",
    "content": "# Selection of English sources from the ROOTS project\nname: openweb\ndefaults:\n  - sources:\n      - openwebtext\n\n# Preprocessing\nnormalizer:\n  force_lowercase: False\n  strip_accents: False\n  force_english_keyboard: False\ntokenizer: BPE\nvocab_size: 32768 # 2^17\n\n# Dataset Formation\nseq_length: 512\ninclude_eot_token_in_corpus: True\n\nmax_entries_in_raw_dataset: 20e6 # Select only this many examples from the dataset # 20e6 are ok if all are chosen. Oversample if filtering\nmax_seq_in_tokenized_dataset: 80e6 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.25\ndeduplicate_entries: False\ndeduplication_threshold: 75\n\n# Data Order:\nordering: randomized # for now\n\n# Validation Split\nvalidation_seqs: 4096 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/proofpile.yaml",
    "content": "name: proofpile\ndefaults:\n  - sources:\n      - proofpiledata\n\n# Preprocessing\nnormalizer:\n  force_lowercase: False\n  strip_accents: False\n  force_english_keyboard: False\ntokenizer: EleutherAI/llemma_34b\nvocab_size: 49152 #32768 # 2^17\n\n# Dataset Formation\nseq_length: 512\ninclude_eot_token_in_corpus: True\n\nmax_entries_in_raw_dataset: 10e5 #10e6 # Select only this many examples from the dataset # 20e6 are ok if all are chosen. Oversample if filtering\nmax_seq_in_tokenized_dataset: 5e4 #5e5 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.25\ndeduplicate_entries: False\ndeduplication_threshold: 75\n\n# Data Order:\nordering: randomized # for now\n\n# Validation Split\nvalidation_seqs: 4096 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/sanity-check-1.yaml",
    "content": "# Just a bunch of fake data ...\nname: sanity-check-1\ndefaults:\n  - sources:\n      - fake\n\n#\n# Preprocessing\nnormalizer: # This is ignored and the default bert normalizer is used instead\n  force_lowercase:\n  strip_accents:\n  force_english_keyboard:\ntokenizer: gpt2\nvocab_size: 50257\n\n# Dataset Formation\nseq_length: 64\ninclude_eot_token_in_corpus:\n\nmax_entries_in_raw_dataset: 1e12 # Select only this many examples from the dataset\nmax_seq_in_tokenized_dataset: 1e12 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.3\ndeduplicate_entries: False\ndeduplication_threshold: 100\n\n# Data Order:\nordering: randomized # could be a curriculum\n\n# Validation Split\nvalidation_seqs: 128 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/sanity-check-2.yaml",
    "content": "# Just a tiny test dataset ...\nname: sanity-check-2\n# https://hydra.cc/docs/patterns/select_multiple_configs_from_config_group/\ndefaults:\n  - sources:\n      - ag_news\n\n# Preprocessing\nnormalizer:\n  force_lowercase: False\n  strip_accents: False\n  force_english_keyboard: False\ntokenizer: BPE # faster for sanity checks\nvocab_size: 32768 # to make sure there are not memory surprises compared to the actual data\n\n# Dataset Formation\nseq_length: 128\ninclude_eot_token_in_corpus: True\n\nmax_entries_in_raw_dataset: 1e10 # Select only this many examples from the dataset\nmax_seq_in_tokenized_dataset: 1e10 # Select only this many tokenized sequences.\n# max_seq_in_tokenized_dataset should be just slightly more than budget * 60 * 60 * expected tokens/sec for the single epoch of training\n\n# Data Cleaning:\nremove_trash: False\ntrash_cutoff: 0.3\ndeduplicate_entries: False\ndeduplication_threshold: 100\n\n# Data Order:\nordering: randomized # could be a curriculum\n\n# Validation Split\nvalidation_seqs: 128 # how many sequences to reserve for validation\n"
  },
  {
    "path": "cramming/config/data/sources/ag_news.yaml",
    "content": "# For sanity testing\nag_news:\n  provider: huggingface\n  partition: default\n  split: train\n\n  streaming: False\n\n  remove_columns: label\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/arithmetic.yaml",
    "content": "# Just a bunch of fake data ...\narithmetic:\n  provider: arithmetic\n  split:\n\n  randgen_seed: 0\n  size: 2048\n\n  tokenized_dataset_path: \"arithmetic_data/+_n_3_m_3_examples_100_seed_42/hf_tokenized_dataset\"\n  tokenizer_type: # for specifiying which arthmetic tokenizer we want to use\n"
  },
  {
    "path": "cramming/config/data/sources/bookcorpus.yaml",
    "content": "# The bookcorpus dataset, drawn from it huggingface mirror\nbookcorpus:\n  provider: huggingface\n  partition: plain_text\n  split: train\n\n  streaming: False\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 16\n"
  },
  {
    "path": "cramming/config/data/sources/c4.yaml",
    "content": "# The wikipedia en dataset, drawn from it huggingface mirror\nc4:\n  provider: huggingface\n  partition: en\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/dash_books.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_book_dash_books:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/fake.yaml",
    "content": "# Just a bunch of fake data ...\nfake:\n  provider: fake\n  split:\n\n  randgen_seed: 0\n  size: 2048\n"
  },
  {
    "path": "cramming/config/data/sources/iwslt.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_ted_talks_iwslt:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/local.yaml",
    "content": "# Just a bunch of fake data ...\nlocal:\n  provider: local\n  split:\n\n  randgen_seed: 0\n  size: 2048\n"
  },
  {
    "path": "cramming/config/data/sources/no_code_stackexchange.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_no_code_stackexchange:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/openwebtext.yaml",
    "content": "# The open webtext replication, as mirrored on HF\nopenwebtext:\n  provider: huggingface\n  partition: plain_text\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/oscar.yaml",
    "content": "# The oscar dataset, drawn from it huggingface mirror\n# should be 1.2T in this deduplicated version\noscar:\n  provider: huggingface\n  partition: unshuffled_deduplicated_en\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0 # cannot concat when streaming\n"
  },
  {
    "path": "cramming/config/data/sources/proofpiledata.yaml",
    "content": "# The open webtext replication, as mirrored on HF\nEleutherAI/proof-pile-2:\n  provider: huggingface\n  partition: open-web-math #['default', 'arxiv', 'open-web-math', 'algebraic-stack']\n  split: train\n\n  streaming: False #True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/the_pile.yaml",
    "content": "#\nthe_pile:\n  provider: local\n  file_type: json\n  files:\n    - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/01.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/02.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/03.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/04.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/05.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/06.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/07.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/08.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/09.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/10.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/11.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/12.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/13.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/14.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/15.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/16.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/17.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/18.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/19.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/20.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/21.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/22.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/23.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/24.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/25.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/26.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/27.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/28.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/29.jsonl.zst\"\n  filter:\n    #  pile_set_name:\n    # possible pile_set_name values are\n    # Pile-CC 227.12 GiB 18.11% 1.0 227.12 GiB 4.33 KiB\n    # PubMed Central 90.27 GiB 14.40% 2.0 180.55 GiB 30.55 KiB\n    # # Books3† 100.96 GiB 12.07% 1.5 151.44 GiB 538.36 KiB\n    # OpenWebText2 62.77 GiB 10.01% 2.0 125.54 GiB 3.85 KiB\n    # ArXiv 56.21 GiB 8.96% 2.0 112.42 GiB 46.61 KiB\n    # Github 95.16 GiB 7.59% 1.0 95.16 GiB 5.25 KiB\n    # FreeLaw 51.15 GiB 6.12% 1.5 76.73 GiB 15.06 KiB\n    # Stack Exchange 32.20 GiB 5.13% 2.0 64.39 GiB 2.16 KiB\n    # USPTO Backgrounds 22.90 GiB 3.65% 2.0 45.81 GiB 4.08 KiB\n    # PubMed Abstracts 19.26 GiB 3.07% 2.0 38.53 GiB 1.30 KiB\n    # Gutenberg (PG-19)† 10.88 GiB 2.17% 2.5 27.19 GiB 398.73 KiB\n    # OpenSubtitles† 12.98 GiB 1.55% 1.5 19.47 GiB 30.48 KiB\n    # Wikipedia (en)† 6.38 GiB 1.53% 3.0 19.13 GiB 1.11 KiB\n    # DM Mathematics† 7.75 GiB 1.24% 2.0 15.49 GiB 8.00 KiB\n    # Ubuntu IRC 5.52 GiB 0.88% 2.0 11.03 GiB 545.48 KiB\n    # BookCorpus2 6.30 GiB 0.75% 1.5 9.45 GiB 369.87 KiB\n    # EuroParl† 4.59 GiB 0.73% 2.0 9.17 GiB 68.87 KiB\n    # HackerNews 3.90 GiB 0.62% 2.0 7.80 GiB 4.92 KiB\n    # YoutubeSubtitles 3.73 GiB 0.60% 2.0 7.47 GiB 22.55 KiB\n    # PhilPapers 2.38 GiB 0.38% 2.0 4.76 GiB 73.37 KiB\n    # NIH ExPorter 1.89 GiB 0.30% 2.0 3.79 GiB 2.11 KiB\n    # Enron Emails† 0.88 GiB 0.14% 2.0 1.76 GiB 1.78 KiB\n  split: train\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/the_pileCC.yaml",
    "content": "#\nthe_pileCC:\n  provider: local\n  file_type: json\n  files:\n    - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/01.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/02.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/03.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/04.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/05.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/06.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/07.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/08.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/09.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/10.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/11.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/12.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/13.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/14.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/15.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/16.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/17.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/18.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/19.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/20.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/21.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/22.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/23.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/24.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/25.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/26.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/27.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/28.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/29.jsonl.zst\"\n  filter:\n    pile_set_name:\n      - Pile-CC\n  # possible pile_set_name values are\n  # Pile-CC 227.12 GiB 18.11% 1.0 227.12 GiB 4.33 KiB\n  # PubMed Central 90.27 GiB 14.40% 2.0 180.55 GiB 30.55 KiB\n  # # Books3† 100.96 GiB 12.07% 1.5 151.44 GiB 538.36 KiB\n  # OpenWebText2 62.77 GiB 10.01% 2.0 125.54 GiB 3.85 KiB\n  # ArXiv 56.21 GiB 8.96% 2.0 112.42 GiB 46.61 KiB\n  # Github 95.16 GiB 7.59% 1.0 95.16 GiB 5.25 KiB\n  # FreeLaw 51.15 GiB 6.12% 1.5 76.73 GiB 15.06 KiB\n  # Stack Exchange 32.20 GiB 5.13% 2.0 64.39 GiB 2.16 KiB\n  # USPTO Backgrounds 22.90 GiB 3.65% 2.0 45.81 GiB 4.08 KiB\n  # PubMed Abstracts 19.26 GiB 3.07% 2.0 38.53 GiB 1.30 KiB\n  # Gutenberg (PG-19)† 10.88 GiB 2.17% 2.5 27.19 GiB 398.73 KiB\n  # OpenSubtitles† 12.98 GiB 1.55% 1.5 19.47 GiB 30.48 KiB\n  # Wikipedia (en)† 6.38 GiB 1.53% 3.0 19.13 GiB 1.11 KiB\n  # DM Mathematics† 7.75 GiB 1.24% 2.0 15.49 GiB 8.00 KiB\n  # Ubuntu IRC 5.52 GiB 0.88% 2.0 11.03 GiB 545.48 KiB\n  # BookCorpus2 6.30 GiB 0.75% 1.5 9.45 GiB 369.87 KiB\n  # EuroParl† 4.59 GiB 0.73% 2.0 9.17 GiB 68.87 KiB\n  # HackerNews 3.90 GiB 0.62% 2.0 7.80 GiB 4.92 KiB\n  # YoutubeSubtitles 3.73 GiB 0.60% 2.0 7.47 GiB 22.55 KiB\n  # PhilPapers 2.38 GiB 0.38% 2.0 4.76 GiB 73.37 KiB\n  # NIH ExPorter 1.89 GiB 0.30% 2.0 3.79 GiB 2.11 KiB\n  # Enron Emails† 0.88 GiB 0.14% 2.0 1.76 GiB 1.78 KiB\n  split: train\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/the_pile_dedup.yaml",
    "content": "# The EleutherAI/the_pile_deduplicated\nEleutherAI/the_pile_deduplicated:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/the_pile_natural.yaml",
    "content": "#\nthe_pile_natural:\n  provider: local\n  file_type: json\n  files:\n    - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/01.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/02.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/03.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/04.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/05.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/06.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/07.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/08.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/09.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/10.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/11.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/12.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/13.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/14.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/15.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/16.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/17.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/18.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/19.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/20.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/21.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/22.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/23.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/24.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/25.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/26.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/27.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/28.jsonl.zst\"\n    - \"/fs/cml-datasets/Pile/train/29.jsonl.zst\"\n  filter:\n    pile_set_name:\n      - Gutenberg\n      - Books3\n      - Wikipedia (en)\n  # possible pile_set_name values are\n  # Pile-CC 227.12 GiB 18.11% 1.0 227.12 GiB 4.33 KiB\n  # PubMed Central 90.27 GiB 14.40% 2.0 180.55 GiB 30.55 KiB\n  # # Books3† 100.96 GiB 12.07% 1.5 151.44 GiB 538.36 KiB\n  # OpenWebText2 62.77 GiB 10.01% 2.0 125.54 GiB 3.85 KiB\n  # ArXiv 56.21 GiB 8.96% 2.0 112.42 GiB 46.61 KiB\n  # Github 95.16 GiB 7.59% 1.0 95.16 GiB 5.25 KiB\n  # FreeLaw 51.15 GiB 6.12% 1.5 76.73 GiB 15.06 KiB\n  # Stack Exchange 32.20 GiB 5.13% 2.0 64.39 GiB 2.16 KiB\n  # USPTO Backgrounds 22.90 GiB 3.65% 2.0 45.81 GiB 4.08 KiB\n  # PubMed Abstracts 19.26 GiB 3.07% 2.0 38.53 GiB 1.30 KiB\n  # Gutenberg (PG-19)† 10.88 GiB 2.17% 2.5 27.19 GiB 398.73 KiB\n  # OpenSubtitles† 12.98 GiB 1.55% 1.5 19.47 GiB 30.48 KiB\n  # Wikipedia (en)† 6.38 GiB 1.53% 3.0 19.13 GiB 1.11 KiB\n  # DM Mathematics† 7.75 GiB 1.24% 2.0 15.49 GiB 8.00 KiB\n  # Ubuntu IRC 5.52 GiB 0.88% 2.0 11.03 GiB 545.48 KiB\n  # BookCorpus2 6.30 GiB 0.75% 1.5 9.45 GiB 369.87 KiB\n  # EuroParl† 4.59 GiB 0.73% 2.0 9.17 GiB 68.87 KiB\n  # HackerNews 3.90 GiB 0.62% 2.0 7.80 GiB 4.92 KiB\n  # YoutubeSubtitles 3.73 GiB 0.60% 2.0 7.47 GiB 22.55 KiB\n  # PhilPapers 2.38 GiB 0.38% 2.0 4.76 GiB 73.37 KiB\n  # NIH ExPorter 1.89 GiB 0.30% 2.0 3.79 GiB 2.11 KiB\n  # Enron Emails† 0.88 GiB 0.14% 2.0 1.76 GiB 1.78 KiB\n  split: train\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/the_pile_stream.yaml",
    "content": "# Pile streaming from huggingface with new streaming tech :>\n# should be 1.2T in this deduplicated version\nEleutherAI/the_pile:\n  provider: huggingface\n  partition: unshuffled_deduplicated_en\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0 # cannot concat when streaming\n"
  },
  {
    "path": "cramming/config/data/sources/uncorpus.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_uncorpus:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/uspto.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_the_pile_uspto:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikibooks.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_wikibooks:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: False\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikinews.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_wikinews:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: False\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikipedia.yaml",
    "content": "# The wikipedia en dataset, drawn from it huggingface mirror\nwikipedia:\n  provider: huggingface\n  partition: 20220301.en\n  split: train\n\n  streaming: False\n\n  # source-specific cleaning rules?\n  remove_columns: title\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikiquote.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_wikiquote:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikiversity.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_wikiversity:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/data/sources/wikivoyage.yaml",
    "content": "# A part of ROOTS\nbigscience-data/roots_en_wikivoyage:\n  provider: huggingface\n  partition:\n  split: train\n\n  streaming: True\n\n  # source-specific cleaning rules?\n  remove_columns:\n  concatenate_successive_entries: 0\n"
  },
  {
    "path": "cramming/config/eval/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/eval/pythia.yaml",
    "content": "# defaults:\n#   - optim: adam\n#   - tasks:\n      # - winogrande\n      # - lambada_openai\n      # - piqa\n      # - winograd_wsc\n      # - arc\n      # - sciq\n      # - logiqa\n\nname: pythia-tests\n\narch_modifications: null\n# checkpoint name:\n# This can be either \"latest\", or a reference to a specific checkpoint in a subfolder\ncheckpoint: latest\npath: ${impl.path} # Path for caches of datasets and tokenizers\n"
  },
  {
    "path": "cramming/config/eval/tasks/lambada_openai.yaml",
    "content": "# dataset-specific settings\nlambada_openai:\n"
  },
  {
    "path": "cramming/config/eval/tasks/winogrande.yaml",
    "content": "# dataset-specific settings\nwinogrande:\n"
  },
  {
    "path": "cramming/config/hydra/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/hydra/job_logging/custom.yaml",
    "content": "# python logging configuration for tasks\nversion: 1\nformatters:\n  simple:\n    format: \"[%(asctime)s] %(message)s\"\nhandlers:\n  console:\n    class: logging.StreamHandler\n    formatter: simple\n    stream: ext://sys.stdout\n  file:\n    class: logging.FileHandler\n    formatter: simple\n    # relative to the job log directory\n    filename: ${name}_${hydra.job.name}.log\nroot:\n  level: INFO\n  handlers: [console, file]\n\ndisable_existing_loggers: false\n"
  },
  {
    "path": "cramming/config/impl/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/impl/_default.yaml",
    "content": "# Settings for implementation details\n# These settings \"should\" not influence the outcome of the computation in major ways, only its speed.\n# These settings are generic implementation details\n# -----------------------------------------------------------------------------------------------------\n\n# This is the main folder where data will be stored (such as caches of datasets and tokenizers):\n# This can be an absolute path (which will be honored) or a relative path\n# The relative path will be executed relative to the cfg.base_dir\n# This behavior is controlled in the main_launcher\npath: data\n\n# data implementation:\nlocal_staging_dir: # Optionally copy a preprocessed dataset into this folder before loading it for training\nforbid_dataset_preprocessing: True\ntemporary_corpus: False # Save data directly into local staging dir, forget after use\nmax_raw_chunk_size: 1e14\n\n# checkpointing and logging:\nprint_loss_every_nth_step: 1000\nsave_intermediate_checkpoints: False\nsave_every_nth_step: -1\nsave_every_n_minutes: -1\nsave_intermediate_model_name:\n\n# early termination, cancel runs that do not meet this loss threshold early.\nearly_termination:\n  enabled: False\n  budget: 3 # budget in hours\n  loss_threshold: 6.0 # modify this for non-xent losses\n  overall_budget: -1\n\n# Batch size settings:\n# batch_size: This is handled in train after commit 982a4d33cd7f79a48b691114ae78f6ad1cdbee69\nmicrobatch_size: 128 # dont make it larger than batch_size...\n\n# Basic compute settings\nthreads: 32 # maximal number of cpu dataloader workers used per GPU, this value will never exceed num_gpus * num_physical threads\n# Dataloader multiprocessing\npad_to_multiple_of: 8 # padding in dataloader during downstream\nshuffle_in_dataloader: False # There is still shuffling in the preprocessing pipeline.\npin_memory: True\nprefetch_factor: 2\npersistent_workers: True # this clashes with pin_memory in pytorch<1.7.1\n\n# Default floating point precision:\ndefault_precision: float # needs to be a pytorch datatype\n\n# Distributed training\ndist_backend: nccl\nsharing_strategy: # file_descriptor # if no argument is given, then the OS default is picked by pytorch\n\n# Misc:\nenable_huggingface_offline_mode: False\nlocal_rank: # This is set automatically by the system_startup\n\nsave_final_model: False\npush_to_huggingface_hub: False\nhf_directoy_name: \"test-crammedBERT-c5\" # set a clever name here!\n\nadd_env_variables:\n# should be NAME: stringval\n\n# TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE\n# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM\n\n# Other constants:\n# OMP_NUM_THREADS:[number_of_physical_cores]\n# OMP_SCHEDULE:  # STATIC\n# OMP_PROC_BIND: # CLOSE\n# GOMP_CPU_AFFINITY:  # \"N-M\"\n# KMP_AFFINITY: # \"granularity=fine,compact,1,0\"\n# KMP_BLOCKTIME: # 1\n# optional_ld_preloads:\n#  - libiomp5.so\n# - jemalloc.so\n\n#\n# ### jemalloc\n# export MALLOC_CONF=\"oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1\"\n# export LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.1/lib/libjemalloc.so\n#\n# ### tcmalloc\n# export LD_PRELOAD=/home/mingfeim/packages/gperftools-2.8/install/lib/libtcmalloc.so\n\nexample_token_limit: 30 # never generate more example tokens than this\n# example_prompts:\n#   - \"Oh, distinctly I remember, it was in the bleak\"\n#   - \"The capital of Germany is\"\n#   - \"The Westphalian peace ended the\"\n#   - \"Hi! My name is\"\n#   - \"In the place where we were born,\"\n#   - \"Time is a\"\n\n# example_prompts:\n#   - \"System.out.println(\"\n#   - \"public class \"\n#   - \"public static void main\"\n#   - \"/* print hello world */\"\n#   - \"System.out.println(2);\"\n#   - \"for (let i = 0; i < myarray.length; i++) {\"\nexample_prompts:\n    - \"3 + 3 = \"\n    - \"44 + 56 = \"\n    - \"003 + 003 = \"\n    - \"070 + 094 = \"\n    - \"345 + 324 = \"\n    - \"598 + 527 = \"\n    - \"1234 + 4321 = \"\n    - \"94633 + 91826 = \""
  },
  {
    "path": "cramming/config/impl/torch-default.yaml",
    "content": "# Settings for implementation details\n# These settings \"should\" not influence the outcome of the computation in major ways, only its speed.\n# These settings are pytorch implementation details, tuned for singl(ish) GPU, sane pytorch stuff\n# -----------------------------------------------------------------------------------------------------\n\nname: torch-default\ndefaults:\n  - _default\n  - _self_\n\n\n# Basic pytorch settings\nbenchmark: True # CUDNN benchmarking\ndeterministic: False # This option will disable non-deterministic ops\nnon_blocking: True # unblocked .to(device) handles\ntf32_allowed: True\nmatmul_precision: medium # highest/high/medium\n\nmixed_precision: True # turns on AMP on GPUs/Intel devices. The default precision needs to be float\ngrad_scaling: True # Only activates when mixed_precision=True\nmixed_precision_target_dtype: float16 # you might try your luck with bfloat16 too\n\n# Distributed training:\nzero_redundancy_optimizer: False # requires limited_decay_keys=[] for pytorch<=1.10.2\nbroadcast_buffers: False\nbucket_cap_mb: 25\ngradient_as_bucket_view: True\nstatic_graph: True\n\n# scaled dot products:\nenable_mem_efficient_sdp: False\nenable_math_sdp: True\nenable_flash_sdp: True\n\n# Misc:\nforeach_optimizer: False\n\n# Compilation\ncompile_torch: True\nmode: default # overwritten by manual selection of inductor variables below\ndynamic: False # this is a world of pain (when I last tested it, around torch2.0 release)\nfullgraph: True # why even compile when not compile everywhere :>\nbackend: inductor\n_inductor_vars:\n  # max_autotune_gemm: True\n  # max_autotune_pointwise: False # was better in some tests not to enable this?\n  # triton:\n  #   cudagraphs: False # cannot fit with overhead\n  #   # cudagraph_trees: False # fixes memory problems but has scary warning messages\n  # # epilogue_fusion: True # true by default is latest nightly\n  # # aggressive_fusion: False # oom on latest nightly\n  # permute_fusion: True # nice\n  # shape_padding: True # flaky on the new nightly?\n  # optional to mess with the internal inductor config. Maybe not advisable\n  # - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set\n  # - `max_autotune` which will profile to pick the best matmul configuration\n  # - `fallback_random` which is useful when debugging accuracy issues\n  # - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores\n  # - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs\n  # - `trace.enabled` which is the most useful debugging flag to turn on\n  # - `trace.graph_diagram` which will show you a picture of your graph after fusion\n  # - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`\n  # or directly at https://github.com/pytorch/pytorch/blob/master/torch/_inductor/config.py\n"
  },
  {
    "path": "cramming/config/train/__init__.py",
    "content": ""
  },
  {
    "path": "cramming/config/train/common.yaml",
    "content": "# Basic hyperparameter for normal BERT pretraining\n# working hard here to separate \"impl\" implementation details and \"train\" abstract hyperparameters\n\nname: common\n\ndefaults:\n  - optim: adam_classic\n  - optim_mod: disabled\n\noptim:\n  lr: 1e-4\n\nlimited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers\n\n# steps:\nwarmup_steps: 80_000 # These are microbatch steps\ncooldown_steps: 0\nsteps: 8_000_000 # These are microbatch steps at bs=64. The original 1mio steps for BERT are recovered with 512/64=8\nscheduler: polynomial-decay\n\n# Training settting:\nstream_depth: ${data.seq_length} # full sequence as input to model\nbatch_size: 512\nbatch_size_ramp: 0\n\ngradient_clipping:\npretrain_in_train_mode: True # default BERT trains with dropout layers\nreverse_dataset_order: False\n\nbudget: ${budget}\noverall_budget: ${overall_budget}\n"
  },
  {
    "path": "cramming/config/train/cramming.yaml",
    "content": "# Version 4 of changes to bert training hyperparameters\n# Optimizes MLM rate for torch.compile, includes improved weight decay limitation, finally updated to a relative bs ramp\n\nname: cramming-o4\n\ndefaults:\n  - optim: adam\n  - optim_mod: disabled\n\noptim:\n  lr: 1e-3\n  weight_decay: 0.01\n\nlimited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers\n\n# steps:\nwarmup_steps: 0.1\ncooldown_steps: 0.1\nsteps: 12_000_000 # these are microbatch steps. This is an upper limit that is usually never reached\nscheduler: budget-constant\n\n# Training settting:\nstream_depth: ${data.seq_length} # full sequence as input to model\nbatch_size: 8192\nbatch_size_ramp: 0.60\n\ngradient_clipping: 0.5\npretrain_in_train_mode: True # default BERT trains with dropout layers enabled in pretrain\nreverse_dataset_order: False\n\nbudget: ${budget}\noverall_budget: ${overall_budget}\n\n# for loading previously saved\narch_modifications: null\n# checkpoint name:\n# This can be either \"latest\", or a reference to a specific checkpoint in a subfolder\ncheckpoint: latest\npath: ${impl.path} # Path for caches of datasets and tokenizers\n"
  },
  {
    "path": "cramming/config/train/janus-regime.yaml",
    "content": "# Version 4 of changes to bert training hyperparameters\n# Optimizes MLM rate for torch.compile, includes improved weight decay limitation, finally updated to a relative bs ramp\n\nname: cramming-o4\n\ndefaults:\n  - optim: adam\n  - optim_mod: disabled\n\noptim:\n  lr: 1e-3\n  weight_decay: 0.01\n\nlimited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers\n\n# steps:\nwarmup_steps: 0.1\ncooldown_steps: 0.1\nsteps: 4_000_000 # these are microbatch steps. This is an upper limit that is usually never reached\nscheduler: budget-constant\n\n# Training settting:\nstream_depth: 2 # Train one token at a time\nbatch_size: 16384\nbatch_size_ramp: 0.60\n\ngradient_clipping: 0.5\npretrain_in_train_mode: True # default BERT trains with dropout layers enabled in pretrain\nreverse_dataset_order: False\n\nbudget: ${budget}\n"
  },
  {
    "path": "cramming/config/train/optim/adafactor.yaml",
    "content": "type: Adafactor\n\nlr: 0.001\neps:\n  - 1e-30\n  - 0.001\nclip_threshold: 1.0\ndecay_rate: -0.8\nbeta1:\nweight_decay: 0.0\nscale_parameter: False\nrelative_step: False\nwarmup_init: False\n"
  },
  {
    "path": "cramming/config/train/optim/adahessian.yaml",
    "content": "type: AdaHessian\n\nlr: 0.15\nbetas:\n  - 0.9\n  - 0.98\neps: 1e-12\nweight_decay: 0.01\nhessian_power: 1.0\n"
  },
  {
    "path": "cramming/config/train/optim/adam.yaml",
    "content": "type: AdamW\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.98\neps: 1e-12\nweight_decay: 0.01\namsgrad: False\nfused:\n"
  },
  {
    "path": "cramming/config/train/optim/adam8bit.yaml",
    "content": "type: Adam8bit\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.98\neps: 1e-12\nweight_decay: 0.01\namsgrad: False\n"
  },
  {
    "path": "cramming/config/train/optim/adam_classic.yaml",
    "content": "type: Adam\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.999\neps: 1e-8\nweight_decay: 0.01\namsgrad: False\n"
  },
  {
    "path": "cramming/config/train/optim/adamscale.yaml",
    "content": "type: AdamWScale\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.98\neps: 1e-12\nweight_decay: 0.01\ncorrect_bias: True # adamw fix\n"
  },
  {
    "path": "cramming/config/train/optim/agd.yaml",
    "content": "type: AGD\n\ngain: 1.0\n"
  },
  {
    "path": "cramming/config/train/optim/lion.yaml",
    "content": "type: Lion\n\nlr: 1e-4\nbetas:\n  - 0.9\n  - 0.99\n# use 0.95, 0.98 if unstable\nweight_decay: 0.1\n"
  },
  {
    "path": "cramming/config/train/optim/radam.yaml",
    "content": "type: RAdam\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.98\neps: 1e-12\nweight_decay: 0.01\n"
  },
  {
    "path": "cramming/config/train/optim/sgd.yaml",
    "content": "type: SGD\n\nlr: 0.0005\nmomentum: 0.9\ndampening: 0.0\nweight_decay: 0.01\nnesterov: True\n"
  },
  {
    "path": "cramming/config/train/optim/shampoo.yaml",
    "content": "type: Shampoo\n\nlr: 0.0005\nbetas:\n  - 0.9\n  - 0.98\nepsilon: 1e-12\nuse_bias_correction: True\nadam_w_mode: True\nweight_decay: 0.01\ngrafting_type: 4\ngrafting_epsilon: 1e-08\ngrafting_beta2: 0.999\n\nroot_inv_dist: False\n# update_freq (int): frequency for updating inverse preconditioner (Default: 100)\n# init_delay (int): initial delay before starting to compute root inverse (Default: 1000)\n# threshold (int): threshold for switching to diagonal preconditioner (Default: 1024)\n# preconditioner_dtype (torch.dtype): data type for preconditioner (Default: torch.float)\n# large_dim_method (LargeDimMethod): method for handling large scale tensors. (Default: LargeDimMethod.BLOCKING)\n# root_inv_dist (bool): distributes root inverse computation across multiple GPU workers (Default: True)\n# use_merge_dims (bool): merge dimensions if possible while respecting threshold. (Default: True)\n# grafting_type (GraftingType): Selects grafting method. (Default: GraftingType.ADAGRAD)\n# grafting_epsilon (float): Epsilon for grafting method. (Default: 1e-3)\n# grafting_beta2 (float): Exponential moving average factor for grafting method. (Default: 1.0)\n\n# class PreconditionerType(enum.Enum):\n#     FULL = 0\n#     DIAGONAL = 1\n#\n#\n# class GraftingType(enum.Enum):\n#     NONE = 0\n#     SGD = 1\n#     ADAGRAD = 2\n#     RMSPROP = 3\n#     ADAM = 4\n#\n#\n# class LargeDimMethod(enum.Enum):\n#     DIAGONAL = 0\n#     ADAGRAD = 1\n#     BLOCKING = 2\n"
  },
  {
    "path": "cramming/config/train/optim_mod/disabled.yaml",
    "content": "name: none\n"
  },
  {
    "path": "cramming/config/train/optim_mod/larc.yaml",
    "content": "name: LARC\n\ntrust_coefficient: 0.02\nclip: True\neps: 1e-8\n"
  },
  {
    "path": "cramming/config/train/optim_mod/lars.yaml",
    "content": "name: LARS\n\ntrust_coefficient: 0.02\nclip: False\neps: 1e-8\n"
  },
  {
    "path": "cramming/config/train/optim_mod/progressive.yaml",
    "content": "name: progressive-batching\n\nprogress_rule: norm-based\n\nmonotone: False\ntheta: 0.9\n\nmin_sample_guard: 2\nmax_sample_guard: 128\n"
  },
  {
    "path": "cramming/config/train/optim_mod/sam.yaml",
    "content": "name: SAM\nrho: 0.05\n"
  },
  {
    "path": "cramming/config/wandb/default.yaml",
    "content": "enabled: True\nentity: placeholder # change this obviously ;>\nproject: arithmetic\ntags: []\n"
  },
  {
    "path": "cramming/config/wandb/none.yaml",
    "content": "enabled: False\nentity:\nproject:\ntags: []\n"
  },
  {
    "path": "cramming/data/__init__.py",
    "content": "\"\"\"This module handles and hides the data away ;)\"\"\"\n\nfrom .pretraining_preparation import load_pretraining_corpus, prepare_dataloaders\n"
  },
  {
    "path": "cramming/data/arithmetic_tokenizers.py",
    "content": "\"\"\"\nCharacter level tokenizers for arithemtic projects\nMultiple tokenizers for different tasks\n\"\"\"\n\nfrom transformers import PreTrainedTokenizer\nimport re\nimport torch\nimport random\n\nclass CustomCharLevelTokenizerForAddingPadding(PreTrainedTokenizer):\n    \"\"\"Simple char level math tokenizer\"\"\"\n    def __init__(self, **kwargs):\n        # Define the characters to tokenize\n        characters = '0123456789+-x= '\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n\n        super().__init__(**kwargs)\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n\n    @property\n    def vocab_size(self):\n        return len(self.vocab)\n\n    def get_vocab(self):\n        return self.vocab\n\n    def _tokenize(self, text):\n        # Tokenize the text character by character\n        # text = re.sub('\\s+',' ',text)\n        temp = [char if char in self.vocab else self.unk_token for char in text]\n        temp = [item.replace(' ', '[PAD]') for item in temp]\n        return temp\n\n    def _convert_token_to_id(self, token):\n        return self.vocab.get(token, self.vocab[self.unk_token])\n\n    def _convert_id_to_token(self, index):\n        # Convert an ID to its corresponding token\n        return self.ids_to_tokens.get(index, self.unk_token)\n\n    def __call__(self, text, **kwargs):\n        # Tokenize text and convert to input IDs\n        tokens = self._tokenize(text)\n        input_ids = [self._convert_token_to_id(token) for token in tokens]\n        return {\"input_ids\": input_ids}\n\n    def decode(self, token_ids, **kwargs):\n        # Convert token IDs to tokens and join into a string\n        tokens = [self._convert_id_to_token(token_id) for token_id in token_ids]\n        return ''.join(tokens).replace(self.pad_token, '').replace(self.bos_token, '').replace(self.eos_token, '')\n\n\nclass CustomCharLevelTokenizerForAddingPaddingWithIndexHints(PreTrainedTokenizer):\n    \"\"\"Tokenizer for index hints\"\"\"\n    def __init__(self, **kwargs):\n        # Define the characters to tokenize\n        characters = '0123456789+-x= '\n        self.char_set = \"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwyz!@£#$%^&*()~?.,<>{}[]:;/|βΓΔδεζηθκΛλμΞξΠπΣςτΦφχΨψΩω\"\n        characters = characters + self.char_set\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n\n        super().__init__(**kwargs)\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n        \n    @property\n    def vocab_size(self):\n        return len(self.vocab)\n\n    def get_vocab(self):\n        return self.vocab\n\n    def _tokenize(self, text):\n        # Tokenize the text character by character\n        # text = re.sub('\\s+',' ',text)\n        temp = [char if char in self.vocab else self.unk_token for char in text]\n        temp = [item.replace(' ', '[PAD]') for item in temp]\n        return temp\n\n    def _convert_token_to_id(self, token):\n        return self.vocab.get(token, self.vocab[self.unk_token])\n\n    def _convert_id_to_token(self, index):\n        # Convert an ID to its corresponding token\n        return self.ids_to_tokens.get(index, self.unk_token)\n\n    def __call__(self, text, **kwargs):\n        # Tokenize text and convert to input IDs\n        tokens = self._tokenize(text)\n        input_ids = [self._convert_token_to_id(token) for token in tokens]\n        return {\"input_ids\": input_ids}\n\n    def decode(self, token_ids, **kwargs):\n        # Convert token IDs to tokens and join into a string\n        tokens = [self._convert_id_to_token(token_id) for token_id in token_ids]\n        return ''.join(tokens).replace(self.pad_token, '').replace(self.bos_token, '').replace(self.eos_token, '')\n\n\nclass CustomCharLevelTokenizerSort(PreTrainedTokenizer):\n    \"\"\"Tokenizer for sorting\"\"\"\n    def __init__(self, **kwargs):\n        # Define the characters to tokenize\n        characters = '0123456789D,:= '\n        set_of_chars = ['A', 'B', 'C', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',\n                        'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',\n                        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', '!', '@', '£', '#', '$', '%', '^',\n                        '&', '*', '(', ')', '~', '?', '.', '<', '>', '{', '}', '[', ']', ';', '/', '|', 'β', 'Γ', 'Δ',\n                        'δ', 'ε', 'ζ', 'η', 'θ', 'κ', 'Λ', 'λ', 'μ', 'Ξ', 'ξ', 'Π', 'π', 'Σ', 'ς', 'τ', 'Φ', 'φ', 'χ',\n                        'Ψ', 'ψ', 'Ω', 'ω']\n        self.char_set = ''.join(set_of_chars)\n        characters = characters + self.char_set\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n\n        super().__init__(**kwargs)\n\n        # Define and set special tokens\n        self.pad_token = '[PAD]'\n        self.unk_token = '[UNK]'\n        self.bos_token = '[BOS]'\n        self.eos_token = '[EOS]'\n\n        # Combine characters and special tokens to form the custom vocabulary\n        self.vocab = {char: i + 4 for i, char in enumerate(characters)}  # Starting from 4 to account for special tokens\n        self.vocab.update({self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3})\n\n        # Create the reverse mapping from IDs to tokens\n        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}\n\n    @property\n    def vocab_size(self):\n        return len(self.vocab)\n\n    def get_vocab(self):\n        return self.vocab\n\n    def _tokenize(self, text):\n        # Tokenize the text character by character\n        temp = [char if char in self.vocab else self.unk_token for char in text]\n        temp = [item.replace(' ', '[PAD]') for item in temp]\n        return temp\n\n    def _convert_token_to_id(self, token):\n        return self.vocab.get(token, self.vocab[self.unk_token])\n\n    def _convert_id_to_token(self, index):\n        # Convert an ID to its corresponding token\n        return self.ids_to_tokens.get(index, self.unk_token)\n\n    def __call__(self, text, **kwargs):\n        # Tokenize text and convert to input IDs\n        tokens = self._tokenize(text)\n        input_ids = [self._convert_token_to_id(token) for token in tokens]\n        return {\"input_ids\": input_ids}\n\n    def decode(self, token_ids, **kwargs):\n        # Convert token IDs to tokens and join into a string\n        tokens = [self._convert_id_to_token(token_id) for token_id in token_ids]\n        return ''.join(tokens).replace(self.pad_token, '').replace(self.bos_token, '').replace(self.eos_token, '')\n"
  },
  {
    "path": "cramming/data/curriculum_sorting.py",
    "content": "\"\"\"Baseline curricula.\"\"\"\nimport torch\nimport numpy as np\n\nimport logging\n\nlog = logging.getLogger(__name__)\n\n\ndef _sort_tokenized_dataset_by_unigram(tokenized_dataset, tokenizer, num_threads=1, ngram=1, reverse=False):\n    # Force unigram counts per token:\n    map_setup = dict(\n        batched=True,\n        batch_size=1024,\n        # num_proc=None,  # have to reimplement counting as in-out instead of side effects for this to work. Lets see how slow num_proc=0 is\n        load_from_cache_file=False,\n        # keep_in_memory=True,\n    )\n\n    unigrams_counts_per_token = np.zeros(tokenizer.vocab_size, dtype=np.int64)\n\n    def count_unigrams(examples):\n        nonlocal unigrams_counts_per_token\n        unigrams_counts_per_token += np.bincount(np.asarray(examples[\"input_ids\"]).reshape(-1), minlength=tokenizer.vocab_size)\n\n    tokenized_dataset.map(count_unigrams, desc=\"Counting token unigrams\", **map_setup, num_proc=None)\n\n    token_count = sum(unigrams_counts_per_token)\n    k = 1\n    k_smoothed_probs = (unigrams_counts_per_token + k) / (token_count + k * tokenizer.vocab_size)\n    log2_probs = np.log2(k_smoothed_probs)\n\n    def return_seq_prob(examples):\n        logprob_scores = log2_probs[np.asarray(examples[\"input_ids\"])].sum(axis=1) / tokenizer.model_max_length\n        return dict(scores=logprob_scores)\n\n    dataset_probs = tokenized_dataset.map(\n        return_seq_prob,\n        desc=\"Computing log probs per sequence\",\n        remove_columns=tokenized_dataset.column_names,\n        **map_setup,\n        num_proc=num_threads if num_threads > 0 else None,\n    )\n\n    new_order = np.argsort(np.asarray(dataset_probs[\"scores\"]))\n\n    if reverse:\n        new_order = new_order[::-1]\n\n    return tokenized_dataset.select(indices=new_order, writer_batch_size=1024)\n\n\ndef _sort_tokenized_dataset_by_token(tokenized_dataset, tokenizer, target_token_id, num_threads=1):\n    map_setup = dict(\n        batched=True,\n        batch_size=1024,\n        num_proc=num_threads if num_threads > 0 else None,\n        load_from_cache_file=False,\n        # keep_in_memory=True,\n    )\n\n    def count_token(examples):\n        return dict(counts=(np.asarray(examples[\"input_ids\"]) == target_token_id).sum(axis=1))\n\n    dataset_counts = tokenized_dataset.map(\n        count_token,\n        desc=f\"Counting occurrences of token {tokenizer.decode(target_token_id)}\",\n        remove_columns=tokenized_dataset.column_names,\n        **map_setup,\n    )\n\n    new_order = np.argsort(np.asarray(dataset_counts[\"counts\"]))[::-1]\n\n    # Print sentence with most occurrences:\n    sentence_idx = int(new_order[0])\n    input_data = torch.as_tensor(tokenized_dataset[sentence_idx][\"input_ids\"]).squeeze()  # squeeze because hf has leading dim\n    dataset_size = len(tokenized_dataset)\n\n    log.info(\"Sentence with most occurrences of token ...\")\n    log.info(tokenizer.batch_decode(input_data[None])[0])\n\n    sentence_idx = int(new_order[-1])\n    input_data = torch.as_tensor(tokenized_dataset[sentence_idx][\"input_ids\"]).squeeze()  # squeeze because hf has leading dim\n    dataset_size = len(tokenized_dataset)\n\n    log.info(\"Sentence with least occurrences of token ...\")\n    log.info(tokenizer.batch_decode(input_data[None])[0])\n\n    return tokenized_dataset.select(indices=new_order, writer_batch_size=1024)\n\n\ndef _sort_tokenized_dataset_by_word_length(tokenized_dataset, tokenizer, num_threads=1):\n    map_setup = dict(\n        batched=True,\n        batch_size=1024,\n        num_proc=num_threads if num_threads > 0 else None,\n        load_from_cache_file=False,\n    )\n\n    def count_word_lengths(examples):\n        return dict(lengths=[len(s) for s in tokenizer.batch_decode(torch.as_tensor(examples[\"input_ids\"]))])\n\n    dataset_counts = tokenized_dataset.map(\n        count_word_lengths,\n        desc=\"Counting word lengths per sequence\",\n        remove_columns=tokenized_dataset.column_names,\n        **map_setup,\n    )\n\n    new_order = np.argsort(np.asarray(dataset_counts[\"lengths\"]))  # shortest sentences first\n\n    # Print sentence with shortest length\n    sentence_idx = int(new_order[0])\n    input_data = torch.as_tensor(tokenized_dataset[sentence_idx][\"input_ids\"]).squeeze()  # squeeze because hf has leading dim\n    dataset_size = len(tokenized_dataset)\n\n    log.info(\"Sentence with shortest length ...\")\n    log.info(tokenizer.batch_decode(input_data[None])[0])\n\n    sentence_idx = int(new_order[-1])\n    input_data = torch.as_tensor(tokenized_dataset[sentence_idx][\"input_ids\"]).squeeze()  # squeeze because hf has leading dim\n    dataset_size = len(tokenized_dataset)\n\n    log.info(\"and longest ...\")\n    log.info(tokenizer.batch_decode(input_data[None])[0])\n\n    return tokenized_dataset.select(indices=new_order, writer_batch_size=1024)\n"
  },
  {
    "path": "cramming/data/deduplicate.py",
    "content": "\"\"\"This is glue code to connect to the rust-based deduplication of https://github.com/google-research/deduplicate-text-datasets\nthere is probably a smart way to implement deduplication for huggingface datasets directly,\nbut this is just a dumb dump-everything-into-tmp-files solution.\n\nCode based on branch https://github.com/google-research/deduplicate-text-datasets/tree/dev-v1\nSee original license below.\n\"\"\"\n\n\"\"\"Installation how-to:\ncargo install --target-dir ../cramming/dedup\nMake sure to make sure that path_to_rust_code is set to the correct value if installing differently\n\"\"\"\n\n# ORIGINAL LICENSE:\n\n# Copyright 2021 Google LLC\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport datasets\n\nimport os\nimport numpy as np\nfrom tqdm import tqdm\n\nimport time\nimport tempfile\n\nimport torch\n\n\ndef deduplicate_huggingface_dataset(dataset, threshold=100, original_cwd=\".\"):\n    \"\"\" \"Seamlessly\" run exact deduplication as in Lee et al.\"\"\"\n    path_to_rust_code = os.path.join(original_cwd, \"dedup\", \"release\")\n    with tempfile.TemporaryDirectory() as tmpdir:\n        text_file = _write_tmp_file(dataset, dirname=tmpdir)\n        _make_suffix_array(text_file, tmpdir, path_to_rust_code)\n\n        # Run other rust code directly\n        options = f\"--length-threshold {threshold} --cache-dir {tmpdir}/cache/\"\n\n        print(\"Finding self-similar parts...\")\n        os.popen(\n            f\"{path_to_rust_code}/dedup_dataset self-similar --data-file {text_file} \" f\"{options} --num-threads {torch.get_num_threads()}\"\n        ).read()\n        print(\"Collect self-similar from all parts...\")\n        os.popen(f\"{path_to_rust_code}/dedup_dataset collect --data-file {text_file} \" f\"{options}> {tmpdir}/drop_tokens_file\").read()\n        dataset = _finish_and_return_to_hf_dataset(text_file, f\"{tmpdir}/drop_tokens_file\")\n    return dataset\n\n\ndef _write_tmp_file(dataset, dirname):\n    text_file = os.path.join(dirname, \"tmp_full_dataset_as_text\")\n\n    with open(text_file, \"wb\") as fout:\n        for example in tqdm(dataset, desc=\"Writing dataset to tmp files.\"):  # not batched...\n            fout.write((example[\"text\"] + \"<EOT>\").encode(\"utf-8\"))\n    return text_file\n\n\ndef _make_suffix_array(text_file, tmpdir, path_to_rust_code):\n    data_size = os.path.getsize(text_file)\n    HACK = 100000\n\n    started = []\n\n    if data_size > 10e9:\n        total_jobs = 100\n        jobs_at_once = 20\n    elif data_size > 1e9:\n        total_jobs = 96\n        jobs_at_once = 96\n    elif data_size > 10e6:\n        total_jobs = 4\n        jobs_at_once = 4\n    else:\n        total_jobs = 4\n        jobs_at_once = 1\n\n    S = data_size // total_jobs\n    print(\"Partition into parts and create suffix arrays...\")\n    for jobstart in range(0, total_jobs, jobs_at_once):\n        wait = []\n        for i in range(jobstart, jobstart + jobs_at_once):\n            s, e = i * S, min((i + 1) * S + HACK, data_size)\n            cmd = f\"{path_to_rust_code}/dedup_dataset make-part --data-file {text_file} --start-byte {s} --end-byte {e}\"\n            started.append((s, e))\n            # print(cmd)\n            wait.append(os.popen(cmd))\n\n            if e == data_size:\n                break\n\n        print(\"Waiting for jobs to finish\")\n        [x.read() for x in wait]\n\n    print(\"Checking all wrote correctly\")\n\n    while True:\n        files = [f\"{text_file}.part.{s}-{e}\" for s, e in started]\n\n        wait = []\n        for x, (s, e) in zip(files, started):\n            size_data = os.path.getsize(x)\n            FACT = np.ceil(np.log(size_data) / np.log(2) / 8)\n            # print(\"FACT\", FACT)\n            size_table = os.path.getsize(x + \".table.bin\")\n            if not os.path.exists(x) or not os.path.exists(x + \".table.bin\") or size_table == 0 or size_data * FACT != size_table:\n                cmd = f\"{path_to_rust_code}/dedup_dataset make-part --data-file {text_file} --start-byte {s} --end-byte {e}\"\n                # print(cmd)\n                wait.append(os.popen(cmd))\n        print(\"Rerunning\", len(wait), \"jobs because they failed.\")\n        [x.read() for x in wait]\n        time.sleep(1)\n        if len(wait) == 0:\n            break\n\n    print(\"Merging suffix trees\")\n\n    torun = \" --suffix-path \".join(files)\n    options = f\"--output-file {tmpdir}/out.table.bin --suffix-path {torun} --num-threads {torch.get_num_threads()}\"\n    print(f\"{path_to_rust_code}/dedup_dataset merge {options}\")\n    os.popen(f\"{path_to_rust_code}/dedup_dataset merge {options}\").read()\n    # exit(0)\n    print(\"Now merging individual tables\")\n    os.popen(f\"cat {tmpdir}/out.table.bin.* > {tmpdir}/out.table.bin\").read()\n    print(\"Cleaning up\")\n    os.popen(f\"mv {tmpdir}/out.table.bin {text_file}.table.bin\").read()\n\n\ndef _finish_and_return_to_hf_dataset(original_text_file, remove_file_cache):\n    \"\"\"For simplicity the entire new dataset has to fit into memory...\"\"\"\n    remove = []\n    with open(remove_file_cache) as fin:\n        for line in fin:\n            if \"out\" in line:\n                break\n        for line in fin:\n            remove.append(list(map(int, line.split())))\n        remove = remove[::-1]\n\n    print(f\"Number of removal tuples is {len(remove)}\")\n\n    with open(original_text_file, \"rb\") as original_dataset:\n        deduped_dataset = dict(text=[])\n        start = 0\n        buffer = \"\"\n        for _ in tqdm(range(len(remove)), desc=\"Writing deduplicated data back to hf dataset\"):\n            a, b = remove.pop()\n            buffer += original_dataset.read(a - start).decode(\"utf-8\", errors=\"ignore\")  # Is the error ignore here a terrible idea??\n            original_dataset.seek(b)\n            start = b\n\n            buf_split = buffer.split(\"<EOT>\")\n            if len(buf_split) > 1:\n                deduped_dataset[\"text\"] += buf_split[:-1]\n                buffer = buf_split[-1]\n        deduped_dataset[\"text\"] += (buffer + original_dataset.read().decode(\"utf-8\")).split(\"<EOT>\")[:-1]\n\n    dataset = datasets.Dataset.from_dict(deduped_dataset)\n    return dataset\n"
  },
  {
    "path": "cramming/data/pretraining_preparation.py",
    "content": "\"\"\"Prepare and preprocess datasets.\"\"\"\n\nimport torch\nimport datasets\nimport hydra\nimport pandas as pd\nimport os\nimport contextlib\nimport logging\nimport tempfile\nfrom itertools import chain\nfrom collections import defaultdict\n\nimport json\nfrom omegaconf import OmegaConf\n\nfrom .tokenizer_preparation import construct_tokenizer, load_tokenizer\nfrom .curriculum_sorting import _sort_tokenized_dataset_by_unigram, _sort_tokenized_dataset_by_token, _sort_tokenized_dataset_by_word_length\nfrom .deduplicate import deduplicate_huggingface_dataset\nfrom .utils import checksum_config, stage_dataset, detailed_OSError\nfrom .tokenizer_preparation import get_tokenizer\n\n\nimport random\nimport transformers\n\nfrom datasets.distributed import split_dataset_by_node\nimport random\n\nfrom torch.utils.data import DataLoader\nfrom typing import Dict\n\n\nlog = logging.getLogger(__name__)\ndatasets.enable_progress_bar()\ndatasets.disable_caching()  # We'll save only the final preprocessed dataset\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef get_num_workers(cfg_impl):\n    if cfg_impl is None:\n        return 0\n    elif cfg_impl.threads > 0:\n        return min(torch.get_num_threads() // max(1, torch.cuda.device_count()), cfg_impl.threads)\n    else:\n        return 0\n\n\ndef load_pretraining_corpus(cfg_data, cfg_impl, data_dir: str = None):\n    \"\"\"Load (and optionally stage) a pre-processed corpus. Create one if it doesn't exist.\"\"\"\n    datasets.disable_caching()\n    checksum = checksum_config(cfg_data)\n\n    data_path = data_dir\n    if data_path is None:\n        data_path = cfg_impl.path\n    data_src = list(cfg_data.sources.values())[0]\n    provider = data_src[\"provider\"]\n    tokenizer_type = data_src[\"tokenizer_type\"]\n    if provider == \"fake\":\n        # Shortcut for fake data\n        return _load_fake_dataset(cfg_data, data_src, path=cfg_impl.path)\n    elif provider == \"hub\":\n        # pulling from huggingface\n        return _load_from_hub(cfg_data, data_path)\n    elif provider == \"arithmetic\":\n        # our math data\n        tokenized_dataset_path = data_src[\"tokenized_dataset_path\"]\n        tokenized_dataset_path = os.path.join(data_path, tokenized_dataset_path)\n        print(f\"Loading tokenized dataset from {tokenized_dataset_path}\")\n        tokenized_data = load_tokenized_data(tokenized_dataset_path)\n        print(f\"Loaded tokenized dataset from {tokenized_dataset_path}\")\n        tokenizer = get_tokenizer(tokenizer_type)\n        print(f\"Loaded tokenizer {tokenizer_type}\")\n        tokenizer.model_max_length = cfg_data[\"seq_length\"]  # not perfect but better than nothing\n        return tokenized_data, tokenizer\n    else:\n        # not found so creating\n        try:\n            if cfg_impl.local_staging_dir is not None:\n                with main_process_first():\n                    data_path = stage_dataset(data_path, cfg_impl.local_staging_dir)\n            # Load already processed dataset\n            tokenized_dataset = datasets.load_from_disk(data_path)\n            tokenizer = load_tokenizer(\n                os.path.join(data_path, \"tokenizer\"),\n                seq_length=cfg_data.seq_length,\n                vocab_size=cfg_data.vocab_size,\n                cache_dir=cfg_impl.path,\n            )\n        except FileNotFoundError:\n            if cfg_impl.forbid_dataset_preprocessing:\n                raise ValueError(\n                    f\"Cannot find processed at path {data_path}. Dataset preprocessing disabled. \"\n                    \"Dataset preprocessing can be enabled with 'impl.forbid_dataset_preprocessing=False'.\"\n                )\n            # Run preprocessing to create dataset\n            with main_process_first():\n                num_threads = min(torch.get_num_threads(), cfg_impl.threads)  # Mitigate worker overloading\n                preprocessed_dataset, new_tokenizer = preprocess_dataset(\n                    cfg_data,\n                    download_path=cfg_impl.path,\n                    num_threads=num_threads,\n                    max_raw_chunk_size=cfg_impl.max_raw_chunk_size,\n                )\n\n                def save_corpus(path):\n                    preprocessed_dataset.save_to_disk(path)\n                    new_tokenizer.save_pretrained(os.path.join(path, \"tokenizer\"))\n                    with open(os.path.join(path, \"model_config.json\"), \"w\") as file:\n                        json.dump(OmegaConf.to_container(cfg_data, resolve=True), file)\n\n                if not cfg_impl.temporary_corpus:\n                    # Save to base directory:\n                    save_corpus(os.path.join(cfg_impl.path, processed_dataset_dir))\n                    if cfg_impl.local_staging_dir is not None:\n                        # Optionally also copy into local staging directory\n                        data_path = stage_dataset(data_path, cfg_impl.local_staging_dir)\n                else:\n                    # Directly use staging directory\n                    save_corpus(os.path.join(cfg_impl.local_staging_dir, processed_dataset_dir))\n\n            # Reload dataset\n            tokenized_dataset = datasets.load_from_disk(data_path)\n            tokenizer = load_tokenizer(\n                os.path.join(data_path, \"tokenizer\"),\n                seq_length=cfg_data.seq_length,\n                vocab_size=cfg_data.vocab_size,\n                cache_dir=cfg_impl.path,\n            )\n\n    # Cast to tensors after loading from arrow:\n    tokenized_dataset.set_format(\"torch\")\n\n    # 4) Log overviews so we always know what's going on with weird tokenization tricks\n    dataset_size = len(tokenized_dataset[\"train\"])\n    random_sentence_idx = torch.randint(0, dataset_size, (1,)).item()\n    input_data = tokenized_dataset[\"train\"][random_sentence_idx][\"input_ids\"].squeeze()  # squeeze because hf has leading dim\n\n    log.info(f\"Random sentence with seq_length {tokenizer.model_max_length} from dataset of size {dataset_size:,}: ...\")\n    log.info(tokenizer.batch_decode(input_data[None])[0])\n    log.info(\"above is tokenized into below with _ joined to every token\")\n    log.info(\"_\".join(tokenizer.decode(t) for t in input_data))\n    return tokenized_dataset, tokenizer\n\ndef load_tokenized_data(tokenized_dataset_path):\n    tokenized_dataset = datasets.load_from_disk(tokenized_dataset_path)\n    return tokenized_dataset\n\ndef convert_to_hf_dataset(tokenized_data):\n    # Convert the PyTorch tensor to a list of lists (if it's not already)\n    data_list = tokenized_data.tolist()\n\n    # Create a DataFrame from the list\n    df = pd.DataFrame({'tokens': data_list})\n\n    # Convert the DataFrame to a Hugging Face dataset\n    hf_dataset = datasets.Dataset.from_pandas(df)\n    return hf_dataset\n\ndef preprocess_dataset(cfg_data, download_path, num_threads=1, max_raw_chunk_size=1e14):\n    \"\"\"A lot of loading and preprocessing.\"\"\"\n    # 1) Collect raw source datasets\n    raw_datasets = []\n    for name, details in cfg_data.sources.items():\n        log.info(f\"Now preparing source {name}...\")\n        if details.provider == \"huggingface\":\n            if name == \"EleutherAI/proof-pile-2\":\n                raw_dataset = datasets.load_dataset(\n                    name,\n                    name=details.partition,\n                    split=details.split,\n                    cache_dir=download_path,\n                    streaming=details.streaming,\n                )\n            else:              \n                raw_dataset = datasets.load_dataset(\n                    name,\n                    data_dir=details.partition,\n                    split=details.split,\n                    cache_dir=download_path,\n                    streaming=details.streaming,\n                )\n        elif details.provider == \"local\":\n            raw_dataset = datasets.load_dataset(details.file_type, data_files=details.files, streaming=details.streaming)[details.split]\n        else:\n            raise ValueError(f\"Invalid data provider {details.provider} given.\")\n\n        # remove columns that break later processing steps\n        if details.remove_columns is not None:\n            raw_dataset = raw_dataset.remove_columns(details.remove_columns)\n        # Filter?\n        if getattr(details, \"filter\", None) is not None:\n\n            def filter_fn(entry):\n                \"\"\"Assume a metadata key 'meta' is present\"\"\"\n                for key, values in details.filter.items():\n                    if entry[\"meta\"][key] in values:\n                        return True\n                return False\n\n            raw_dataset = raw_dataset.filter(filter_fn)\n        # move streams to fixed datasets to make everything sane (and to allow concatenation with unstreamed data)\n        if details.streaming:\n            raw_dataset = raw_dataset.take(int(cfg_data.max_entries_in_raw_dataset))\n            raw_dataset = _move_stream_to_fixed_map(raw_dataset, cfg_data.max_entries_in_raw_dataset, max_raw_chunk_size)\n        else:\n            if cfg_data.max_entries_in_raw_dataset < len(raw_dataset):\n                raw_dataset = raw_dataset.select(range(int(cfg_data.max_entries_in_raw_dataset)))\n        # concatenate dataset that were cut into pieces that are too small\n        if details.concatenate_successive_entries > 0:\n            raw_dataset = _concatenate_entries(raw_dataset, details.concatenate_successive_entries, num_threads=num_threads)\n        raw_datasets += [raw_dataset]\n\n    # 2) Preprocess and tokenize\n    raw_data = datasets.concatenate_datasets(raw_datasets)\n    raw_data = raw_data.shuffle(seed=89)  # Shuffle once here so that multiproc has shards of similar size!\n    # This shuffle is crucial for fast multiprocessing tokenization\n    # because datasets.map uses a contiguous sharding under the hood.\n\n    # However, we also shuffle so we can now select a smaller range:\n    if cfg_data.max_entries_in_raw_dataset < len(raw_data):\n        raw_data = raw_data.select(range(int(cfg_data.max_entries_in_raw_dataset)))\n\n    raw_data = raw_dataset_preprocessing(raw_data, num_threads, cfg_data)  # This is by default a no-op, but can be dedup, filtering...\n    tokenizer = construct_tokenizer(raw_data, cfg_data, path=download_path)\n    tokenized_dataset = _huggingface_preprocessing(raw_data, tokenizer, cfg_data, num_threads=num_threads)  # Tokenize, group, sort...\n\n    return tokenized_dataset, tokenizer\n\n\ndef _move_stream_to_fixed_map(raw_data_streamed, max_entries_in_raw_dataset, max_raw_chunk_size=1e14):\n    \"\"\"Save streaming dataset to a fixed mapping-style database.\"\"\"\n    # I'm tired of IterableDatasets and will take the performance hit to write them out instead:\n    try:\n        if max_raw_chunk_size > max_entries_in_raw_dataset:\n            with tempfile.TemporaryDirectory() as tmpdirname:\n                datasets.Dataset.from_dict(dict(text=[v[\"text\"] for v in raw_data_streamed])).save_to_disk(tmpdirname + \"raw_data\")\n                raw_data_mapped = datasets.load_from_disk(tmpdirname + \"raw_data\")\n            # This used to be only a move into RAM but this breaks memory later using C4:\n            # raw_data = datasets.Dataset.from_dict(dict(text=[v[\"text\"] for v in raw_data]))\n            return raw_data_mapped\n        else:\n            with tempfile.TemporaryDirectory() as tmpdirname:\n                mapped_sets = []\n                data_in_RAM = defaultdict(list)\n                for idx, value_stream in enumerate(raw_data_streamed):\n                    data_in_RAM[\"text\"].append(value_stream[\"text\"])\n                    if ((idx + 1) % max_raw_chunk_size == 0) or ((idx - 1) == max_entries_in_raw_dataset):\n                        datasets.Dataset.from_dict(data_in_RAM).save_to_disk(tmpdirname + \"raw_data\" + str(idx))\n                        mapped_dataset = datasets.load_from_disk(tmpdirname + \"raw_data\" + str(idx))\n                        log.info(\n                            f\"Saved temporary copy at idx {idx} of {max_entries_in_raw_dataset} at {tmpdirname + 'raw_data' + str(idx)}.\"\n                        )\n                        data_in_RAM[\"text\"] = []\n                        mapped_sets.append(mapped_dataset)\n            return datasets.concatenate_datasets(mapped_sets)\n    except OSError as e:\n        detailed_OSError(e)\n\n\ndef _huggingface_preprocessing(raw_dataset, tokenizer, cfg_data, num_threads=4):\n    \"\"\"Dataset preprocessing and tokenization.\n\n    This is basically the default HF routine from\n    https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_mlm.py\n    \"\"\"\n    # Preprocessing the datasets.\n    # First we tokenize all the texts.\n    column_names = getattr(raw_dataset, \"column_names\", \"text\")\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n\n    max_seq_length = tokenizer.model_max_length\n    map_setup = dict(\n        batched=True,\n        batch_size=512,\n        num_proc=num_threads if num_threads > 0 else None,\n        # load_from_cache_file=False,\n        # keep_in_memory=False,\n    )\n    parellism_flag = os.environ[\"TOKENIZERS_PARALLELISM\"]\n    if num_threads > 0:\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.\n    # The Collator is modified not to read special_masks anyway:\n\n    def tokenize_function(examples):\n        return tokenizer(\n            examples[text_column_name],\n            return_special_tokens_mask=False,\n            return_attention_mask=False,  # handle this manually elsewhere if necessary\n            return_token_type_ids=False,\n        )\n\n    tokenizer.model_max_length = 1e30\n    tokenized_dataset = raw_dataset.map(\n        tokenize_function, remove_columns=column_names, desc=\"Running tokenizer on every text in dataset\", **map_setup\n    )\n    tokenizer.model_max_length = max_seq_length\n\n    # Main data processing function that will concatenate all texts from our dataset and generate chunks of\n    # max_seq_length.\n    def group_texts(examples):\n        # Concatenate all texts.\n        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n        # customize this part to your needs.\n        if total_length >= max_seq_length:\n            total_length = (total_length // max_seq_length) * max_seq_length\n        # Split by chunks of max_len.\n        result = {k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] for k, t in concatenated_examples.items()}\n        return result\n\n    tokenized_dataset = tokenized_dataset.map(group_texts, desc=f\"Grouping texts in chunks of {max_seq_length}\", **map_setup)\n\n    # Reduce size to maximal limit:\n    if cfg_data.max_seq_in_tokenized_dataset < len(tokenized_dataset):\n        tokenized_dataset = tokenized_dataset.select(range(int(cfg_data.max_seq_in_tokenized_dataset)), keep_in_memory=True)\n\n    # Split into train-val\n    tokenized_dataset = tokenized_dataset.train_test_split(test_size=cfg_data.validation_seqs, shuffle=False)\n\n    # Shuffle?\n    if cfg_data.ordering == \"randomized\":\n        tokenized_dataset[\"train\"] = tokenized_dataset[\"train\"].shuffle(seed=233)\n    elif cfg_data.ordering == \"unigram-curriculum\":\n        tokenized_dataset[\"train\"] = _sort_tokenized_dataset_by_unigram(tokenized_dataset[\"train\"], tokenizer, num_threads)\n    elif cfg_data.ordering == \"word-length-curriculum\":\n        tokenized_dataset[\"train\"] = _sort_tokenized_dataset_by_word_length(tokenized_dataset[\"train\"], tokenizer, num_threads)\n    elif cfg_data.ordering == \"sentence-length-curriculum\":\n        tokenized_dataset[\"train\"] = _sort_tokenized_dataset_by_token(\n            tokenized_dataset[\"train\"],\n            tokenizer,\n            tokenizer.vocab[\" .\"],\n            num_threads,\n        )\n    elif cfg_data.ordering == \"fragment-curriculum\":\n        tokenized_dataset[\"train\"] = _sort_tokenized_dataset_by_token(\n            tokenized_dataset[\"train\"],\n            tokenizer,\n            tokenizer.vocab[\"<eot>\"],\n            num_threads,\n        )\n    else:\n        raise ValueError(f\"Invalid dataset ordering {cfg_data.ordering} provided.\")\n\n    # Finally flatten\n    # This is necessary for the save_to_disk call that comes next. If skipped here, the call will be invoked from save_to_disk\n    # This way, atleast it shares the same batch parameters and prints a progress bar.\n    tokenized_dataset = tokenized_dataset.map(desc=\"Flattening the indices\", **map_setup)\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = parellism_flag\n    return tokenized_dataset\n\n\ndef _load_fake_dataset(cfg_data, details, path=None):\n    tokenizer = load_tokenizer(cfg_data.tokenizer, cfg_data.seq_length, cfg_data.vocab_size, cache_dir=path)\n    tokenizer.model_max_length = cfg_data.seq_length\n    generator = torch.Generator()\n    generator.manual_seed(details.randgen_seed)\n    dataset = torch.randint(0, cfg_data.vocab_size, (details.size, cfg_data.seq_length), generator=generator)\n    return dataset, tokenizer\n\n\ndef _concatenate_entries(dataset, num_entries_in_group, num_threads):\n    parellism_flag = os.environ[\"TOKENIZERS_PARALLELISM\"]\n    if num_threads > 0:\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    def group_texts(examples):\n        result = dict()\n        for key, entries in examples.items():\n            reduced_list = []\n            state, num_collected = None, 0\n            for entry in entries:\n                num_collected += 1\n                if num_collected == 1:\n                    state = entry\n                else:\n                    state += entry\n                if num_collected == num_entries_in_group:\n                    reduced_list.append(state)\n                    state, num_collected = None, 0\n\n            result[key] = reduced_list\n\n        return result\n\n    map_setup = dict(\n        batched=True,\n        batch_size=512,\n        num_proc=num_threads if num_threads > 0 else None,\n        # load_from_cache_file=False,\n        # keep_in_memory=True,\n    )\n    dataset = dataset.map(group_texts, desc=\"Concatenating examples\", **map_setup)\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = parellism_flag\n    return dataset\n\n\ndef raw_dataset_preprocessing(raw_dataset, num_threads, cfg_data):\n    \"\"\"Some dataset \"improvements\". These are optional filtering or normalization rules that are only applied to the pretraining corpus.\n    This separates them from generic normalizations that are baked into the tokenizer.\"\"\"\n    column_names = getattr(raw_dataset, \"column_names\", \"text\")\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n    known_tokens = []\n    map_setup = dict(\n        batched=True,\n        batch_size=512,\n        num_proc=None,  # a bit messy but c4 in RAM can be overbearing otherwise\n    )\n    parellism_flag = os.environ[\"TOKENIZERS_PARALLELISM\"]\n    if num_threads > 0:\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    if cfg_data.remove_trash:\n        # experimental first test based on Unigram tokenization:\n        from transformers import AutoTokenizer\n\n        if cfg_data.remove_trash == \"self\":\n            os.environ[\"TOKENIZERS_PARALLELISM\"] = parellism_flag\n            tokenizer = construct_tokenizer(raw_dataset, cfg_data, path=None)\n            if num_threads > 0:\n                os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n        else:\n            tokenizer = AutoTokenizer.from_pretrained(\"albert-base-v2\")\n        tokenizer.model_max_length = 1e30\n\n        def filtering_rule(examples):\n            tokenized = tokenizer(examples[text_column_name])[\"input_ids\"]\n            return [len(t) < cfg_data.trash_cutoff * len(e) for t, e in zip(tokenized, examples[text_column_name])]\n\n        log.info(f\"Size of dataset before trash removal: {len(raw_dataset)}.\")\n        raw_dataset = raw_dataset.filter(\n            filtering_rule,\n            desc=\"Filter sentences that cannot be tokenized well.\",\n            **map_setup,\n        )\n        log.info(f\"Size of filtered dataset: {len(raw_dataset)}.\")\n\n    if cfg_data.deduplicate_entries:\n        log.info(f\"Size of dataset before deduplication: {len(raw_dataset)}.\")\n        raw_dataset = deduplicate_huggingface_dataset(\n            raw_dataset, threshold=cfg_data.deduplication_threshold, original_cwd=hydra.utils.get_original_cwd()\n        )\n        log.info(f\"Size of deduplicated dataset: {len(raw_dataset)}.\")\n\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = parellism_flag\n    return raw_dataset\n\n\n@contextlib.contextmanager\ndef main_process_first():\n    \"\"\"\n    A context manager for torch distributed environment where on needs to do something on the main process, while\n    blocking replicas, and when it's finished releasing the replicas.\n    One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process,\n    which upon completion saves a cached version of results and which then automatically gets loaded by the\n    replicas.\n\n    This is a stripped-down version of the the huggingface context manager from commit 2eb7bb15e771f13192968cd4657c78f76b0799fe\n    \"\"\"\n    if torch.distributed.is_initialized():\n        is_main_process = torch.distributed.get_rank() == 0\n        try:\n            if not is_main_process:\n                # tell all replicas to wait\n                torch.distributed.barrier()\n            yield\n        finally:\n            if is_main_process:\n                torch.distributed.barrier()\n    else:\n        yield\n\n\ndef _load_from_hub(cfg_data, data_path):\n    from huggingface_hub import hf_hub_download\n\n    tokenized_dataset = datasets.load_dataset(cfg_data.hf_location, \"train\", streaming=cfg_data.streaming, cache_dir=data_path)[\"train\"]\n    tokenized_dataset = tokenized_dataset.with_format(\"torch\")\n\n    tokenizer_req_files = [\"special_tokens_map.json\", \"tokenizer.json\", \"tokenizer_config.json\"]\n    os.makedirs(os.path.join(data_path, \"tokenizer\"), exist_ok=True)\n    for file in tokenizer_req_files:\n        hf_hub_download(\n            cfg_data.hf_location,\n            file,\n            subfolder=\"tokenizer\",\n            repo_type=\"dataset\",\n            local_dir=os.path.join(data_path),\n        )\n    tokenizer = load_tokenizer(os.path.join(data_path, \"tokenizer\"), seq_length=cfg_data.seq_length, cache_dir=data_path)\n    return tokenized_dataset, tokenizer\n\n\ndef prepare_dataloaders(datasets, tokenizer, cfg_train, cfg_impl) -> Dict[str, DataLoader]:\n    dataloaders = dict()\n    train_loader = prepare_pretraining_dataloader(datasets[\"train\"], tokenizer, cfg_train, cfg_impl)\n    dataloaders[\"train\"] = train_loader\n    dataloaders[\"test\"] = prepare_validation_dataloader(datasets[\"test\"], tokenizer, cfg_impl)\n    return dataloaders\n\n\ndef prepare_pretraining_dataloader(dataset, tokenizer, cfg_train, cfg_impl) -> torch.utils.data.DataLoader:\n\n    num_workers = get_num_workers(cfg_impl)\n    collate_fn = FastDataCollatorForLanguageModeling(tokenizer=tokenizer, pad_to_multiple_of=8, mlm=False)\n\n    if dataset is None:\n        # generate data at runtime\n        return RuntimeInfiniteDataLoader(tokenizer, device)\n    elif isinstance(dataset, torch.utils.data.IterableDataset):\n        # streaming mode for ready-made datasets, speed not tested\n        if torch.distributed.is_initialized():\n            dataset = split_dataset_by_node(dataset, rank=int(os.environ[\"RANK\"]), world_size=int(os.environ[\"WORLD_SIZE\"]))\n\n        if cfg_impl.shuffle_in_dataloader:\n            dataset = dataset.shuffle(seed=42, buffer_size=256)\n        if cfg_train.reverse_dataset_order:\n            raise ValueError(\"Reverse stream not implemented.\")\n        sampler = None\n    else:\n        # Normally, we'd just use nice map-style datasets:\n        if torch.distributed.is_initialized():\n            sampler = torch.utils.data.distributed.DistributedSampler(\n                dataset,\n                shuffle=cfg_impl.shuffle_in_dataloader,\n                drop_last=True,\n            )\n        else:\n            if cfg_impl.shuffle_in_dataloader:\n                sampler = torch.utils.data.RandomSampler(dataset)\n            else:\n                sampler = torch.utils.data.SequentialSampler(dataset)\n\n    if cfg_train.reverse_dataset_order:\n        dataset = dataset.select(reversed(range(len(dataset))))\n    repeated_dataloader = InfiniteDataLoader(\n        dataset,\n        sampler=sampler,\n        batch_size=min(cfg_impl.microbatch_size, len(dataset)),\n        num_workers=num_workers,\n        pin_memory=cfg_impl.pin_memory,\n        drop_last=True,\n        prefetch_factor=cfg_impl.prefetch_factor if num_workers > 0 else None,\n        persistent_workers=cfg_impl.persistent_workers if num_workers > 0 else False,\n        collate_fn=collate_fn,\n    )\n    return repeated_dataloader\n\n\ndef prepare_validation_dataloader(dataset, tokenizer, cfg_impl):\n\n    num_workers = get_num_workers(cfg_impl)\n    collate_fn = FastDataCollatorForLanguageModeling(tokenizer=tokenizer, pad_to_multiple_of=8, mlm=False)\n    if dataset is None:\n        # generate data at runtime\n        return RuntimeInfiniteDataLoader(tokenizer, device)\n    elif isinstance(dataset, torch.utils.data.IterableDataset):\n        sampler = None\n    else:\n        sampler = torch.utils.data.SequentialSampler(dataset)\n\n    dataloader = torch.utils.data.DataLoader(\n        dataset,\n        sampler=sampler,\n        batch_size=min(cfg_impl.microbatch_size, len(dataset)),\n        num_workers=num_workers,\n        pin_memory=cfg_impl.pin_memory,\n        drop_last=True,  # better make it fit elsewhere\n        prefetch_factor=cfg_impl.prefetch_factor if num_workers > 0 else None,\n        persistent_workers=False,\n        collate_fn=collate_fn,\n    )\n    return dataloader\n\n\n\"\"\"This is a minor modification of huggingface's toking masking:\"\"\"\n\"\"\"original source:\nhttps://github.com/huggingface/transformers/blob/130b987880a9b1ade5c76dc1413c12c8924fda50/src/transformers/data/data_collator.py#L748\nat commit f00f22a3e290fd377b979124dcf9800b3d73eb11\"\"\"\n\n\nclass FastDataCollatorForLanguageModeling(transformers.DataCollatorForLanguageModeling):\n    def __init__(self, *args, create_labels_entry=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.mlm = False\n        self.create_labels_entry = create_labels_entry\n\n    def torch_call(self, examples):\n        \"\"\"Simplified call assuming all dicts in the list of examples have the same layout and contain tensors.\n        Assume further that all these tensors contain vectors of Long Tensors  [AND THEY HAVE TO BE LONG]\"\"\"\n        if isinstance(examples[0], torch.Tensor):\n            examples = [{\"input_ids\": ex} for ex in examples]\n        # So this is the handmade version\n        batch = dict()\n        for key in examples[0].keys():\n            elem = torch.as_tensor(examples[0][key])\n            out = None\n            if torch.utils.data.get_worker_info() is not None:\n                storage = elem._typed_storage()._new_shared(len(examples) * elem.shape[0], device=elem.device)\n                out = elem.new(storage).resize_(len(examples), elem.shape[0])\n\n            batch[key] = torch.stack([torch.as_tensor(example[key]) for example in examples], 0, out=out).contiguous()\n\n        if self.create_labels_entry:\n            labels = batch[\"input_ids\"].clone()\n            if self.tokenizer.pad_token_id is not None:\n                labels[labels == self.tokenizer.pad_token_id] = -100\n            batch[\"labels\"] = labels\n        return batch\n\n\nclass InfiniteDataLoader(torch.utils.data.DataLoader):\n    \"\"\"Lazy copy-paste from https://gist.github.com/MFreidank/821cc87b012c53fade03b0c7aba13958.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        # Initialize an iterator over the dataset.\n        self.dataset_iterator = super().__iter__()\n        self.epoch_counter = 0\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        try:\n            batch = next(self.dataset_iterator)\n        except StopIteration:\n            # Dataset exhausted, use a new fresh iterator.\n            self.dataset_iterator = super().__iter__()\n            self.epoch_counter += 1\n            if hasattr(self.sampler, \"set_epoch\"):\n                self.sampler.set_epoch(self.epoch_counter)\n            batch = next(self.dataset_iterator)\n        return batch\n\n    def set_epoch(self, epoch: int):\n        self.epoch_counter = epoch\n\nclass RuntimeInfiniteDataLoader(torch.utils.data.DataLoader):\n    \"\"\"Lazy copy-paste from https://gist.github.com/MFreidank/821cc87b012c53fade03b0c7aba13958.\"\"\"\n\n    def __init__(self, tokenizer, device, *args, **kwargs):\n        self.epoch_counter = 0\n        ## All need to be moved to cfg\n        self.max_n = 20\n        self.max_m = 20\n        self.batch_size = 16\n        self.reverse_answer = False\n        self.reverse_all = False\n        self.operation = '+'\n\n        self.tokenizer = tokenizer\n        self.eos_token_id = self.tokenizer.vocab[self.tokenizer.eos_token]\n        self.device = device\n        self.current_batch = []\n\n    def get_arithmetic(self, n, m):\n        batch = []\n        for _ in range(self.batch_size):\n            num1 = random.randint((10**(n-1)), (10**n - 1))\n            num2 = random.randint(10**(m-1), 10**m - 1)\n\n            num1_str = str(num1)\n            num2_str = str(num2)\n\n            result = num1 + num2\n\n            result = str(result)\n\n            if self.reverse_answer:\n                result = result[::-1]\n            if self.reverse_all:\n                result = result[::-1]\n                num1_str = num1_str[::-1]\n                num2_str = num2_str[::-1]\n\n            batch.append(f\"{num1_str}{self.operation}{num2_str}={result}\")\n\n        return batch\n\n    def tokenize_batch(self, batch):\n        # todo this can be sped up using the HF dataset.map\n        tokenized_list = [self.tokenizer(entry)[\"input_ids\"] + [self.eos_token_id] for entry in batch]\n\n        max_length = max(len(entry) for entry in tokenized_list)\n        pad_token_id = self.tokenizer.pad_token_id\n        tokenized_list = [entry + [pad_token_id] * (max_length - len(entry)) for entry in tokenized_list]\n\n        tokenized_tensor = torch.tensor(tokenized_list, device=self.device)\n        return tokenized_tensor\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        n = random.randint(1, self.max_n)\n        m = random.randint(1, self.max_m)\n        batch = self.get_arithmetic(n, m)\n        tokenized_batch = self.tokenize_batch(batch)\n        return {'input_ids': tokenized_batch, 'max_recur': max(n, m)+5}\n"
  },
  {
    "path": "cramming/data/tokenizer_preparation.py",
    "content": "\"\"\"Tokenizer functionality.\n\nNote: CANNOT name this file \"tokenizers.py ;>\n\"\"\"\n\nfrom transformers import AutoTokenizer, PreTrainedTokenizerFast\nfrom tokenizers import Tokenizer, models, normalizers, pre_tokenizers, decoders, trainers, Regex, processors\nfrom cramming.data.arithmetic_tokenizers import CustomCharLevelTokenizerForAddingPadding, CustomCharLevelTokenizerForAddingPaddingWithIndexHints, CustomCharLevelTokenizerSort\n\n\n\ndef get_tokenizer(tokenizer_type: str):\n    \"\"\"Get an arithemtic tokenizer\"\"\"\n    if tokenizer_type == \"pad\":\n        tokenizer = CustomCharLevelTokenizerForAddingPadding()\n    elif tokenizer_type == \"index\":\n        tokenizer = CustomCharLevelTokenizerForAddingPaddingWithIndexHints()\n    elif tokenizer_type == \"sort\":\n        # also has the index hints charset\n        tokenizer = CustomCharLevelTokenizerSort()\n    else:\n        print(\"tokenizer not found\")\n        exit()\n    return tokenizer\n\n\ndef load_tokenizer(tokenizer_path_or_name, seq_length=512, vocab_size=None, cache_dir=None):\n    \"\"\"Load a tokenizer from disk/huggingface. This will never construct a new tokenizer.\"\"\"\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_or_name, model_max_length=seq_length)\n    except FileNotFoundError:\n        tokenizer = _download_tokenizer(tokenizer_path_or_name, seq_length, cache_dir)\n    if vocab_size is not None and tokenizer.vocab_size != vocab_size:\n        raise ValueError(f\"Loaded tokenizer with vocab_size {tokenizer.vocab_size} incompatible with given vocab size {vocab_size}.\")\n    return tokenizer\n\n\ndef construct_tokenizer(raw_datasets, cfg_data, path, known_tokens=[]):\n    \"\"\"Construct a new tokenizer. This may include downloading from huggingface.\"\"\"\n    if cfg_data.tokenizer not in [\"BPE\", \"Unigram\", \"WordLevel\", \"WordPiece\", \"WordPieceBERT\", \"SentencePieceUnigram\", \"SentencePieceBPE\",\"starcoder\"]:\n        tokenizer = _download_tokenizer(cfg_data.tokenizer, cfg_data.seq_length, cache_dir=path)\n    else:\n        tokenizer = _construct_tokenizer(raw_datasets, cfg_data, known_tokens)\n    tokenizer.name = f\"{cfg_data.tokenizer}-{cfg_data.name}-{cfg_data.vocab_size}.json\"\n    return tokenizer\n\n\ndef _download_tokenizer(tokenizer_path_or_name, seq_length, cache_dir=None):\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_or_name, cache_dir=cache_dir)\n        tokenizer.model_max_length = seq_length\n    except OSError as error_msg:\n        raise OSError(f\"Invalid huggingface tokenizer {tokenizer_path_or_name} given: {error_msg}\")\n    return tokenizer\n\n\ndef _get_sane_token_args():\n    return dict(\n        pad_token=\"<pad>\",\n        bos_token=\"<eot>\",\n        eos_token=\"<eot>\",\n        sep_token=\"<eot>\",\n        unk_token=\"<unk>\",\n    )\n\n\ndef _get_sane_normalizers(force_english_keyboard=False, force_lowercase=False, strip_accents=False, whitespace_escape=False, sanity=False):\n    \"\"\"original rules as in XLNET with optional modifications. force_english_keyboard is actually an ascii normalization.\"\"\"\n    if sanity:\n        return normalizers.BertNormalizer(lowercase=force_lowercase)\n    normalize_ops = []\n    normalize_ops.append(normalizers.Replace(\"``\", '\"'))\n    normalize_ops.append(normalizers.Replace(\"''\", '\"'))\n    normalize_ops.append(normalizers.NFD() if strip_accents else normalizers.NFKC())\n    if force_lowercase:\n        normalize_ops.append(normalizers.Lowercase())\n    if strip_accents:\n        normalize_ops.append(normalizers.StripAccents())\n    normalize_ops.append(normalizers.Replace(Regex(\" {2,}\"), \" \"))\n    if force_english_keyboard:\n        normalize_ops.append(normalizers.Replace(Regex(r\"[^\\x00-\\x7F]+\"), \"\"))  # start from 00 instead of 1F to include tab\n    return normalizers.Sequence(normalize_ops)\n\n\ndef _construct_tokenizer(raw_datasets, cfg_data, known_tokens=[]):\n    \"\"\"The actual generation instructions for a new tokenizer. Might make this more scriptable in the future...\n\n    Follows closely along with https://huggingface.co/course/chapter6\"\"\"\n    try:\n        len_dataset = len(raw_datasets)\n\n        def batch_iterator(batch_size=1024):\n            for i in range(0, len_dataset, batch_size):\n                try:\n                    yield raw_datasets[i : i + batch_size][\"content\"]\n                except:\n                    yield raw_datasets[i : i + batch_size][\"text\"]\n\n    except TypeError:\n        # streaming dataset\n        len_dataset = int(cfg_data.max_entries_in_dataset)\n\n        def batch_iterator():\n            for entry in iter(raw_datasets):\n                try:\n                    yield entry[\"content\"]\n                except:\n                    yield entry[\"text\"]\n\n    special_token_args = _get_sane_token_args()\n    normalizer_sequence = _get_sane_normalizers(**cfg_data.normalizer)\n    # Outline tokenizer rules:\n    if cfg_data.tokenizer == \"Unigram\":  # without the sentencepice part\n        tokenizer = Tokenizer(models.Unigram())\n        tokenizer.add_tokens(known_tokens)\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()\n        # tokenizer.decoder = None\n        special_tokens = list(set(v for k, v in special_token_args.items()))\n\n        trainer = trainers.UnigramTrainer(\n            vocab_size=cfg_data.vocab_size,\n            special_tokens=special_tokens,\n            unk_token=special_token_args[\"unk_token\"],\n        )\n    elif cfg_data.tokenizer == \"BPE\":\n        tokenizer = Tokenizer(models.BPE())\n        tokenizer.add_tokens(known_tokens)\n\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)\n        tokenizer.decoder = decoders.ByteLevel()\n\n        trainer = trainers.BpeTrainer(\n            vocab_size=cfg_data.vocab_size,\n            min_frequency=2,\n            special_tokens=list(set(special_token_args.values())),\n            initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),\n        )\n    elif cfg_data.tokenizer == \"WordPiece\":\n        tokenizer = Tokenizer(models.WordPiece(unk_token=special_token_args[\"unk_token\"]))\n        tokenizer.add_tokens(known_tokens)\n\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()\n        tokenizer.decoder = decoders.WordPiece(prefix=\"##\")\n\n        trainer = trainers.WordPieceTrainer(vocab_size=cfg_data.vocab_size, special_tokens=list(set(special_token_args.values())))\n    elif cfg_data.tokenizer == \"WordPieceBERT\":\n        # Sanity check tokenizer\n        tokenizer = Tokenizer(models.WordPiece(unk_token=\"<unk>\"))\n        tokenizer.add_tokens(known_tokens)\n        tokenizer.normalizer = normalizers.BertNormalizer()\n        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()\n        tokenizer.decoder = decoders.WordPiece(prefix=\"##\")\n\n        trainer = trainers.WordPieceTrainer(vocab_size=cfg_data.vocab_size, special_tokens=list(set(special_token_args.values())))\n    elif cfg_data.tokenizer == \"WordLevel\":\n        tokenizer = Tokenizer(models.WordLevel(unk_token=special_token_args[\"unk_token\"]))\n        tokenizer.add_tokens(known_tokens)\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()\n        trainer = trainers.WordLevelTrainer(vocab_size=cfg_data.vocab_size, special_tokens=list(set(special_token_args.values())))\n    elif cfg_data.tokenizer == \"SentencePieceBPE\":\n        \"\"\"ref https://github.com/huggingface/tokenizers/blob/main/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py\"\"\"\n        tokenizer = Tokenizer(models.BPE())\n        tokenizer.add_tokens(known_tokens)\n\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.Sequence(\n            [pre_tokenizers.Metaspace(replacement=\"▁\", add_prefix_space=True), pre_tokenizers.ByteLevel(add_prefix_space=False)],\n        )\n        tokenizer.decoder = decoders.Sequence([decoders.ByteLevel(), decoders.Metaspace(replacement=\"▁\", add_prefix_space=True)])\n\n        trainer = trainers.BpeTrainer(\n            vocab_size=cfg_data.vocab_size,\n            min_frequency=2,\n            special_tokens=list(set(special_token_args.values())),\n            initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),\n        )\n    elif cfg_data.tokenizer == \"SentencePieceUnigram\":\n        tokenizer = Tokenizer(models.Unigram())\n        tokenizer.add_tokens(known_tokens)\n        tokenizer.normalizer = normalizer_sequence\n        tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=\"▁\", add_prefix_space=True)\n        tokenizer.decoder = decoders.Metaspace(replacement=\"▁\", add_prefix_space=True)\n        special_tokens = list(set(v for k, v in special_token_args.items()))\n\n        trainer = trainers.UnigramTrainer(\n            vocab_size=cfg_data.vocab_size,\n            special_tokens=special_tokens,\n            unk_token=special_token_args[\"unk_token\"],\n        )\n    else:\n        raise ValueError(f\"Invalid tokenization strategy {cfg_data.tokenizer} given.\")\n\n    # Construct tokenizer\n    tokenizer.train_from_iterator(batch_iterator(), trainer=trainer, length=len_dataset)\n\n    if tokenizer.get_vocab_size() != cfg_data.vocab_size:\n        raise RuntimeError(f\"Tokenizer generation failure. Vocab size of trained tokenizer is {tokenizer.get_vocab_size()}.\")\n\n    # Postprocess:\n    eot_token_id = tokenizer.token_to_id(\"<eot>\")\n\n    # Generate template:\n    single_template = \"$A\"\n    if cfg_data.include_eot_token_in_corpus:\n        single_template = single_template + \" <eot>\"\n    tokenizer.post_processor = processors.TemplateProcessing(\n        single=single_template,\n        special_tokens=[(\"<eot>\", eot_token_id)],\n    )\n    # Wrap into fast codebase\n    wrapped_tokenizer = PreTrainedTokenizerFast(\n        tokenizer_object=tokenizer,\n        model_max_length=cfg_data.seq_length,\n        **special_token_args,\n    )\n    return wrapped_tokenizer\n"
  },
  {
    "path": "cramming/data/utils.py",
    "content": "\"\"\"Various utilities.\"\"\"\nimport os\nfrom omegaconf import OmegaConf\nimport hashlib\nimport json\nimport shutil\nimport subprocess\n\nimport logging\nimport time\n\nimport datasets\n\nlog = logging.getLogger(__name__)\n\n\ndef checksum_config(cfg):\n    \"\"\"This is more annoying that I thought it would be. But a json-dump of the config file is hashed and used as checksum.\"\"\"\n    bindump = json.dumps(OmegaConf.to_container(cfg, resolve=True), sort_keys=True).encode(\"utf-8\")\n    checksum_of_config = hashlib.md5(bindump).hexdigest()\n    if \"tokenizer\" in cfg and \"vocab_size\" in cfg:\n        checksum_of_config = f\"{cfg.tokenizer}x{cfg.vocab_size}_{checksum_of_config}\"\n    return checksum_of_config\n\n\ndef stage_dataset(data_directory_path, local_staging_dir):\n    \"\"\"This is a mess because our network drives are a mess. You might not need this.\"\"\"\n    data_directory_name = os.path.basename(data_directory_path)\n    new_path = os.path.join(local_staging_dir, data_directory_name)\n    if os.path.isdir(data_directory_path):\n        try:\n            if not os.path.isdir(new_path):\n                try:\n                    shutil.copytree(data_directory_path, new_path)\n                    log.info(f\"Staging dataset to {new_path}...\")\n                except FileExistsError:\n                    log.info(f\"Concurrent writing to {new_path} detected. Stopping staging in this run and waiting for 300 seconds.\")\n                    time.sleep(300)\n            else:\n                log.info(f\"Using staged dataset found at {new_path}...\")\n\n            for retries in range(15):\n                _, _, free = shutil.disk_usage(new_path)\n                used = _get_size(new_path)\n                try:\n                    tokenized_dataset = datasets.load_from_disk(new_path)\n                    log.info(f\"Staged dataset size is {used / 1024**3:,.3f}GB. {free/ 1024**3:,.3f}GB free in staging dir.\")\n                    return new_path\n                except FileNotFoundError:\n                    log.info(\n                        f\"Staged dataset is incomplete. Size is {used / 1024**3:,.3f}GB. \"\n                        f\" Waiting for 60 more secs for staging race condition.\"\n                    )\n                    time.sleep(60)\n            log.info(f\"Staging dataset corrupted. Falling back to network drive location {data_directory_path}\")\n            return data_directory_path\n\n        except Exception as e:  # noqa\n            log.info(f\"Staging failed with error {e}. Falling back to network drive location {data_directory_path}\")\n            return data_directory_path\n    else:\n        raise FileNotFoundError(f\"Dataset not yet generated or not found at {data_directory_path}.\")\n\n\ndef _get_size(start_path=\".\"):\n    \"\"\"Compute the size of a directory path. Why is this not in the standard library?\"\"\"\n    \"\"\"Stolen from https://stackoverflow.com/questions/1392413/calculating-a-directorys-size-using-python\"\"\"\n    total_size = 0\n    for dirpath, dirnames, filenames in os.walk(start_path):\n        for f in filenames:\n            fp = os.path.join(dirpath, f)\n            # skip if it is symbolic link\n            if not os.path.islink(fp):\n                total_size += os.path.getsize(fp)\n    return total_size\n\n\ndef detailed_OSError(e):\n    if e.errno == 28:  # \"no space left on device\"\n        if e.filename:\n            df_output = subprocess.check_output([\"df\", \"-h\", e.filename]).decode(\"utf-8\")\n            df_lines = df_output.strip().split(\"\\n\")[1:]\n            if df_lines:\n                # The file system containing the file is full\n                device_name, size, used, available, percent, mount_point = df_lines[0].split()\n                error_path = os.path.abspath(e.filename)\n                error_message = f\"Error writing to {error_path}: {e.strerror}\"\n                space_message = f\"{available} space left on {mount_point}\"\n                full_error_message = f\"{error_message}\\nDevice {device_name} is full. {space_message}\"\n        else:\n            # The file name is unknown\n            error_message = f\"Error: {e.strerror}\"\n            full_error_message = f\"{error_message}\\nUnknown file name. Device may be full.\"\n        raise OSError(full_error_message)\n    else:\n        raise e\n"
  },
  {
    "path": "cramming/utils.py",
    "content": "\"\"\"System utilities.\"\"\"\n\nimport socket\nimport sys\n\nimport os\nimport csv\nimport yaml\nimport psutil\nimport pynvml\n\nimport multiprocess  # hf uses this for some reason\nimport collections\n\nimport torch\nimport torch._inductor.config\nimport transformers\n\n\nimport json\nimport random\nimport numpy as np\nimport time\nimport datetime\nimport tempfile\nfrom .data.utils import checksum_config\n\nimport logging\nimport hydra\nfrom omegaconf import OmegaConf, open_dict\nimport cramming\n\nlog = logging.getLogger(__name__)\nos.environ[\"HYDRA_FULL_ERROR\"] = \"0\"\n\n\ndef main_launcher(cfg, main_fn, job_name=\"\"):\n    \"\"\"This is boiler-plate code for a launcher.\"\"\"\n    launch_time = time.time()\n    # Set definitive random seed:\n    if cfg.seed is None:\n        cfg.seed = torch.randint(0, 2**32 - 1, (1,)).item()\n\n    # Figure out all paths:\n    cfg = pathfinder(cfg)\n\n    # Decide GPU and possibly connect to distributed setup\n    setup, kWh_counter = system_startup(cfg)\n    # Initialize wanDB\n    if cfg.wandb.enabled:\n        _initialize_wandb(setup, cfg)\n    log.info(\"--------------------------------------------------------------\")\n    log.info(f\"--------------Launching {job_name} run! ---------------------\")\n    log.info(OmegaConf.to_yaml(cfg, resolve=True))\n    metrics = main_fn(cfg, setup)\n    metrics = collect_system_metrics(cfg, metrics, kWh_counter, setup)\n\n    log.info(\"-------------------------------------------------------------\")\n    log.info(f\"Finished running job {cfg.name} with total train time: \" f\"{str(datetime.timedelta(seconds=time.time() - launch_time))}\")\n    if is_main_process():\n        metrics = flatten(metrics)\n        dump_metrics(cfg, metrics)\n        # Export to wandb:\n        if cfg.wandb.enabled:\n            import wandb\n\n            for k, v in metrics.items():\n                wandb.run.summary[k] = v\n\n        if torch.cuda.is_available():\n            max_alloc = f\"{torch.cuda.max_memory_allocated(setup['device'])/float(1024**3):,.3f} GB\"\n            max_reserved = f\"{torch.cuda.max_memory_reserved(setup['device'])/float(1024**3):,.3f} GB\"\n            log.info(f\"Max. Mem allocated: {max_alloc}. Max. Mem reserved: {max_reserved}.\")\n            log.info(f\"{metrics['kWh']:.2e} kWh of electricity used for GPU(s) during job.\")\n    log.info(\"-----------------Shutdown complete.--------------------------\")\n\n\ndef get_cpus() -> int:\n    # Number of threads\n    try:\n        return min(psutil.cpu_count(logical=False), len(psutil.Process().cpu_affinity()))  # covering both affinity and phys.\n    except:\n        pass\n    try:\n        return os.cpu_count()  # when running on mac\n    except:\n        return 1\n\n\ndef system_startup(cfg):\n    \"\"\"Decide and print GPU / CPU / hostname info. Generate local distributed setting if running in distr. mode.\n\n    Set all required and interesting environment variables.\n    \"\"\"\n    torch.backends.cudnn.benchmark = cfg.impl.benchmark\n    torch.backends.cuda.enable_flash_sdp(cfg.impl.enable_flash_sdp) if cfg.impl.enable_flash_sdp is not None else 0\n    torch.backends.cuda.enable_math_sdp(cfg.impl.enable_math_sdp) if cfg.impl.enable_math_sdp is not None else 0\n    torch.backends.cuda.enable_mem_efficient_sdp(cfg.impl.enable_mem_efficient_sdp) if cfg.impl.enable_mem_efficient_sdp is not None else 0\n    torch.set_float32_matmul_precision(cfg.impl.matmul_precision)\n\n    if cfg.impl.sharing_strategy is not None:\n        torch.multiprocessing.set_sharing_strategy(cfg.impl.sharing_strategy)\n\n    if cfg.impl.tf32_allowed:\n        torch.backends.cudnn.allow_tf32 = True\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True  # Should be true anyway\n\n    multiprocess.set_start_method(\"forkserver\")\n    if cfg.impl.local_staging_dir is not None:\n        tmp_path = os.path.join(cfg.impl.local_staging_dir, \"tmp\")\n        os.makedirs(tmp_path, exist_ok=True)\n        os.environ[\"TMPDIR\"] = tmp_path\n        tempfile.tempdir = None  # Force temporary directory regeneration\n    if cfg.impl.enable_huggingface_offline_mode:\n        os.environ[\"HF_DATASETS_OFFLINE\"] = \"1\"\n        os.environ[\"TRANSFORMERS_OFFLINE\"] = \"1\"\n\n    if cfg.impl.add_env_variables is not None:\n        # Note that for any environment variables added here, they have to be able to change behavior at runtime\n        # for example, the torchdynamo settings are read at import and cannot be changed at runtime here\n        for env_var, string_val in cfg.impl.add_env_variables.items():\n            os.environ[str(env_var)] = str(string_val)\n        log.info(os.environ)\n\n    allowed_cpus_available = get_cpus()\n    # Distributed launch?\n    if \"LOCAL_RANK\" in os.environ:\n        torch.distributed.init_process_group(backend=cfg.impl.dist_backend)\n        local_rank = int(os.environ[\"LOCAL_RANK\"])\n        global_rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n        run = os.environ.get(\"TORCHELASTIC_RUN_ID\", \"unknown\")\n        threads_per_gpu = max(1, min(allowed_cpus_available // max(1, torch.cuda.device_count()), cfg.impl.threads))\n        log.info(\n            f\"Distributed worker initialized on rank {global_rank} (local rank {local_rank}) \"\n            f\"with {world_size} total processes. OMP Threads set to {threads_per_gpu}. Run ID is {run}.\"\n        )\n        log.setLevel(logging.INFO if is_main_process() else logging.ERROR)\n    else:\n        threads_per_gpu = max(1, min(allowed_cpus_available, cfg.impl.threads))\n        global_rank = local_rank = 0\n\n    torch.set_num_threads(threads_per_gpu)\n    os.environ[\"OMP_NUM_THREADS\"] = str(threads_per_gpu)\n    cfg.impl.local_rank = local_rank\n\n    # datasets will automatically disable tokenizer parallelism when needed:\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n    os.environ[\"RAYON_RS_NUM_CPUS\"] = str(threads_per_gpu)\n    max_dataset_memory = f\"{psutil.virtual_memory().total // 2 // max(torch.cuda.device_count(), 1)}\"\n    os.environ[\"HF_DATASETS_IN_MEMORY_MAX_SIZE\"] = max_dataset_memory\n\n    # Construct setup dictionary:\n    dtype = getattr(torch, cfg.impl.default_precision)  # :> dont mess this up\n    device = torch.device(f\"cuda:{local_rank}\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    if torch.cuda.is_available():\n        torch.cuda.set_device(local_rank)\n        log.info(f\"GPU : {torch.cuda.get_device_name(device=device)}. CUDA: {torch.version.cuda}.\")\n\n        # Populate kwH counter:\n        pynvml.nvmlInit()\n        miilijoule_start = pynvml.nvmlDeviceGetTotalEnergyConsumption(pynvml.nvmlDeviceGetHandleByIndex(device.index))\n        kWh_counter = dict(initial_value=miilijoule_start * 1e-6 / 3600)  # kilojoule per hour\n    else:\n        kWh_counter = dict(initial_value=float(\"NaN\"))\n    setup = dict(device=device, dtype=dtype)\n    python_version = sys.version.split(\" (\")[0]\n\n    if local_rank == 0:\n        log.info(f\"Platform: {sys.platform}, Python: {python_version}, PyTorch: {torch.__version__}\")\n        log.info(f\"CPUs: {allowed_cpus_available}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.\")\n\n    if cfg.impl.deterministic:\n        set_deterministic()\n    if cfg.seed is not None:\n        if is_main_process():\n            log.info(f\"Seeding with random seed {cfg.seed} on rank 0.\")\n        set_random_seed(cfg.seed + 10 * global_rank)\n\n    return setup, kWh_counter\n\n\ndef is_main_process():\n    return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0\n\n\ndef num_processes():\n    num_procs = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()\n    return num_procs\n\n\n# def find_pretrained_checkpoint(cfg, downstream_classes=None):\ndef find_pretrained_checkpoint(checkpoint: str, local_checkpoint_folder: str = None, arch_modifications=None):\n    \"\"\"Load a checkpoint either locally or from the internet.\"\"\"\n    # tokenizer is only returned for HF models\n    tokenizer = None\n    cfg_arch = None\n    checkpoint_path = None\n    if checkpoint is None:\n        checkpoint_name = local_checkpoint_folder\n    elif checkpoint == \"latest\":\n        # Load the latest local checkpoint\n        all_checkpoints = [f for f in os.listdir(local_checkpoint_folder)]\n        checkpoint_paths = [os.path.join(local_checkpoint_folder, c) for c in all_checkpoints]\n        # checkpoint_paths = [x for x in checkpoint_paths if x[:6] != \"FINAL_\"]\n        checkpoint_name = max(checkpoint_paths, key=os.path.getmtime)\n    elif checkpoint == \"smallest\":\n        # Load maybe the local checkpoint with smallest loss\n        all_checkpoints = [f for f in os.listdir(local_checkpoint_folder)]\n        checkpoint_paths = [os.path.join(local_checkpoint_folder, c) for c in all_checkpoints]\n        checkpoint_losses = [float(path[-5:]) for path in checkpoint_paths]\n        checkpoint_name = checkpoint_paths[np.argmin(checkpoint_losses)]\n    elif not os.path.isabs(checkpoint) and not checkpoint.startswith(\"hf://\"):\n        # Look locally for a checkpoint with this name\n        checkpoint_name = os.path.join(local_checkpoint_folder, checkpoint)\n    elif checkpoint.startswith(\"hf://\"):\n        # Download this checkpoint directly from huggingface\n        model_name = checkpoint.split(\"hf://\")[1].removesuffix(\"-untrained\")\n        tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n        cfg_arch = transformers.AutoConfig.from_pretrained(model_name)\n        checkpoint_path = checkpoint\n        checkpoint_name = None\n    else:\n        # Look for this name as an absolute path\n        checkpoint_name = checkpoint\n\n    if checkpoint_name is not None:\n        # Load these checkpoints locally, might not be a huggingface model\n        try:\n            tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_name)\n        except:\n            log.warning(f\"Could not load tokenizer from checkpoint: {checkpoint_name}\")\n\n        with open(os.path.join(checkpoint_name, \"model_config.json\"), \"r\") as file:\n            cfg_arch = OmegaConf.create(json.load(file))  # Could have done pure hydra here, but wanted interop\n\n        # Optionally modify parts of the arch at eval time. This is not guaranteed to be a good idea ...\n        # All mismatched parameters will be randomly initialized ...\n        if arch_modifications is not None:\n            cfg_arch = OmegaConf.merge(cfg_arch, arch_modifications)\n            log.info(\"Using arch modifications\")\n\n        checkpoint_path = checkpoint_name\n\n        log.info(f\"Architecture: {cfg_arch}\")\n\n    if checkpoint_path is not None:\n        log.info(f\"Loading from checkpoint {checkpoint_path}...\")\n    else:\n        log.error(f\"No checkpoint to be loaded by: {checkpoint}\")\n\n    return tokenizer, cfg_arch, checkpoint_path\n\n\ndef save_summary(table_name, cfg, stats, local_time, setup, original_cwd=True):\n    \"\"\"Save two summary tables. A detailed table of iterations/loss+acc and a summary of the end results.\"\"\"\n    # 1) detailed table:\n    for step in range(len(stats[\"loss\"])):\n        iteration = dict()\n        for key in stats:\n            iteration[key] = stats[key][step] if step < len(stats[key]) else None\n        save_to_table(\".\", f\"{cfg.name}_convergence_results\", dryrun=cfg.dryrun, **iteration)\n\n    def _maybe_record(key, step=-1):\n        try:\n            return stats[key][step]\n        except (IndexError, ValueError):\n            return \"\"\n\n    if \"data\" in cfg:\n        processed_dataset_dir = f\"{cfg.data.name}_{checksum_config(cfg.data)}\"\n    else:\n        processed_dataset_dir = None\n    base_name = cfg.base_dir.rstrip(os.sep).split(os.sep)[-1]\n    local_folder = os.getcwd().split(base_name)[1].lstrip(os.sep)\n\n    # 2) save a reduced summary\n    if table_name == \"pretrain\":\n        summary = dict(\n            name=cfg.name,\n            budget=cfg.budget,\n            dataset=\"_\".join(processed_dataset_dir.split(\"_\")[:-1]),\n            backend=cfg.impl.name,\n            arch=\" \".join(cfg.arch.architectures),\n            loss=_maybe_record(\"loss\"),\n            final_step=_maybe_record(\"step\"),\n            final_epoch=_maybe_record(\"epoch\"),\n            step_time=np.mean(stats[\"train_time\"]) if len(stats[\"train_time\"]) > 0 else \"\",\n            loss100k=_maybe_record(\"loss\", step=100_000 // cfg.impl.print_loss_every_nth_step),\n            loss200k=_maybe_record(\"loss\", step=200_000 // cfg.impl.print_loss_every_nth_step),\n            loss300k=_maybe_record(\"loss\", step=300_000 // cfg.impl.print_loss_every_nth_step),\n            total_time=str(datetime.timedelta(seconds=local_time)).replace(\",\", \"\"),\n            batch_size=cfg.train.batch_size,\n            lr=cfg.train.optim.lr,\n            warmup=cfg.train.warmup_steps,\n            steps=cfg.train.steps,\n            # System settings:\n            seed=cfg.seed,\n            dataset_hash=processed_dataset_dir.split(\"_\")[-1],\n            base_dir=cfg.base_dir,\n            impl_path=cfg.impl.path,\n            local_folder=local_folder,\n            # # Dump configs from here on:\n            **{f\"Data_{k}\": v for k, v in cfg.data.items()},\n            **{f\"Arch_{k}\": v for k, v in cfg.arch.items()},\n            **{f\"Train_{k}\": v for k, v in cfg.train.items()},\n        )\n    else:\n        summary = dict(\n            name=cfg.name,\n            backend=cfg.impl.name,\n            checkpoint=cfg.eval.checkpoint,\n            loss=_maybe_record(\"loss\"),\n            avg_loss=_maybe_record(\"avg_loss\"),\n            final_epoch=_maybe_record(\"epoch\"),\n            step_time=np.mean(stats[\"train_time\"]) if len(stats[\"train_time\"]) > 0 else \"\",\n            total_time=str(datetime.timedelta(seconds=local_time)).replace(\",\", \"\"),\n            batch_size=cfg.eval.batch_size,\n            lr=cfg.eval.optim.lr,\n            warmup=cfg.eval.warmup_steps,\n            # System settings:\n            seed=cfg.seed,\n            base_dir=cfg.base_dir,\n            impl_path=cfg.impl.path,\n            local_folder=local_folder,\n            # # Dump configs from here on:\n            **{f\"Eval_{k}\": v for k, v in cfg.eval.items()},\n        )\n    location = os.path.join(cfg.original_cwd, \"tables\") if original_cwd else \"tables\"\n    save_to_table(location, f\"{table_name}_reports\", dryrun=cfg.dryrun, **summary)\n\n\ndef save_to_table(out_dir, table_name, dryrun, **kwargs):\n    \"\"\"Save keys to .csv files.\"\"\"\n    # Check for file\n    if not os.path.isdir(out_dir):\n        os.makedirs(out_dir)\n    fname = os.path.join(out_dir, f\"table_{table_name}.csv\")\n    fieldnames = list(kwargs.keys())\n    # Read or write header\n    try:\n        with open(fname, \"r\") as f:\n            reader = csv.reader(f, delimiter=\"\\t\")\n            header = next(reader)  # noqa  # this line is testing the header\n            # assert header == fieldnames[:len(header)]  # new columns are ok, but old columns need to be consistent\n            # dont test, always write when in doubt to prevent erroneous table deletions\n    except Exception as e:  # noqa\n        if not dryrun:\n            # print('Creating a new .csv table...')\n            with open(fname, \"w\") as f:\n                writer = csv.DictWriter(f, delimiter=\"\\t\", fieldnames=fieldnames)\n                writer.writeheader()\n        else:\n            pass\n\n    # Write a new row\n    if not dryrun:\n        # Add row for this experiment\n        with open(fname, \"a\") as f:\n            writer = csv.DictWriter(f, delimiter=\"\\t\", fieldnames=fieldnames)\n            writer.writerow(kwargs)\n    else:\n        pass\n\n\ndef set_random_seed(seed=233):\n    \"\"\".\"\"\"\n    torch.manual_seed(seed + 1)\n    torch.cuda.manual_seed(seed + 2)\n    torch.cuda.manual_seed_all(seed + 3)\n    np.random.seed(seed + 4)\n    torch.cuda.manual_seed_all(seed + 5)\n    random.seed(seed + 6)\n    # Can't be too careful :>\n\n\ndef set_deterministic():\n    \"\"\"Switch pytorch into a deterministic computation mode.\"\"\"\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    torch.use_deterministic_algorithms(True)\n    os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n\ndef avg_n_dicts(dicts):\n    \"\"\"https://github.com/wronnyhuang/metapoison/blob/master/utils.py.\"\"\"\n    # given a list of dicts with the same exact schema, return a single dict with same schema whose values are the\n    # key-wise average over all input dicts\n    means = {}\n    for dic in dicts:\n        for key in dic:\n            if key not in means:\n                if isinstance(dic[key], list):\n                    means[key] = [0 for entry in dic[key]]\n                else:\n                    means[key] = 0\n            if isinstance(dic[key], list):\n                for idx, entry in enumerate(dic[key]):\n                    means[key][idx] += entry / len(dicts)\n            else:\n                means[key] += dic[key] / len(dicts)\n    return means\n\n\ndef dump_metrics(cfg, metrics):\n    \"\"\"Simple yaml dump of metric values.\"\"\"\n\n    filepath = f\"metrics_{cfg.name}.yaml\"\n    sanitized_metrics = dict()\n    for metric, val in metrics.items():\n        try:\n            sanitized_metrics[metric] = np.asarray(val).item()\n        except ValueError:\n            sanitized_metrics[metric] = np.asarray(val).tolist()\n    with open(filepath, \"w\") as yaml_file:\n        yaml.dump(sanitized_metrics, yaml_file, default_flow_style=False)\n\n\ndef _initialize_wandb(setup, cfg):\n    if is_main_process():\n        import wandb\n\n        config_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)\n        settings = wandb.Settings(start_method=\"thread\")\n        settings.update({\"git_root\": cfg.original_cwd})\n        run = wandb.init(\n            entity=cfg.wandb.entity,\n            project=cfg.wandb.project,\n            settings=settings,\n            name=cfg.name,\n            mode=\"disabled\" if cfg.dryrun else None,\n            tags=cfg.wandb.tags if len(cfg.wandb.tags) > 0 else None,\n            config=config_dict,\n        )\n        run.summary[\"GPU\"] = torch.cuda.get_device_name(device=setup[\"device\"]) if torch.cuda.device_count() > 0 else \"\"\n        run.summary[\"numGPUs\"] = torch.cuda.device_count()\n\n\ndef wandb_log(stats, cfg):\n    if cfg.wandb.enabled:\n        if is_main_process():\n            import wandb\n\n            wandb.log({k: v[-1] for k, v in stats.items()}, step=stats[\"step\"][-1] if \"step\" in stats else None)\n\n\ndef flatten(d, parent_key=\"\", sep=\"_\"):\n    \"\"\"Straight-up from https://stackoverflow.com/a/6027615/3775820.\"\"\"\n    items = []\n    for k, v in d.items():\n        new_key = parent_key + sep + k if parent_key else k\n        if isinstance(v, collections.abc.MutableMapping):\n            items.extend(flatten(v, new_key, sep=sep).items())\n        else:\n            items.append((new_key, v))\n    return dict(items)\n\n\ndef collect_system_metrics(cfg, metrics, kWh_counter, setup):\n    # Finalize some compute metrics:\n    metrics[\"GPU\"] = torch.cuda.get_device_name(device=setup[\"device\"]) if torch.cuda.device_count() > 0 else \"\"\n    metrics[\"numGPUs\"] = torch.cuda.device_count()\n    metrics[\"VRAM\"] = torch.cuda.max_memory_allocated(setup[\"device\"]) / float(1 << 30)\n    metrics[\"RAM\"] = psutil.Process(os.getpid()).memory_info().rss / 1024**3\n    if torch.cuda.device_count() == 1:\n        metrics[\"kWh\"] = get_kWh(kWh_counter, setup)\n    else:\n        if torch.distributed.is_initialized():\n            local_kWh = get_kWh(kWh_counter, setup)\n            kWh_comm = torch.as_tensor(local_kWh).cuda() if torch.cuda.is_available() else kWh_comm.float()\n            torch.distributed.all_reduce(kWh_comm, torch.distributed.ReduceOp.SUM, async_op=False)\n            metrics[\"kWh\"] = kWh_comm.item()\n        else:\n            metrics[\"kWh\"] = float(\"NaN\")\n    return metrics\n\n\ndef get_kWh(kWh_counter, setup):\n    miilijoule_final = pynvml.nvmlDeviceGetTotalEnergyConsumption(pynvml.nvmlDeviceGetHandleByIndex(setup[\"device\"].index))\n    kWh_final = miilijoule_final * 1e-6 / 3600  # kilojoule per hour\n    kWh = kWh_final - kWh_counter[\"initial_value\"]\n    return kWh\n\n\ndef pathfinder(cfg):\n    with open_dict(cfg):\n        cfg.original_cwd = hydra.utils.get_original_cwd()\n        # ugliest way to get the absolute path to output subdir\n        if not os.path.isabs(cfg.base_dir):\n            base_dir_full_path = os.path.abspath(os.getcwd())\n            while os.path.basename(base_dir_full_path) != cfg.base_dir:\n                base_dir_full_path = os.path.dirname(base_dir_full_path)\n                if base_dir_full_path == \"/\":\n                    raise ValueError(\"Cannot find base directory.\")\n            cfg.base_dir = base_dir_full_path\n\n        cfg.impl.path = os.path.expanduser(cfg.impl.path)\n        if not os.path.isabs(cfg.impl.path):\n            cfg.impl.path = os.path.join(cfg.base_dir, cfg.impl.path)\n    return cfg\n"
  },
  {
    "path": "create_data_split.py",
    "content": "from transformers import PreTrainedTokenizer\nimport random\nimport os\nimport torch\nfrom transformers import AutoTokenizer\nfrom torch.nn.utils.rnn import pad_sequence\nfrom datasets import Dataset, DatasetDict\nimport pandas as pd\nimport datasets\nimport json\nimport argparse\nfrom cramming.data.tokenizer_preparation import get_tokenizer\nimport matplotlib.pyplot as plt\nfrom collections import Counter\nfrom matplotlib import cm\nimport re\nfrom dataset_analysis import main as data_analysis_main\nimport numpy as np\n\ndef generate_no_carry_addition(n, m):\n    \"\"\"No carries addition, brute force implementation\"\"\"\n    num1 = random.randint(10**(n-1), 10**n - 1)\n    num2 = random.randint(10**(m-1), 10**m - 1)\n\n    while has_carry(num1, num2):\n        num1 = random.randint(10**(n-1), 10**n - 1)\n        num2 = random.randint(10**(m-1), 10**m - 1)\n\n    return num1, num2, num1 + num2\n\ndef has_carry(num1, num2):\n    # Check if there is a carry in any column during addition\n    for digit1, digit2 in zip(str(num1)[::-1], str(num2)[::-1]):\n        if int(digit1) + int(digit2) >= 10:\n            return True\n    return False\n\n# Function to generate the arithmetic dataset\ndef generate_dataset(dir_name, operation, n, m, num_examples, base_folder_name, keep_places, exact, prepend_zeros, reverse_answer, reverse_all, p=0, no_carry_addition=False, seed=42, interleave=False):\n    \"\"\"\n    generate a dataset, NOT using the bucket method!\n    p = probability for random padding to be inserted\n    \"\"\"\n    if p < 0 or p >= 1:\n        raise ValueError(\"Probability p must be strictly between 0 and 1.\")\n\n    random.seed()\n    dataset = []\n\n    for _ in range(num_examples):\n        if exact: # exactly length n,m \n            num1 = random.randint(10**(n-1), 10**n - 1)\n            num2 = random.randint(10**(m-1), 10**m - 1)\n        elif no_carry_addition and operation == '+':\n            num1, num2, _ = generate_no_carry_addition(n,m)\n        else:\n            num1 = random.randint(0, 10**n - 1)\n            num2 = random.randint(0, 10**m - 1)\n\n        if keep_places: # fill with zeros so it is always the same length\n            num1_str = str(num1).zfill(n)\n            num2_str = str(num2).zfill(m)\n        else:\n            num1_str = str(num1)\n            num2_str = str(num2)\n\n        if operation == '+':\n            result = num1 + num2\n        elif operation == '-':\n            result = num1 - num2\n        elif operation == 'x':\n            result = num1 * num2\n        else:\n            raise ValueError(\"Invalid operation\")\n\n        result = str(result)\n\n        if prepend_zeros > 0:\n            zeros = \"0\"*prepend_zeros\n            num1_str = zeros + num1_str\n            num2_str = zeros + num2_str\n            result = \"0\" + zeros + result\n\n        orgional_p = p\n\n        if reverse_all: # reversals \n            result = result[::-1]\n            num1_str = num1_str[::-1]\n            num2_str = num2_str[::-1]\n        elif reverse_answer:\n            result = result[::-1]\n        \n\n        dataset_entry = f\"{num1_str}{operation}{num2_str}={result}\"\n        if interleave: # interleave the operands so the digits of the same significance are  next to eachother\n            dataset_entry = ''.join([a + b for a, b in zip(num1_str, num2_str)]) + num1_str[len(num2_str):] + num2_str[len(num1_str):]+f\"={result}\"\n        p = orgional_p\n        if p > 0: # adds random spaces, exponentially decaying\n            dataset_entry = f\"{num1_str}{operation}{num2_str}={result}\"\n            if interleave:\n                dataset_entry = ''.join([a + b for a, b in zip(num1_str, num2_str)]) + num1_str[len(num2_str):] + num2_str[len(num1_str):]+f\"={result}\"\n            spaced_string = \"\"\n            for char in dataset_entry:\n                space_p = p\n                while random.random() < space_p:\n                    space_p *= 0.1\n                    spaced_string += \" \"\n                spaced_string += char\n            dataset_entry = spaced_string\n        dataset.append(dataset_entry)\n\n    for i in range(0,min(len(dataset),5)):\n        print(dataset[i])\n    \n    folder_name = f\"{base_folder_name}/{dir_name}\"\n    os.makedirs(folder_name, exist_ok=True)\n    # automated file name\n    file_name = f\"{operation}_n_{n}_m_{m}_examples_{num_examples}{'_diff_lens' if not keep_places else ''}{'_exact' if exact else ''}{f'_prepend_{prepend_zeros}zeros' if prepend_zeros>0 else ''}{f'_reverse_ans' if reverse_answer else ''}{f'_prob_space_{p}' if p>0 else ''}_seed_{seed}.txt\"\n    file_path = os.path.join(folder_name, file_name)\n\n    with open(file_path, 'w') as file:\n        for entry in dataset:\n            file.write(entry + '\\n')\n    print(f\"created: {file_path}\")\n    return dataset, folder_name, file_path\n\n\ndef tokenize_and_save_dataset(dataset, tokenizer, directory, test_split_ratio=0.05, pad_sequences=False):\n    # tokenization, slow but gets the job done\n\n    os.makedirs(directory, exist_ok=True)\n\n    # Tokenize the dataset and add EOS token at the end of each entry\n    eos_token_id = tokenizer.vocab[tokenizer.eos_token]\n    tokenized_dataset = [tokenizer(entry)[\"input_ids\"] + [eos_token_id] for entry in dataset]\n\n    # print some of them say 5 input and its tokenized version\n    print(\"Some examples of tokenized dataset:\")\n    for i in range(0,min(len(dataset),5)):\n        print(f\"Input: {dataset[i]}\")\n        print(f\"Tokenized: {tokenized_dataset[i]}\")\n        decoded = tokenizer.decode(tokenized_dataset[i])\n        print(f\"Decoded: {decoded}\")\n        print()\n\n    # Optionally pad the sequences\n    if pad_sequences:\n        max_length = max(len(entry) for entry in tokenized_dataset)\n        pad_token_id = tokenizer.pad_token_id\n        tokenized_dataset = [entry + [pad_token_id] * (max_length - len(entry)) for entry in tokenized_dataset]\n\n    save_to_json_intermed = False # save the tokenized dataset to a json instead of hf\n    if save_to_json_intermed:\n        print(tokenized_dataset)\n        data_path = os.path.join(directory, \"dataset.json\")\n        with open(data_path, \"w\") as outfile:\n            # Iterate over each dictionary in the list\n            for entry in tokenized_dataset:\n                # Convert dictionary to JSON string and write it to the file\n                json.dump({'input_ids': entry}, outfile)\n                # Write a newline character to separate each JSON object\n                outfile.write('\\n')\n        exit()\n\n    # Split the data into train and test sets\n    test_size = int(len(tokenized_dataset) * test_split_ratio)\n    train_data = tokenized_dataset[:-test_size]\n    test_data = tokenized_dataset[-test_size:]\n    # Convert to Hugging Face datasets with 'input_ids' column\n    train_dataset = Dataset.from_pandas(pd.DataFrame({\"input_ids\": train_data}))\n    test_dataset = Dataset.from_pandas(pd.DataFrame({\"input_ids\": test_data}))\n\n    # Create a DatasetDict with train and test splits\n    dataset_dict = DatasetDict({\n        \"train\": train_dataset,\n        \"test\": test_dataset\n    })\n\n    # Save the dataset to disk\n    hf_dataset_path = os.path.join(directory, \"hf_tokenized_dataset\")\n    dataset_dict.save_to_disk(hf_dataset_path)\n\n    # # Save tokenizer\n    # print(f\"Tokenized data saved to {tokenized_data_path}\")\n    print(f\"HuggingFace Dataset saved to {hf_dataset_path}\")\n\n    # return dataset_dict, tokenized_data_path, hf_dataset_path #, tokenizer_dir\n    return dataset_dict, hf_dataset_path\n\ndef character_histogram(dir_name, condense_white_space=False):\n    \"\"\"Histogram of character occurences\"\"\"\n    base_directory = \"./cramming-data/data/arithmetic_data\"\n    dir_name = os.path.join(base_directory, dir_name)\n\n    # open all data files and append to big list\n    dataset = []\n    for filename in os.listdir(dir_name):\n        if filename.endswith(\".txt\"):\n            file_path = os.path.join(dir_name, filename)\n            with open(file_path, \"r\") as file:\n                lines = file.readlines()\n                stripped_lines = [line.replace(\"\\n\", \"\") for line in lines]\n                if condense_white_space:\n                    stripped_lines = [re.sub('\\s+',' ', line) for line in lines]\n                dataset.extend(stripped_lines)\n\n    for i in range(0,min(len(dataset),5)):\n        print(dataset[i])\n\n    max_length = max(map(len, dataset))\n    \n    counters_list = [Counter() for _ in range(max_length)]\n\n    for string in dataset:\n        for index, char in enumerate(string):\n            counters_list[index][char] += 1\n\n    # Plot the occurrences for each index\n    plt.figure(figsize=(10, 6))\n    indices = list(range(max_length))\n    bottom = [0] * max_length\n    sorted_chars = sorted(set(''.join(dataset)))\n\n    colors = cm.get_cmap('tab20', len(sorted_chars))\n\n    for char, color in zip(sorted_chars, colors.colors):\n        occurrences = [counter[char] for counter in counters_list]\n        legend_char = char if char != \" \" else \"\\' \\'\"\n        plt.bar(indices, occurrences, label=legend_char, bottom=bottom, color=color)\n        bottom = [b + o for b, o in zip(bottom, occurrences)]\n\n    plt.xlabel('Index')\n    plt.ylabel('Occurrences')\n    plt.title(\"Character Frequency\")\n    plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=10)\n    plt.savefig(f\"{dir_name}/char_histogram{'_condensed_ws' if condense_white_space else ''}\", bbox_inches='tight')\n\ndef token_histogram(dir_name, tokenizer_type=\"normal\"):\n    \"\"\"Histogram of token occurences\"\"\"\n    base_directory = \"./cramming-data/data/arithmetic_data\"\n    dir_name = os.path.join(base_directory, dir_name)\n    hf_dir_name = os.path.join(dir_name, \"hf_tokenized_dataset\")\n    tokenized_dataset = datasets.load_from_disk(hf_dir_name)\n    train_part = tokenized_dataset[\"train\"]\n    test_part = tokenized_dataset[\"test\"]\n    \n    tokenizer = get_tokenizer(tokenizer_type)\n    EOS_token = tokenizer._convert_token_to_id(\"[EOS]\")\n    \n    dataset = []\n    for example in train_part:\n        tokens = example[\"input_ids\"]\n        eos_index = tokens.index(EOS_token) if EOS_token in tokens else len(tokens) # not including the EOS token\n        tokens = tokens[:eos_index]\n        dataset.append(tokens)\n    for example in test_part:\n        tokens = example[\"input_ids\"]\n        eos_index = tokens.index(EOS_token) if EOS_token in tokens else len(tokens) # not including the EOS token\n        tokens = tokens[:eos_index]\n        dataset.append(tokens)\n\n    for i in range(0,min(len(dataset),5)):\n        print(dataset[i])\n\n    max_length = max(map(len, dataset))\n    counters_list = [Counter() for _ in range(max_length)]\n\n    for string in dataset:\n        for index, char in enumerate(string):\n            counters_list[index][str(char)] += 1\n\n    plt.figure(figsize=(10, 6))\n    indices = list(range(max_length))\n    bottom = [0] * max_length\n    print(tokenizer.vocab.values())\n    sorted_chars = [str(x) for x in sorted(tokenizer.vocab.values())]\n    \n    colors = cm.get_cmap('tab20', len(sorted_chars))\n\n    for char, color in zip(sorted_chars, colors.colors):\n        occurrences = [counter[char] for counter in counters_list]\n        tokenizer_char = tokenizer._convert_id_to_token(int(char))\n        tokenizer_char = tokenizer_char if tokenizer_char != \" \" else \"\\' \\'\"\n        legend_char = f\"{char} => {tokenizer_char}\"\n        plt.bar(indices, occurrences, label=legend_char, bottom=bottom, color=color)\n        bottom = [b + o for b, o in zip(bottom, occurrences)]\n\n    plt.xlabel('Index')\n    plt.ylabel('Occurrences')\n    plt.title(\"Token Frequency\")\n    legend = plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.35), ncol=6)\n    legend.set_title(\"token => char\")\n\n    plt.savefig(f\"{dir_name}/token_histogram\", bbox_inches='tight')\n\ndef main_dataset_gen(dir_name, op, n, m, num_samples, exact=False, keep_places=False, prepend_zeros=0, reverse_answer=False, reverse_all=False, p=0, no_carry_addition=False, seed=42, interleave=False):\n    \"\"\"Main method for non bucket datasets\"\"\"\n    base_directory = \"./cramming-data/data\"\n    os.makedirs(base_directory, exist_ok=True)\n    base_directory = f\"{base_directory}/arithmetic_data\"\n    os.makedirs(base_directory, exist_ok=True)\n    \n    dataset, data_folder_name, _ = generate_dataset(dir_name, op, n, m, num_samples, base_directory, keep_places, exact, prepend_zeros, reverse_answer, reverse_all, p, no_carry_addition, seed=seed, interleave=interleave)\n\ndef tokenize_main(dir_name, tokenizer_type, test_split_ratio=0.05):\n    \"\"\"Main tokenizer method\"\"\"\n    base_directory = \"./cramming-data/data/arithmetic_data\"\n    dir_name = os.path.join(base_directory, dir_name)\n    data_folder_name = dir_name\n\n    # Initialize the tokenizer\n    tokenizer = get_tokenizer(tokenizer_type)\n\n    # open all data files and append to big list\n    dataset = []\n\n    for filename in os.listdir(dir_name):\n        if filename.endswith(\".txt\"):\n            file_path = os.path.join(dir_name, filename)\n            with open(file_path, \"r\") as file:\n                lines = file.readlines()\n                # stripped_lines = [line.strip() for line in lines]\n                stripped_lines = [line.replace(\"\\n\", \"\") for line in lines]\n                dataset.extend(stripped_lines)\n    random.shuffle(dataset) # shuffling all the datasets together\n\n    dataset_dict, hf_dataset_path = tokenize_and_save_dataset(dataset, tokenizer, data_folder_name,\n                                                                                   pad_sequences=True,\n                                                                                   test_split_ratio=test_split_ratio)\n    tokenized_dataset = datasets.load_from_disk(hf_dataset_path)\n    print(tokenized_dataset)\n\n\ndef pick_char_set(max_len):\n    \"\"\"Pick a set of characters in a cyclic method for index hints\"\"\"\n    # 102 characters\n    set_of_chars = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', '!', '@', '£', '#', '$', '%', '^', '&', '*', '(', ')', '~', '?', '.', ',', '<', '>', '{', '}', '[', ']', ':', ';','/','|','β','Γ', 'Δ', 'δ', 'ε', 'ζ', 'η', 'θ', 'κ','Λ', 'λ', 'μ', 'Ξ', 'ξ','Π', 'π','Σ', 'ς', 'τ', 'Φ', 'φ', 'χ', 'Ψ', 'ψ', 'Ω', 'ω']\n    \n    output = []\n    start = random.randint(0, len(set_of_chars))\n    if start + max_len > len(set_of_chars): # i.e. cycle round\n        return set_of_chars[start:len(set_of_chars)] + set_of_chars[:start + max_len-len(set_of_chars)]\n    else:\n        return set_of_chars[start:start + max_len]\n\ndef hints_helper(num_str, chars):\n    # returns the positional hints with the number\n    result = \"\"\n    for char, digit in zip(chars, num_str):\n        result += f\"{char}{digit}\"\n    return result\n\ndef bucket_method_gen(n=3, m=3, operation='+', limit=1000, p=0, no_carry_addition=False, reverse_answer=False, start=1, reverse_all=False, keep_0_for_len_1=False, Flags=None):\n    \"\"\"Bucket method generator, samples all operand lengths equally\"\"\"\n    dataset = []\n    while True:\n        for i in range(start,n+1):\n            for j in range(start,m+1):\n                start_i = 10**(i-1)\n                start_j = 10**(j-1)\n                if keep_0_for_len_1 and i==1: # i.e. use natruals including 0, we just use naturals\n                    start_i = 0\n                if keep_0_for_len_1 and j==1:\n                    start_j = 0\n                num1 = random.randint(start_i, (10**i - 1))\n                num2 = random.randint(start_j, 10**j - 1)\n\n                if no_carry_addition and operation == '+':\n                    num1, num2, _ = generate_no_carry_addition(i,j)\n                num1_str = str(num1)\n                num2_str = str(num2)\n\n                if operation == '+':\n                    result = num1 + num2\n                elif operation == '-':\n                    result = num1 - num2\n                elif operation == 'x':\n                    result = num1 * num2\n                else:\n                    raise ValueError(\"Invalid operation\")\n\n                result = str(result)\n                if reverse_answer: # reversals\n                    result = result[::-1]\n                if reverse_all:\n                    result = result[::-1]\n                    num1_str = num1_str[::-1]\n                    num2_str = num2_str[::-1]\n                if Flags.index_hints: # adding the index hints\n                    max_len = max(len(result), max(len(num1_str),len(num2_str)))\n                    chars = pick_char_set(max_len)\n                    result = hints_helper(result, chars)\n                    num1_str = hints_helper(num1_str, chars)\n                    num2_str = hints_helper(num2_str, chars)\n                else:\n                    dataset_entry = f\"{num1_str}{operation}{num2_str}={result}\"\n\n                    if p > 0: # adds random spaces\n                        spaced_string = \"\"\n                        for char in dataset_entry:\n                            space_p = p\n                            while random.random() < space_p:\n                                space_p *= 0.1\n                                spaced_string += \" \"\n                            spaced_string += char\n                        dataset_entry = spaced_string\n                \n                dataset.append(dataset_entry)\n                if len(dataset) == limit:\n                    return dataset\n\ndef bucket_method_main(n, m, operation, limit, dir_name, p=0, no_carry_addition=False, reverse_answer=False, start=1, reverse_all=False, keep_0_for_len_1=False, Flags=None):\n    \"\"\"Mains method for bucket style generation\"\"\"\n    dataset = bucket_method_gen(n, m, operation, limit, p, no_carry_addition, reverse_answer, start, reverse_all=reverse_all, keep_0_for_len_1=keep_0_for_len_1, Flags=Flags)\n    for i in range(0,10):\n        print(dataset[i])\n    \n    base_directory = \"./cramming-data/data\"\n    os.makedirs(base_directory, exist_ok=True)\n    base_directory = f\"{base_directory}/arithmetic_data\"\n    os.makedirs(base_directory, exist_ok=True)\n    \n    folder_name = f\"{base_directory}/{dir_name}\"\n    os.makedirs(folder_name, exist_ok=True)\n    file_name = f\"{operation}_n_{n}_m_{m}_examples_{limit}.txt\"\n    file_path = os.path.join(folder_name, file_name)\n\n    random.seed()\n    random.shuffle(dataset)\n    with open(file_path, 'w') as file:\n        for entry in dataset:\n            file.write(entry + '\\n')\n    print(f\"created: {file_path}\")\n    return dataset, folder_name, file_path\n\n\ndef uniform_distribution_sort_basic(maximum_number_of_digts, maximum_length, limit, FLAGS):\n    \"\"\"sorting dataset generator\"\"\"\n    dataset = []\n    for i in range(0, limit):\n        dataset_entry = \"\"\n        chars = pick_char_set(maximum_length)\n        local_chars = pick_char_set(maximum_number_of_digts)\n        all_nums = []\n        for j in range(0, maximum_length):\n            # choose a random number of digit between 1 and maximum_number_of_digts\n            num_digit = random.randint(1, maximum_number_of_digts)\n            # pick a number with num_digit digits\n            num = random.randint(10**(num_digit-1), 10**num_digit - 1)\n            all_nums.append([chars[j], num])\n\n            num = str(num)\n            if FLAGS.reverse_all:\n                num = num[::-1]\n            if FLAGS.index_hints:\n                num = hints_helper(num, local_chars)\n            dataset_entry += f\"{chars[j]}:{num},\"\n\n        dataset_entry = dataset_entry[:-1]\n        all_nums = sorted(all_nums, key=lambda x: x[1]) # get the answer\n        sorted_chars = [x[0] for x in all_nums]\n        dataset_entry += f\"={','.join(sorted_chars)}\" # convert them into a string separated by ,\n        dataset.append(dataset_entry)\n\n    return dataset\n\ndef bucket_uniform_distribution(maximum_number_of_digts, maximum_length, limit, FLAGS):\n    \"\"\"Use a uniform distribution over -- i.e. bucket method for sorting\"\"\"\n    bucket_limit = limit // (maximum_length * maximum_number_of_digts)\n    dataset = []\n    for i in range(0, maximum_length):\n        for j in range(0, maximum_number_of_digts):\n            dataset += uniform_distribution_sort_basic(j+1, i+1, bucket_limit, FLAGS)\n    return dataset\n\ndef uniform_distribution_sort_main(FLAGS, dir_name):\n    \"\"\"Main method for sorting generation\"\"\"\n    maximum_number_of_digts = FLAGS.n\n    maximum_length = FLAGS.m\n    limit = FLAGS.limit\n\n    dataset = bucket_uniform_distribution(maximum_number_of_digts, maximum_length, limit, FLAGS)\n\n    for i in range(0, 10):\n        print(dataset[i])\n\n    base_directory = \"./cramming-data/data\"\n    os.makedirs(base_directory, exist_ok=True)\n    base_directory = f\"{base_directory}/arithmetic_data\"\n    os.makedirs(base_directory, exist_ok=True)\n\n    folder_name = f\"{base_directory}/{dir_name}\"\n    os.makedirs(folder_name, exist_ok=True)\n    file_name = f\"sort_maximum_number_of_digts_{FLAGS.n}\" \\\n                f\"_maximum_length_{FLAGS.m}_examples_{limit}.txt\"\n    file_path = os.path.join(folder_name, file_name)\n\n    random.seed()\n    random.shuffle(dataset)\n    with open(file_path, 'w') as file:\n        for entry in dataset:\n            file.write(entry + '\\n')\n    print(f\"created: {file_path}\")\n    return dataset, folder_name, file_path\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Train a model\")\n    # General addition\n    parser.add_argument(\"--dir_name\", type=str, required=True, help='name of dataset')\n    parser.add_argument(\"--op\", type=str, default='+', help=\"operation e.g. +,-,x\")\n    parser.add_argument(\"--n\", default=2, type=int, help=\"num digits in first number\")\n    parser.add_argument(\"--m\", default=2, type=int, help=\"num digits in second number\")\n    parser.add_argument(\"--num_samples\", default=100, type=int, help=\"number of samples\")\n    parser.add_argument(\"--seed\", default=42, type=int, help=\"seed for random generation\")\n    parser.add_argument('--keep_places', action='store_true') # i.e. default is different length numbers\n    parser.add_argument('--exact', action='store_true') # will only take numbers which are exactly length n,m if turned on\n    parser.add_argument('--special', action='store_true') # special flag to do any crazy ideas\n    parser.add_argument('--p', default=0.0, type=float, help=\"prob for adding padding\")\n    parser.add_argument(\"--prepend_zeros\", default=0, type=int, help=\"prepend this number of zeros to n, m and answer (adds 1 more to answer)\")\n    parser.add_argument('--reverse_answer', action='store_true', help=\"reverses the answer\")\n    parser.add_argument('--reverse_all', action='store_true', help=\"reverses the inputs and answer\")\n    parser.add_argument('--no_carry_addition', action='store_true', help=\"no carried in the addition\")\n    parser.add_argument('--test_split_ratio', default=0.05, type=float, help=\"test split percentage\")\n    parser.add_argument('--interleave', action='store_true', help=\"interleave digits of the operands\")\n    parser.add_argument('--keep_0_for_len_1', action='store_true', help='keep 0 as a possible digit for length 1 digits, i.e. Naturals including 0')\n    \n    # bucket method to sample all operands equally\n    parser.add_argument('--bucket', action='store_true', help='all operand lengths sampled equally')\n    parser.add_argument(\"--limit\", default=1000000, type=int, help=\"number of samples if using the bucket method\")\n    parser.add_argument('--index_hints', action='store_true', help='use index hints for numbers')\n\n    # tokenize\n    parser.add_argument('--tokenize', action='store_true', help='tokenize the all txt files in the dir_name given') # i.e. tokenize the folder\n    parser.add_argument(\"-tt\", \"--tokenizer_type\", type=str, default=\"pad\", help='tokenizer type used')\n    \n    # sort\n    parser.add_argument('--uniform_distribution_sort_data', action='store_true', help='sort data')\n    parser.add_argument(\"--extra_path\", type=str, default=None, help='extra path infront of the autogenerated sort data path')\n\n    FLAGS = parser.parse_args()\n    random.seed(FLAGS.seed)\n    if FLAGS.no_carry_addition and FLAGS.op != '+':\n        print(\"no carries is only for addition\")\n        exit()\n        \n    if FLAGS.bucket:\n        # automated nameing scheme for the most common flags\n        index_hints = \"_with_index_hints_circular\" if FLAGS.index_hints else \"\"\n        folder_name = f\"{FLAGS.op}_bucket_method_n_{FLAGS.n}_m_{FLAGS.m}_{FLAGS.limit}_p_{str(FLAGS.p).replace('.','')}{'_reverse_ans' if FLAGS.reverse_answer else ''}{'_reverse_all' if FLAGS.reverse_all else ''}{'_keep_0_for_len_1' if FLAGS.keep_0_for_len_1 else ''}{index_hints}\"\n        print(f\"folder name = {folder_name}\")\n        if FLAGS.no_carry_addition:\n            folder_name = FLAGS.dir_name\n        bucket_method_main(FLAGS.n, FLAGS.m, FLAGS.op, FLAGS.limit, folder_name, FLAGS.p, FLAGS.no_carry_addition, FLAGS.reverse_answer,reverse_all=FLAGS.reverse_all,keep_0_for_len_1=FLAGS.keep_0_for_len_1, Flags=FLAGS)\n        print(\"dataset made\")\n        character_histogram(folder_name)\n        print(\"char histogram made\")\n        data_analysis_main(folder_name) # more automated analysis\n        exit()\n\n    if FLAGS.uniform_distribution_sort_data:\n        index_hints = \"_with_index_hints_circular\" if FLAGS.index_hints else \"\"\n\n        # uniform_distribution_steps\n        # bucket_uniform_distribution\n\n        # sort\n        # n - max length of a number\n        # m - number of numbers in the list to sort\n        folder_name = f\"sort_bucket_uniform_distribution_max_digits_n_{FLAGS.n}_max_length_m_{FLAGS.m}_\" \\\n                      f\"{FLAGS.limit}_\" \\\n                      f\"p_{str(FLAGS.p).replace('.','')}\" \\\n                      f\"{'_reverse_all' if FLAGS.reverse_all else ''}\" \\\n                      f\"{index_hints}\"\n        if FLAGS.extra_path != None:\n            folder_name = f\"{FLAGS.extra_path}/{folder_name}\"\n        print(f\"folder name = {folder_name}\")\n\n        uniform_distribution_sort_main(FLAGS, folder_name)\n        FLAGS.dir_name = folder_name\n\n    if FLAGS.tokenize:\n        if FLAGS.tokenizer_type != \"sort\": # do some automated plotting for each dataset\n            character_histogram(FLAGS.dir_name)\n            print(\"char histogram made\")\n        tokenize_main(FLAGS.dir_name, FLAGS.tokenizer_type, test_split_ratio=FLAGS.test_split_ratio)\n        print(\"tokenized\")\n        if FLAGS.tokenizer_type != \"sort\": # do some automated plotting for each dataset\n            token_histogram(FLAGS.dir_name, FLAGS.tokenizer_type)\n            print(\"token histogram made\")\n            data_analysis_main(FLAGS.dir_name) # more automated analysis\n    else:\n        main_dataset_gen(FLAGS.dir_name, FLAGS.op, FLAGS.n, FLAGS.m, FLAGS.num_samples, FLAGS.exact, FLAGS.keep_places, FLAGS.prepend_zeros, FLAGS.reverse_answer, FLAGS.reverse_all, FLAGS.p, FLAGS.no_carry_addition, FLAGS.seed, interleave=FLAGS.interleave)\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "create_pos_or_variants.py",
    "content": "import numpy as np\nimport argparse\nimport random\nimport os\n\ndef one_hot_vector(length, index=None):\n    \"\"\"return a one hot vector\"\"\"\n    if index is None:\n        index = np.random.randint(length)\n    one_hot = np.zeros(length)\n    one_hot[index] = 1\n    return one_hot\n\ndef zero_vector(length):\n    \"\"\"return a zero vector\"\"\"\n    zeros = np.zeros(length)\n    return zeros\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Train a model\")\n    parser.add_argument(\"--dir_name\", type=str, required=True, help=\"dir to save to\")\n    parser.add_argument(\"--op\", type=str, default='+', help=\"operation\")\n    parser.add_argument(\"--n\", default=2, type=int, help=\"num digits in first number\")\n    parser.add_argument(\"--m\", default=2, type=int, help=\"num digits in second number\")\n    parser.add_argument('--p', default=0.0, type=float, help=\"prob for adding padding\")\n    parser.add_argument(\"--max\", default=-1, type=int, help=\"num digits in second number\")\n    parser.add_argument('--exact', action='store_true', help='only this size')\n    parser.add_argument('--eval', action='store_true', help='save as part of eval dataset')\n    FLAGS = parser.parse_args()\n\n    p = FLAGS.p\n    dir_name = FLAGS.dir_name\n    lengths_n = lengths_n_range = list(range(1,FLAGS.n+1))\n    lengths_m = lengths_m_range = list(range(1,FLAGS.m+1))\n    if FLAGS.exact:\n        lengths_n = [FLAGS.n]\n        lengths_m = [FLAGS.m]\n        \n    ds = []\n    # 2d loop to sample exaustively\n    for i in lengths_n:\n        for j in lengths_m:\n            i_len=i\n            j_len=j\n            combined_len=max(i,j)\n            for index in list(range(0,min(i,j))):\n                if i_len > j_len: # put one hot in longer vector\n                    vec1 = zero_vector(i_len)\n                    vec2 = one_hot_vector(j_len, index)\n                elif i_len < j_len:\n                    vec1 = one_hot_vector(i_len, index)\n                    vec2 = zero_vector(j_len)\n                else: # i.e. same length so either can be the zeros\n                    if random.random() > 0.5:\n                        vec1 = one_hot_vector(i_len, index)\n                        vec2 = zero_vector(j_len)\n                    else:\n                        vec1 = zero_vector(i_len)\n                        vec2 = one_hot_vector(j_len, index)\n                ans = one_hot_vector(combined_len, index)\n\n                vec1_str = \"\".join(map(lambda x: str(int(x)), vec1))\n                vec2_str = \"\".join(map(lambda x: str(int(x)), vec2))\n                ans_str = \"\".join(map(lambda x: str(int(x)), ans))\n\n                dataset_entry = f\"{vec1_str}{FLAGS.op}{vec2_str}={ans_str}\"\n                \n                if p>0: # add random padding, exponentially decaying\n                    spaced_string = \"\"\n                    for char in dataset_entry:\n                        space_p = p\n                        while random.random() < space_p:\n                            space_p *= 0.1\n                            spaced_string += \" \"\n                        spaced_string += char\n                    dataset_entry = spaced_string\n            \n                ds.append(dataset_entry)\n\n    if FLAGS.max != -1:\n        ds = random.sample(ds, min(len(ds),FLAGS.max)) # cut to maximum size\n    if FLAGS.eval:\n        data_dir = f\"./cramming-data/data/arithmetic_data/pos_or_one_vec_zeros/{dir_name}\"\n        file_name = f\"positional_arithmetic_n_{FLAGS.n}_m_{FLAGS.m}.txt\"\n    else:\n        data_dir = f\"./cramming-data/data/arithmetic_data/{dir_name}\"\n        file_name = f\"positional_or_one_vec_zeros_n_{FLAGS.n}_m_{FLAGS.m}_examples_{len(ds)}.txt\"\n    os.makedirs(data_dir, exist_ok=True)\n    file_path = os.path.join(data_dir, file_name)\n\n    with open(file_path, 'w') as file:\n        for entry in ds:\n            file.write(entry + '\\n')\n    print(f\"created: {file_path}\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "dataset_analysis.py",
    "content": "import os\nimport re\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport seaborn as sns\nimport pandas as pd\nimport argparse\n\ndef read_dataset(dir_name, condense_white_space=False):\n    # open all data files and append to big list\n    dataset = []\n    for filename in os.listdir(dir_name):\n        if filename.endswith(\".txt\"):\n            file_path = os.path.join(dir_name, filename)\n            with open(file_path, \"r\") as file:\n                lines = file.readlines()\n                stripped_lines = [line.replace(\"\\n\", \"\") for line in lines]\n                if condense_white_space:\n                    stripped_lines = [re.sub('\\s+',' ', line) for line in lines]\n                dataset.extend(stripped_lines)\n\n    for i in range(0,min(len(dataset),5)):\n        print(dataset[i])\n    return dataset\n\ndef remove_leading_zeros(match):\n    \"\"\"Removes all leading zeros\"\"\"\n    return str(int(match.group(0)))\n\ndef count_digits(dataset, remove_formatting=False):\n    \"\"\"Count the digits in each operand\"\"\"\n    pairs = {}\n    input_1 = {}\n    input_2 = {}\n    ans = {}\n    for input_string in dataset:\n        cleaned_string = input_string.replace(' ', '')\n        if remove_formatting:\n            cleaned_string = re.sub(r'\\b0+\\d+', remove_leading_zeros, cleaned_string)\n\n        numbers = re.findall(r'\\d+', cleaned_string)\n        digit_counts = [len(number) for number in numbers]\n\n        input_1[digit_counts[0]] = input_1.get(digit_counts[0], 0) + 1\n        input_2[digit_counts[1]] = input_2.get(digit_counts[1], 0) + 1\n        ans[digit_counts[2]] = ans.get(digit_counts[2], 0) + 1\n\n        input_tuple = (digit_counts[0], digit_counts[1])\n        pairs[input_tuple] = pairs.get(input_tuple, 0) + 1\n\n    return pairs, input_1, input_2, ans\n\ndef plot_pairs_heatmap(pairs, dir_name=\".\", remove_formatting=False):\n    \"\"\"plot a heatmap of the lengths of the operands\"\"\"\n    max_length = int(max(max(pair) for pair in pairs.keys()))\n    heatmap_matrix = np.zeros((max_length + 1, max_length + 1))\n\n    # Populate the matrix with counts\n    for pair, count in pairs.items():\n        heatmap_matrix[pair[0],pair[1]] = count\n\n    df = pd.DataFrame.from_dict(heatmap_matrix)\n\n    # Create a heatmap using seaborn\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, annot=True, cmap=\"YlGnBu\", fmt=\".4g\", cbar_kws={'label': 'Count'}, annot_kws={'size': 8,'rotation':45})\n    plt.xlabel('Length of First Number')\n    plt.ylabel('Length of Second Number')\n    plt.title('Input Pairs Length Heatmap')\n    plt.savefig(f\"{dir_name}/pairs_heatmap{'_removed_prepended_zeros' if remove_formatting else ''}.png\", bbox_inches='tight')\n    plt.clf()\n\ndef line_plotter(data, name, dir_name=\".\", remove_formatting=False):\n    \"\"\"plot a line graph for the length of the operand \"\"\"\n    data = dict(sorted(data.items()))\n    x_values = list(data.keys())\n    y_values = list(data.values())\n\n    # Plotting the line plot\n    plt.plot(x_values, y_values, marker='o')\n\n    # Adding labels and title\n    plt.xlabel('Length of number')\n    plt.ylabel('Count')\n    plt.title(f\"Line Plot for {name}\")\n    plt.savefig(f\"{dir_name}/{name}_line_plot{'_removed_prepended_zeros' if remove_formatting else ''}.png\", bbox_inches='tight')\n    plt.clf()\n\ndef consecutive_digit_counts(input_strings):\n    \"\"\"Count the number of times a digit is repeated\"\"\"\n    counts_by_digit = {}\n\n    for input_str in input_strings:\n        current_digit = None\n        consecutive_count = 0\n\n        for char in input_str:\n            if char.isdigit():\n                if char == current_digit:\n                    consecutive_count += 1\n                else:\n                    if current_digit is not None:\n                        # Update the dictionary with consecutive count\n                        if consecutive_count != 1:\n                            counts_by_digit.setdefault(current_digit, {}).setdefault(consecutive_count, 0)\n                            counts_by_digit[current_digit][consecutive_count] += 1\n\n                    current_digit = char\n                    consecutive_count = 1\n\n        # Update the dictionary for the last digit in the string\n        if current_digit is not None:\n            if consecutive_count != 1:\n                counts_by_digit.setdefault(current_digit, {}).setdefault(consecutive_count, 0)\n                counts_by_digit[current_digit][consecutive_count] += 1\n\n    return counts_by_digit\n\ndef create_repetition_heatmap(data, dir_name=\".\", remove_formatting=False):\n    \"\"\"plot heat map for, consecutive_digit_counts\"\"\"\n    data = dict(sorted(data.items()))\n    # Convert the dictionary to a DataFrame\n    df = pd.DataFrame.from_dict(data, orient='index').fillna(0)\n\n    # Create a heatmap using seaborn\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, annot=True, cmap=\"YlGnBu\", fmt=\".4g\", cbar_kws={'label': 'Count'}, annot_kws={'size': 8,'rotation':45})\n    plt.title('Consecutive Digit Counts Heatmap')\n    plt.xlabel('Consecutive Count')\n    plt.ylabel('Digit')\n    plt.savefig(f\"{dir_name}/repetition_count_heatmap{'_removed_prepended_zeros' if remove_formatting else ''}.png\", bbox_inches='tight')\n    plt.clf()\n\ndef main(dir_name):\n    base_directory = \"./cramming-data/data/arithmetic_data\"\n    dir_name = os.path.join(base_directory, dir_name)\n    dataset = read_dataset(dir_name)\n\n    options = [True, False]\n    for remove_formatting in options:\n        pairs, input_1, input_2, ans = count_digits(dataset, remove_formatting=remove_formatting)\n        print(f\"{'removed prepended zeros' if remove_formatting else 'keeping prepended zeros'}\")\n        print(\"pairs: \",pairs)\n        print(\"input 1: \",input_1)\n        print(\"input 2: \",input_2)\n        print(\"answers: \",ans)\n\n        plot_pairs_heatmap(pairs, dir_name=dir_name, remove_formatting=remove_formatting)\n        line_plotter(input_1, \"input_1\", dir_name=dir_name, remove_formatting=remove_formatting)\n        line_plotter(input_2, \"input_2\", dir_name=dir_name, remove_formatting=remove_formatting)\n        line_plotter(ans, \"answer\", dir_name=dir_name, remove_formatting=remove_formatting)\n\n        result_list = consecutive_digit_counts(dataset)\n        print(\"repetitions: \",result_list)\n        create_repetition_heatmap(result_list, dir_name=dir_name, remove_formatting=remove_formatting)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Data analysis\")\n    parser.add_argument(\"--dir_name\", type=str, required=True)\n    FLAGS = parser.parse_args()\n\n    main(FLAGS.dir_name)"
  },
  {
    "path": "gen_eval_script.py",
    "content": "# input your model name and base_dir\nname = \"sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_reycle_with_fire_8x1_1_24_run_1\"\nbase_dir = \"cramming-data\"\n\n# pick which eval you are doing\nadd_100 = False\nadd_110+ = False\nadd_small = False\nmul = False\nsort = True\nbitwise_or = False\n\n# set the model parameters for eval\nprint(\"remember to edit max_rec and tokenizer!!\")\nmax_rec = 1\ntokenizer = ' data.sources.arithmetic.tokenizer_type=\"pad\"'\nif sort:\n    tokenizer = ' data.sources.arithmetic.tokenizer_type=\"sort\"'\n\n## print statements for all tasks below\nif add_100:\n    for checkerboard_str in [\" checkerboard=odd\",\" checkerboard=even\"]:\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=55 big_eval_step_1=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=60 big_eval_step_2=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=70 big_eval_step_3=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=85 big_eval_step_4=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=90 big_eval_step_5=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=100 big_eval_step_6=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=100 big_eval_step_7=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=110 big_eval_step_8=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=110 big_eval_step_9=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n        print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=110 big_eval_step_10=True reverse_inputs=True{tokenizer}{checkerboard_str}\")\n\nif add_100:\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=105 big_eval_step_1=True reverse_inputs=True checkerboard=even extended_eval=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=105 big_eval_step_2=True reverse_inputs=True checkerboard=even extended_eval=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=105 big_eval_step_3=True reverse_inputs=True checkerboard=even extended_eval=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=105 big_eval_step_4=True reverse_inputs=True checkerboard=even extended_eval=True{tokenizer}\")\n\nif add_small:\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=30 reverse_inputs=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=35 ood_only=True reverse_inputs=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=45 up_to_40=True reverse_inputs=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=55 up_to_50=True reverse_inputs=True{tokenizer}\")\n\nif mul:\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=30 pos_arth=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=50 pos_arth_ood=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=30 mul=True{tokenizer}\")\n\nif sort:\n    for i in range(0,30):\n        print(f\"python sort_eval.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} sort_reverse=True data.sources.arithmetic.tokenizer_type='sort' max_size_given={i+2} start_ind_1_given={i+1} start_ind_2_given={i+1}\")\n\nif bitwise_or: # we give data to evaluate up to 100x100 as we show in the paper, but the evaluation loop in only arithmetic_eval_quicker.py evaluates up to 40x40. This can be easily edited if required\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=30 pos_arth=True{tokenizer}\")\n    print(f\"python arithmetic_eval_quicker.py name={name} base_dir={base_dir} data=arithmetic max_rec={max_rec} token_limit=50 pos_arth_ood=True{tokenizer}\")\n                    \n\n"
  },
  {
    "path": "load_local_model.py",
    "content": "\"\"\"Example for a script to load a local saved model.\n\nUse as e.g.\n\npython load_local_model.py name=A6000amp_b4096_c5_o3_final base_dir=\n> wandb=none impl.push_to_huggingface_hub=True arch=bert-c5 train=bert-o3 train.batch_size=4096\n> data=c4-subset-processed dryrun=True +eval=GLUE_sane\n\n\"\"\"\nimport os\n\nimport hydra\nimport time\n\nimport logging\n\n\nimport cramming\n\nlog = logging.getLogger(__name__)\n\n\ndef main_load_process(cfg, setup):\n    \"\"\"This function controls the central routine.\"\"\"\n    local_time = time.time()\n\n    local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, \"checkpoints\")\n    tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,\n                                                                                local_checkpoint_folder,\n                                                                                cfg.eval.arch_modifications)\n\n    model = cramming.construct_model(cfg_arch, tokenizer.vocab_size, downstream_classes=None)\n    model_engine, _, _, _ = cramming.load_backend(model, tokenizer, cfg.train, cfg.impl, setup=setup)\n    model_engine.load_checkpoint(cfg_arch, model_file)\n\n    if cramming.utils.is_main_process():\n        if cfg.impl.push_to_huggingface_hub:\n            model_engine.push_to_hub(tokenizer, cfg, dryrun=cfg.dryrun)\n\n\n@hydra.main(config_path=\"cramming/config\", config_name=\"cfg_pretrain\", version_base=\"1.3\")\ndef launch(cfg):\n    cramming.utils.main_launcher(cfg, main_load_process, job_name=\"load and push model\")\n\n\nif __name__ == \"__main__\":\n    launch()\n"
  },
  {
    "path": "pretrain.py",
    "content": "\"\"\"Script for a pretraining run.\"\"\"\n\nimport torch\nimport hydra\n\nimport os\nimport time\nimport datetime\nimport logging\nfrom collections import defaultdict\n\nimport cramming\n\nlog = logging.getLogger(__name__)\n\n\ndef main_training_process(cfg, setup):\n    \"\"\"This function controls the central training loop.\"\"\"\n    model, model_engine, tokenizer, dataloaders, prior_metadata = cramming.backend.get_model_engine_tokenizer_dataloaders(\n        cfg, setup, True)\n\n    data_source = list(cfg.data.sources.values())[0][\"provider\"]\n    stats = defaultdict(list)\n\n    # Start the clocks now:\n    wallclock_timer = time.time()\n    last_save_time = wallclock_timer\n    train_time = time.time()  # Crude time measurement for print_loss_every_nth_step\n    training_allowed = True\n    loss_vals, loss_ppls = [], []\n\n    loss = prior_metadata.get(\"loss\", 0)\n    total_steps = prior_metadata.get(\"steps\", 0)\n    epochs = prior_metadata.get(\"epochs\", 0)\n    elapsed_time = prior_metadata.get(\"elapsed_time\", 0.0)\n    prev_data_idx = prior_metadata.get(\"data_idx\", 0)\n\n    # Launch training\n    log.info(f\"Training run for {cfg.budget} hours{f'' if cfg.overall_budget < 0 else f' and {cfg.overall_budget} hours overall'}{f'' if elapsed_time <= 0 else f' of which {elapsed_time/3600:.2f} hours was used so far.'}\")\n    run_time = min(cfg.budget, cfg.overall_budget - elapsed_time/3600)\n    log.info(f\"Running for {run_time:.2f} hours\")\n    if run_time <= 0:\n        log.info(f\"Already used budget!\")\n        return {}\n\n    for data_idx, batch in enumerate(dataloaders[\"train\"], prev_data_idx):\n        logged_stats = False\n\n        device_batch = model_engine.to_device(batch)\n        model_outputs = {}\n        for seq_idx in range(0, max(1, device_batch[\"input_ids\"].shape[1] - cfg.train.stream_depth), cfg.train.stream_depth):\n            # Run over seq_dim and dispatch multiple model updates while maintaining state in model_outputs\n            # .clone() is required for new nightly so compilation is not stuck recompiling due to StorageOffsets\n            input_ids = device_batch[\"input_ids\"][:, seq_idx: seq_idx + cfg.train.stream_depth + 1].clone()  # last token is only a target\n            model_outputs = model_engine.forward(input_ids=input_ids, **model_outputs)\n            loss = model_outputs[\"loss\"]\n\n            model_engine.backward(loss)\n            model_engine.optimizer_step()\n            loss_vals.append(loss.detach())\n            loss_ppls.append(model_outputs[\"log_perplexity\"].detach())\n\n            if cfg.dryrun:\n                break\n\n        # Check stopping criteria\n        if check_deadline(wallclock_timer, cfg.budget, elapsed_time, cfg.overall_budget) or data_idx == cfg.train.steps:\n            training_allowed = False\n\n            log.info(f\"Reached deadline: Used {get_time_elapsed(wallclock_timer)/3600:.2f}/{cfg.budget} hours {'' if cfg.overall_budget < 0 else f' since reset and {get_time_elapsed(wallclock_timer, elapsed_time)/3600:.2f}/{cfg.overall_budget} hours overall'}. \"\n                     f\"Stopping training ...\")\n                     \n        if check_checkpointing(data_idx, cfg.impl, last_save_time):\n            if cramming.utils.is_main_process():\n                loss_vals, loss_ppls, train_time = collect_stats(\n                    data_idx,\n                    loss_vals,\n                    loss_ppls,\n                    model_outputs,\n                    train_time,\n                    stats,\n                    model_engine,\n                    dataloaders[\"train\"],\n                    cfg,\n                )\n                logged_stats = True\n\n                # Save intermediate training checkpoint?\n                epochs = dataloaders[\"train\"].epoch_counter\n                last_save_time = time.time()\n                last_save_time_datetime = datetime.datetime.fromtimestamp(last_save_time)\n                if cfg.impl.save_intermediate_model_name is None:\n                    # if name is given use it (will overwrite), else use time to save\n                    checkpoint_name = f\"{cfg.arch.model_type}_{last_save_time_datetime.strftime('%Y-%m-%d')}_{loss.item():2.4f}\"\n                else:\n                    checkpoint_name = cfg.impl.save_intermediate_model_name\n                checkpoint_path = os.path.join(cfg.model_dir, cfg.name, \"checkpoints\")\n\n                metadata = {\"epochs\": epochs,\n                            \"loss\": loss.item(),\n                            \"data_idx\": data_idx,\n                            \"steps\": model_engine.steps,\n                            \"elapsed_time\": (time.time() - wallclock_timer) + elapsed_time\n                            }\n\n                saved_path_temp = model_engine.save_model(checkpoint_path, checkpoint_name, cfg.arch, metadata)\n                log.info(\n                    f\"Saving training checkpoint! Number of epochs/optim steps/data steps trained for: {epochs}/{model_engine.steps}/{data_idx},\"\n                    f\"saving to: {saved_path_temp}\")\n\n                if cfg.impl.push_to_huggingface_hub:\n                    model_engine.push_to_hub(tokenizer, cfg, dryrun=cfg.dryrun)\n\n        # Collect stats and print to console and upload to wandb\n        if data_idx % cfg.impl.print_loss_every_nth_step == 0:\n            if not logged_stats:\n                loss_vals, loss_ppls, train_time = collect_stats(\n                    data_idx,\n                    loss_vals,\n                    loss_ppls,\n                    model_outputs,\n                    train_time,\n                    stats,\n                    model_engine,\n                    dataloaders[\"train\"],\n                    cfg,\n                )\n\n            if check_early_termination(wallclock_timer, stats[\"loss\"][-1], cfg.impl.early_termination, elapsed_time):\n                training_allowed = False\n                log.info(\"Loss higher than allowed threshold. Stopping training early...\")\n\n        if not loss.detach().isfinite():\n            log.info(f\"Non-finite loss in block {data_idx} on device {cfg.impl.local_rank}.\")\n            training_allowed = False\n\n        flag_communication(training_allowed)\n\n        if (cfg.dryrun and data_idx > (model_engine.accumulation_steps_expected + 1)) or not training_allowed:\n            break\n\n    epochs = dataloaders[\"train\"].epoch_counter\n    log.info(f\"Number of epochs/optim steps/data steps trained for: {epochs}/{model_engine.steps}/{data_idx}\")\n\n    if cramming.utils.is_main_process():\n        # Save final checkpoint?\n        if cfg.impl.save_final_model:\n            metadata = {\"epochs\": epochs,\n                        \"loss\": loss.item(),\n                        \"data_idx\": data_idx,\n                        \"steps\": model_engine.steps,\n                        \"elapsed_time\": time.time() - wallclock_timer + elapsed_time\n                        }\n                        \n            if cfg.model_dir is None:\n                save_dir = cfg.base_dir\n            else:\n                save_dir = cfg.model_dir\n            checkpoint_path = os.path.join(save_dir, cfg.name, \"checkpoints\")\n            checkpoint_name = f\"FINAL_{loss.item():2.4f}\"\n            saved_path = model_engine.save_model(checkpoint_path, checkpoint_name, cfg.arch, metadata, None, save_safe=True)\n\n            log.info(f\"Saving training checkpoint to: {saved_path}\")\n\n            if cfg.impl.push_to_huggingface_hub:\n                model_engine.push_to_hub(tokenizer, cfg, dryrun=cfg.dryrun)\n            \n            # Print some example completions\n        if loss.detach().isfinite():\n            generate(model_engine, tokenizer, cfg.impl.example_prompts, token_limit=cfg.impl.example_token_limit)\n    \n    # Save to summary:\n    if loss.detach().isfinite():\n        validation_log_p = validate(model_engine, dataloaders[\"test\"], setup, cfg)\n    else:\n        validation_log_p = float(\"Inf\")\n    log.info(f\"Log-Perplexity on validation data is {validation_log_p:2.4f}.\")\n    metrics = dict(\n        validation_log_ppl=validation_log_p,\n        validation_ppl=torch.as_tensor(validation_log_p).exp().item(),\n        num_params=sum([p.numel() for p in model.parameters()]),\n    )\n\n    return metrics\n\n\ndef get_time_elapsed(start_time: float, additional_time: float = 0.0) -> float:\n    return time.time() - start_time + additional_time\n\ndef check_checkpointing(data_idx: int, cfg_impl, last_save_time) -> bool:\n    step_condition = cfg_impl.save_every_nth_step > 0 and (data_idx % cfg_impl.save_every_nth_step == 0)\n    time_condition = cfg_impl.save_every_n_minutes > 0 and (time.time() - last_save_time) / 60 > cfg_impl.save_every_n_minutes\n    return cfg_impl.save_intermediate_checkpoints and (step_condition or time_condition)\n\n\ndef check_deadline(launch_time, hour_limit, prev_budget: float = 0.0, overall_hour_limit: float = 0.0):\n    \"\"\"These measurements are deliberately wall-clock based.\"\"\"\n    current_time = time.time()\n    overall_budget = overall_hour_limit if overall_hour_limit >= 0 else hour_limit\n    current_violated = (current_time - launch_time) / 3600 > hour_limit\n    overall_violated = (prev_budget + (current_time - launch_time)) / 3600 > overall_budget\n    return current_violated or overall_violated\n\n\ndef check_early_termination(start_time, loss, early_termination, prev_budget: float = 0.0):\n    \"\"\"Early termination based on terrible loss.\"\"\"\n    if early_termination.enabled and loss > early_termination.loss_threshold:\n        current_time = time.time()\n        overall_budget = early_termination.overall_budget if early_termination.overall_budget > 0 else early_termination.budget\n        current_violated = (current_time - start_time) / 3600 > early_termination.budget\n        overall_violated = (prev_budget + (current_time - start_time)) / 3600 > overall_budget\n        return current_violated or overall_violated\n    else:\n        return False\n\n\ndef collect_stats(data_step, loss_vals, log_ppls, model_outputs, train_time, stats, model_engine, dataloader, cfg):\n    \"\"\" \"data_step\" here refers to one step on the dataloader, which may be multiple steps on the model_engine.\"\"\"\n    stats[\"data_step\"] += [data_step]\n    stats[\"epoch\"] += [dataloader.epoch_counter]\n    stats[\"model_steps\"] += [model_engine.steps]\n\n    tokens_per_step = model_engine.record_tokens_per_step()\n    stats[\"tokens\"] += [data_step * tokens_per_step]\n    stats[\"loss\"] += [torch.stack(loss_vals).mean().item()]  # Averaged loss\n    stats[\"log_ppl\"] += [torch.stack(log_ppls).mean().item()]  # Averaged loss\n    if \"losses\" in model_outputs:\n        for key, acccum_loss in model_outputs[\"losses\"].items():\n            if key != \"count\":\n                stats[key] += [acccum_loss.item()]\n    if \"logits\" in model_outputs:\n        try:\n            precise_logits = model_outputs[\"logits\"].to(dtype=torch.float32)\n            stats[\"entropy\"] += [torch.distributions.Categorical(torch.softmax(precise_logits, dim=-1)).entropy().mean().item()]\n        except ValueError:\n            stats[\"entropy\"] += [float(\"NaN\")]  # can happen if invalid values in logits, or softmax numerical issues\n\n    current_lr = model_engine.optimizer.param_groups[0].get(\"lr\", float(\"NaN\"))\n    log_msg = f\"Train loss {loss_vals[-1].item():2.4f} at data block {data_step} with lr {current_lr:.5f}. \"\n    log_msg += f\"[Avg: {stats['loss'][-1]:2.4f}] \"\n    if data_step > 0:\n        stats[\"train_time\"] += [(time.time() - train_time) / cfg.impl.print_loss_every_nth_step]\n        estimated_train_finish = str(datetime.timedelta(seconds=stats[\"train_time\"][-1] * cfg.train.steps))\n        tokens_per_second = tokens_per_step / stats[\"train_time\"][-1]\n        stats[\"tok/sec\"] += [int(tokens_per_second)]\n        log_msg += f\" Perf: {stats['train_time'][-1]:2.4f}s per block ({tokens_per_second:.0f}t/s). \"\n        # log_msg += f\"Est.for all sched. blocks: {estimated_train_finish}.\"\n\n    # Adaptive optim stats\n    stats[\"lr\"] += [current_lr]\n    stats[\"batch_size\"] += [model_engine.record_batch_size()]\n    stats[\"seq_length\"] = [model_engine.current_seq_length]\n\n    # Publish\n    cramming.utils.wandb_log(stats, cfg)\n    log.info(log_msg)\n\n    # Clear:\n    loss_vals, log_ppls = [], []\n    train_time = time.time()\n    return loss_vals, log_ppls, train_time\n\n\n@torch.no_grad()\ndef validate(model_engine, validloader, setup, cfg):\n    \"\"\"Evaluate on validation set.\"\"\"\n    log.info(\"Starting model validation.\")\n    model_engine.eval()\n    val_timer = time.time()\n    # Cut up smaller streams so the inductor doesn't break, but keep parallelizable archs at full depth:\n    eval_depth = 1 if cfg.train.stream_depth < cfg.data.seq_length else cfg.data.seq_length\n\n    log_perplexity = 0\n    len_validloader = len(validloader)\n\n    for step, batch in enumerate(validloader):\n        device_batch = model_engine.to_device(batch)\n        seq_len = max(1, device_batch[\"input_ids\"].shape[1] - eval_depth)\n        num_entries = len(range(0, seq_len))\n        # Stream over sequence\n        model_outputs = {}\n        for seq_idx in range(0, seq_len, eval_depth):\n            input_ids = device_batch[\"input_ids\"][:, seq_idx : seq_idx + eval_depth + 1].clone()  # last token is used as target\n            model_outputs = model_engine.forward(input_ids=input_ids, **model_outputs)\n            log_perplexity += model_outputs.get(\"log_perplexity\", model_outputs[\"loss\"].detach()) / num_entries\n            if cfg.dryrun:\n                break\n\n        if step % cfg.impl.print_loss_every_nth_step == 0:\n            log_msg = f\"Avg Log-Perplexity: {log_perplexity/(step + 1):2.4f} at step {step} \"\n            if step > 1:\n                validation_time = (time.time() - val_timer) / cfg.impl.print_loss_every_nth_step\n                estimated_train_finish = str(datetime.timedelta(seconds=validation_time * len(validloader)))\n                tokens_per_step = cramming.utils.num_processes() * model_engine.record_tokens_per_step()\n                tokens_per_second = tokens_per_step / validation_time\n                log_msg += f\" Perf: {validation_time:2.4f}s per step ({tokens_per_second:.0f}t/s). \"\n                log_msg += f\"Estimated Total validation Time: {estimated_train_finish}.\"\n\n            val_timer = time.time()\n            log.info(log_msg)\n        \n        if step > 200000: # putting hard limit of 200,000 steps for validation\n            len_validloader = step\n            break\n\n        if cfg.dryrun:\n            break\n\n    model_engine.train(cfg.train.pretrain_in_train_mode)\n    return log_perplexity.item() / len_validloader\n\n\ndef generate(model_engine, tokenizer, example_prompts, token_limit=10, temp=1.0):\n    model_engine.eval()\n    # Just do a dumb generation for now, can implement efficient generation later\n    for prompt in example_prompts:\n\n        tokenized_inputs = torch.as_tensor(tokenizer(prompt)[\"input_ids\"], dtype=torch.long)[None, :]#-1]  # cut off EOT NOT ALWAYS SAFE\n        print(\"tokenised input is \",tokenized_inputs)\n        device_inputs = model_engine.to_device(dict(input_ids=tokenized_inputs))[\"input_ids\"]\n        print(\"device inputs: \", device_inputs)\n        # Generate new tokens\n        predicted_ids = model_engine.dynamic_generation(device_inputs, temperature=temp, token_limit=token_limit)\n        print(\"predicted ids: \", predicted_ids, \" with length \", predicted_ids.shape)\n        # print(type(predicted_ids[0]))\n        decoded_completion = tokenizer.decode(predicted_ids[0].tolist())  # drop batch dim before decoding\n\n        log.info(f\"[{prompt}] {decoded_completion}\")\n\n\ndef flag_communication(training_allowed):\n    \"\"\"A quick and dirty communication through NCCL. Should not be a major burden.\"\"\"\n    if torch.distributed.is_initialized():\n        comm_tensor = torch.as_tensor(training_allowed).cuda()\n        torch.distributed.all_reduce(comm_tensor, torch.distributed.ReduceOp.MIN, async_op=False)\n        if comm_tensor >= 1:\n            return True\n        else:\n            return False\n    else:\n        return training_allowed\n\n\n@hydra.main(config_path=\"cramming/config\", config_name=\"cfg_pretrain\", version_base=\"1.3\")\ndef launch(cfg):\n    cramming.utils.main_launcher(cfg, main_training_process, job_name=\"pretraining\")\n\n\nif __name__ == \"__main__\":\n    launch()\n"
  },
  {
    "path": "pretty_plotter.py",
    "content": "## combine multiple testing plots and make a pretty one \n\nimport os\nimport numpy as np\nimport json\nimport matplotlib.patches as patches\nimport matplotlib.pyplot as plt\nimport pandas as pd\nimport seaborn as sns\nfrom omegaconf import OmegaConf\n\ndef find_file(starting_directory, target_file):\n    \"\"\"Find target_file in the tree from starting_directory\"\"\"\n    for root, dirs, files in os.walk(starting_directory):\n        if target_file in files:\n            return os.path.join(root, target_file)\n\ndef grid_plotter(data, type=\"accs\", path=\"\", title=None, rect_size=20, up_to_50=False):\n    \"\"\"plot the 2d grid (up to 50x50)\"\"\"\n    if title is None:\n        title = \"All numbers are percetanges rounded to 1dp\"\n    data = np.array(data)*100\n    df = pd.DataFrame(data)\n\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, annot=True, cmap=\"YlGnBu\", fmt=\".0f\", annot_kws={'size': 8,'rotation':0})\n    if up_to_50:\n        rect = patches.Rectangle((0, 0), rect_size, rect_size, linewidth=1.5, edgecolor='red', facecolor='none')\n    else:\n        rect = patches.Rectangle((0, 0), rect_size, rect_size, linewidth=1, edgecolor='red', facecolor='none')\n    plt.gca().add_patch(rect)\n    rect_size = data.shape[0]\n    plt.xticks(np.arange(1, rect_size+1) - 0.5, labels=np.arange(1, rect_size+1), rotation=90, fontsize=10)\n    plt.yticks(np.arange(1, rect_size+1) - 0.5, labels=np.arange(1, rect_size+1), rotation=0, fontsize=10)\n    \n    # Customize the plot\n    plt.title(title)\n    plt.ylabel(\"1st Number Length\")\n    plt.xlabel(\"2nd Number Length\")\n    \n    plt.savefig(f\"{path}combined_{type}_grid_plot{'_50' if up_to_50 else ''}\", bbox_inches='tight', dpi=300)\n    plt.clf()\n\ndef main():\n    # replace with model name\n    model_name = \"cramming-data/add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_run_1\"\n\n    file_path = f\"{model_name}/downstream\"\n    # get latest checkpoint for the model data\n    config_path = f\"{model_name}/checkpoints\"\n    all_checkpoints = [f for f in os.listdir(config_path)]\n    checkpoint_paths = [os.path.join(config_path, c) for c in all_checkpoints]\n    checkpoint_name = max(checkpoint_paths, key=os.path.getmtime)\n    with open(os.path.join(checkpoint_name, \"model_config.json\"), \"r\") as file:\n        cfg_arch = OmegaConf.create(json.load(file))\n    max_rec = cfg_arch['maximal_recurrence']\n    layers_in_block = cfg_arch['layers_in_recurrent_block']\n    mask_bf_eq = cfg_arch['mask_before_equals']\n    attn_type = cfg_arch['attention']['type']\n    loss_reduc = cfg_arch['loss_reduction']\n    throttle = cfg_arch['throttle']\n    title = f\"Model name:\\n{model_name[14:]}\\nNum layers in block: {layers_in_block}, Num blocks in training: {max_rec}\\n Mask all before equals: {mask_bf_eq}, Train time: 24 hr\\n attn: {attn_type}, temp: Greedy{', loss: 'if loss_reduc == 'none' else ''}{', throttle' if throttle else ''}\"\n\n    # works up in tiers starting from the smallest grid (large) up to the largest for this size (up_to_50)\n    large_path = find_file(file_path, f\"accs_grid_quick_large.json\")\n    with open(large_path, 'r') as file:\n        data = json.load(file)\n    large_data = np.array(data)\n\n    ood_path = find_file(file_path, f\"accs_grid_quick_ood_only.json\")\n    with open(ood_path, 'r') as file:\n        data = json.load(file)\n    ood_data = np.array(data)\n\n    num_rows_to_add = ood_data.shape[0] - large_data.shape[0]\n    num_cols_to_add = ood_data.shape[1] - large_data.shape[1]\n\n    padded_array = np.pad(large_data, ((0, num_rows_to_add), (0, num_cols_to_add)), mode='constant', constant_values=0)\n    combined = padded_array+ood_data\n\n    rect_size=20\n    path_40 = find_file(file_path, f\"accs_grid_quick_up_to_40.json\")\n    if path_40 is not None:\n        with open(path_40, 'r') as file:\n            data = json.load(file)\n        data_40 = np.array(data)\n        num_rows_to_add = data_40.shape[0] - combined.shape[0]\n        num_cols_to_add = data_40.shape[1] - combined.shape[1]\n        padded_array = np.pad(combined, ((0, num_rows_to_add), (0, num_cols_to_add)), mode='constant', constant_values=0)\n        combined = padded_array+data_40\n\n    path_50 = find_file(file_path, f\"accs_grid_quick_up_to_50.json\")\n    up_to_50 = False\n    if path_50 is not None:\n        with open(path_50, 'r') as file:\n            data = json.load(file)\n        data_50 = np.array(data)\n        num_rows_to_add = data_50.shape[0] - combined.shape[0]\n        num_cols_to_add = data_50.shape[1] - combined.shape[1]\n        padded_array = np.pad(combined, ((0, num_rows_to_add), (0, num_cols_to_add)), mode='constant', constant_values=0)\n        combined = padded_array+data_50\n        up_to_50 = True\n        \n    grid_plotter(combined, type=\"accs\", path=f\"{file_path}/\", title=title, rect_size=rect_size, up_to_50=up_to_50)\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "pretty_plotter_big.py",
    "content": "## combine multiple testing plots and make a pretty one \n\nimport os\nimport numpy as np\nimport json\nimport matplotlib.patches as patches\nimport matplotlib.pyplot as plt\nimport pandas as pd\nimport seaborn as sns\nfrom omegaconf import OmegaConf\nimport glob\nimport re\n\ndef grid_plotter(data, type=\"accs\", path=\"\", title=None, rect_size=20):\n    \"\"\"Plot the large 100x100 grid\"\"\"\n    if title is None:\n        title = \"All numbers are percetanges rounded to 1dp\"\n    data = np.array(data)*100\n    df = pd.DataFrame(data)\n\n    plt.figure(figsize=(10, 8))\n    annotate = False\n    # use interpolant\n    sns.heatmap(df, annot=annotate, cmap=\"YlGnBu\", fmt=\".0f\", annot_kws={'size': 8,'rotation':0})\n\n    rect = patches.Rectangle((0, 0), rect_size, rect_size, linewidth=1.8, edgecolor='red', facecolor='none')\n    plt.gca().add_patch(rect)\n    rect_size = data.shape[0]\n    plt.xticks(np.arange(1, rect_size+1, 2) - 0.5, labels=np.arange(1, rect_size+1, 2), rotation=90, fontsize=10)\n    plt.yticks(np.arange(1, rect_size+1, 2) - 0.5, labels=np.arange(1, rect_size+1, 2), rotation=0, fontsize=10)\n    \n    # Customize the plot\n    plt.title(title)\n    plt.ylabel(\"1st Number Length\")\n    plt.xlabel(\"2nd Number Length\")\n    \n    plt.savefig(f\"{path}combined_accs_grid_plot_big_run\", bbox_inches='tight', dpi=300)\n    plt.clf()\n\ndef main():\n    # replace with your model name\n    model_name = \"cramming-data/add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_with_skip_connections_run_1\"\n    rect_size = 20\n\n    directory_path = f\"{model_name}/downstream\"\n    # get latest checkpoint for the model data\n    config_path = f\"{model_name}/checkpoints\"\n    all_checkpoints = [f for f in os.listdir(config_path)]\n    checkpoint_paths = [os.path.join(config_path, c) for c in all_checkpoints]\n    checkpoint_name = max(checkpoint_paths, key=os.path.getmtime)\n    with open(os.path.join(checkpoint_name, \"model_config.json\"), \"r\") as file:\n        cfg_arch = OmegaConf.create(json.load(file))\n    max_rec = cfg_arch['maximal_recurrence']\n    layers_in_block = cfg_arch['layers_in_recurrent_block']\n    mask_bf_eq = cfg_arch['mask_before_equals']\n    attn_type = cfg_arch['attention']['type']\n    loss_reduc = cfg_arch['loss_reduction']\n    throttle = cfg_arch['throttle']\n    title = f\"Model name:\\n{model_name[14:]}\\nNum layers in block: {layers_in_block}, Num blocks in training: {max_rec}\\n Mask all before equals: {mask_bf_eq}, Train time: 24 hr\\n attn: {attn_type}, temp: Greedy{', loss: 'if loss_reduc == 'none' else ''}{', throttle' if throttle else ''}\"\n\n\n    # Define the pattern to search for\n    file_pattern = directory_path + \"/accs_grid_quick_big_eval_?_even.json\"\n    matching_files_even = glob.glob(file_pattern, recursive=True)\n    file_pattern = directory_path + \"/accs_grid_quick_big_eval_??_even.json\"\n    matching_files_even += glob.glob(file_pattern, recursive=True)\n\n    file_pattern = directory_path + \"/accs_grid_quick_big_eval_?_odd.json\"\n    matching_files_odd = glob.glob(file_pattern, recursive=True)\n    file_pattern = directory_path + \"/accs_grid_quick_big_eval_??_odd.json\"\n    matching_files_odd += glob.glob(file_pattern, recursive=True)\n\n    # Print the matching files\n    number_pattern_even = re.compile(r'accs_grid_quick_big_eval_(\\d+)_even.json')\n    number_pattern_odd = re.compile(r'accs_grid_quick_big_eval_(\\d+)_odd.json')\n\n    # Print the matching files and the numbers extracted from them\n    file_paths = []\n    even_nums = []\n    odd_nums = []\n\n    for file_path in matching_files_even:\n        match = number_pattern_even.search(file_path)\n        if match:\n            number = match.group(1)\n            if number not in even_nums:\n                even_nums.append(number)\n                print(\"Number:\", number)\n            else:\n                continue\n        print(\"File:\", file_path)\n        file_paths.append(file_path)\n\n    for file_path in matching_files_odd:\n        match = number_pattern_odd.search(file_path)\n        if match:\n            number = match.group(1)\n            if number not in odd_nums:\n                odd_nums.append(number)\n                print(\"Number:\", number)\n            else:\n                continue\n        print(\"File:\", file_path)\n        file_paths.append(file_path)\n\n    arr = np.zeros((100, 100))\n    for file_path in file_paths:\n        with open(file_path, 'r') as file:\n            data = json.load(file)\n            if len(data) == 3:\n                data = data[0]\n        arr = arr + np.array(data)\n        \n    title = title + \"\\n Even: \"+', '.join(sorted(even_nums, key=lambda x: int(x))) + \"\\n Odd: \"+', '.join(sorted(odd_nums, key=lambda x: int(x)))\n    grid_plotter(arr, type=type, path=f\"{directory_path}/\", title=title, rect_size=rect_size)\n    print(f\"{model_name}\")\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "pretty_plotter_sort.py",
    "content": "import numpy as np\nimport os\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport cv2\n\ndef grid_plotter(data, title=\"\", path=None):\n    data = np.array(data)\n    df = pd.DataFrame(data)\n\n    # find the average accuracy\n    avg = np.mean(data)\n\n    # Create the heatmap\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, annot=True, cmap=\"YlGnBu\", fmt=\".1f\", annot_kws={'size': 8, 'rotation': 0}, vmin=0, vmax=100)\n\n    # Customize the plot\n    plt.title(f\"Accuracy - percetange, rounded to 1dp : {title}, Avg acc: {avg}\")\n    plt.ylabel(\"Maximum n-digit number (1-n)\")\n    plt.xlabel(\"Length of array to sort\")\n    size = data.shape[0]\n    plt.xticks(np.arange(0.5, size + 0.5, 1), labels=np.arange(1, size + 1, 1))\n    plt.yticks(np.arange(0.5, size + 0.5, 1), labels=np.arange(1, size + 1, 1))\n\n    plt.savefig(f\"{path}\", bbox_inches='tight')\n    plt.clf()\n\n\ndef run(names, short_hand, base_dir, sort_plots_path):\n    os.makedirs(sort_plots_path, exist_ok=True)\n    all_data_acc_dict = {}\n    all_data_top_1_acc_dict = {}\n\n    for i in range(len(names)):\n        name = names[i]\n        extra_name = short_hand[i]\n        dict_key = extra_name[0]\n        extra_name = extra_name[0] + \"_\" + extra_name[1]\n        all_data_path = base_dir + name + \"/downstream/\"\n\n        # get all the directories in the path that start with all_outputs\n        all_dirs = os.listdir(all_data_path)\n        # remove the ones that are not directories\n        all_dirs = [dir for dir in all_dirs if os.path.isdir(all_data_path + dir)]\n        all_images = []\n        for dir in all_dirs:\n            if \"all_outputs\" in dir:\n                # get the recurrence\n                recurrence = dir.split(\"_\")[-1]\n                if \"recurrence\" not in recurrence:\n                    continue\n\n                # get all the files in the directory\n                files = os.listdir(all_data_path + dir + \"/\")\n                all_images_local = []\n\n                all_data_acc = {}\n                all_data_top_1_acc = {}\n                max_size = 0\n\n                print(extra_name)\n                print(\"dir\", dir)\n\n                for file in files:\n                    if \".txt\" in file:\n                        all_info = file.split(\".\")[0]\n                        all_info = all_info.split(\"_\")\n                        data_size_1 = int(all_info[-2])\n                        data_size_2 = int(all_info[-1])\n\n                        if data_size_1 > max_size:\n                            max_size = data_size_1\n                        if data_size_2 > max_size:\n                            max_size = data_size_2\n\n                        # get the accuracy\n                        with open(all_data_path + dir + \"/\" + file, \"r\") as f:\n                            acc = float(f.read())\n                            if \"top_1_acc\" in file:\n                                all_data_top_1_acc[(data_size_1, data_size_2)] = acc\n                            else:\n                                all_data_acc[(data_size_1, data_size_2)] = acc\n\n                # create the grid plot\n                data = np.zeros((max_size, max_size))\n                for key in all_data_acc.keys():\n                    data[key[0] - 1][key[1] - 1] = all_data_acc[key]\n                grid_plotter(data,\n                            title=f\"{extra_name} {recurrence} acc\",\n                            path=f\"./{sort_plots_path}/{extra_name}_{recurrence}_acc.png\")\n\n                if dict_key not in all_data_acc_dict.keys():\n                    all_data_acc_dict[dict_key] = []\n                    all_data_top_1_acc_dict[dict_key] = []\n\n                all_data_acc_dict[dict_key].append(data)\n\n                data = np.zeros((max_size, max_size))\n                for key in all_data_top_1_acc.keys():\n                    data[key[0] - 1][key[1] - 1] = all_data_top_1_acc[key]\n                grid_plotter(data,\n                            title=f\"{extra_name} {recurrence} top_1_acc\",\n                            path=f\"./{sort_plots_path}/{extra_name}_{recurrence}_top_1_acc.png\")\n\n                all_data_top_1_acc_dict[dict_key].append(data)\n\n\n                all_images_local.append(cv2.imread(f\"./{sort_plots_path}/{extra_name}_{recurrence}_acc.png\"))\n                all_images_local.append(cv2.imread(f\"./{sort_plots_path}/{extra_name}_{recurrence}_top_1_acc.png\"))\n                all_images_local = cv2.hconcat(all_images_local)\n                # write this image\n                all_images.append((all_images_local, f\"{extra_name}_{recurrence}.png\"))\n\n        os.makedirs(f\"./{sort_plots_path}/final/\", exist_ok=True)\n        if len(all_images) == 1:\n            all_images_local, name = all_images[0]\n            cv2.imwrite(f\"./{sort_plots_path}/final/{name}\", all_images_local)\n        else:\n            os.makedirs(f\"./{sort_plots_path}/final/{extra_name}/\", exist_ok=True)\n            for all_images_local, name in all_images:\n                cv2.imwrite(f\"./{sort_plots_path}/final/{extra_name}/{name}\", all_images_local)\n\nif __name__ == \"__main__\":\n    names = [\"sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_with_fire_8x1_1_24_run_1\"]\n    short_hand = [(\"rev_abacus_fire_8x1\", \"v1\")] # the shrothand names for the runs you want to plot in the same order\n\n    base_dir = \"cramming-data/\"\n    sort_plots_path = \"./sort_plots/\"\n    run(names, short_hand, base_dir, sort_plots_path)"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.black]\nline-length = 140\n"
  },
  {
    "path": "setup.cfg",
    "content": "\n\n[metadata]\nname = cramming\nversion = 0.1.0\nauthor = Sean McLeish\nauthor_email = smcleish@umd.edu\nurl = https://github.com/mcleish7/arithmetic\ndescription = Fork of cramming for next token predicition\nlong_description = file: README.md, LICENSE.md\nlong_description_content_type = text/markdown\nlicense = MIT\nlicense_file = LICENSE.md\nplatform = any\nkeywords = Machine Learning, Language Modeling\nclassifiers =\n    License :: OSI Approved :: MIT License\n    Operating System :: OS Independent\n    Programming Language :: Python\nhomepage = \"https://github.com/mcleish7/arithmetic\"\nrepository = \"https://github.com/mcleish7/arithmetic\"\ndocumentation = \"\"\"\n\n[options]\nzip_safe = False\ninclude_package_data = True\npython_requires = >= 3.10\npackages = find:\n\nsetup_requires =\n    setuptools\n\ninstall_requires =\n    torch >= 2.0.0\n    hydra-core >= 1.1\n    datasets\n    tokenizers\n    transformers\n    evaluate\n    scipy\n    scikit-learn # for metrics\n    pynvml\n    psutil\n    einops\n    safetensors\n    apache-beam  # only used for wikipedia ...\n    zstandard    # only used for the Pile\n    wandb # if you want to use it\n    matplotlib==3.8.3 # the versions of plt and sns are fixed for annotating the heatmaps\n    seaborn==0.13.2\n    opencv-python\n\nscripts =\n  pretrain.py\n  arithmetic_eval_quicker.py\n\n[options.package_data]\n* =  \"*.yaml\", \"*.txt\"\n\n\n[check-manifest]\nignore =\n    .ipynb\n    .sh\n\n\n#basically the pytorch flake8 setting from https://github.com/pytorch/pytorch/blob/master/.flake8\n[flake8]\nselect = B,C,E,F,P,T4,W,B9\nmax-line-length = 140\n# C408 ignored because we like the dict keyword argument syntax\n# E501 is not flexible enough, we're using B950 instead\nignore =\n    E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,\nper-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950\noptional-ascii-coding = True\nexclude =\n    .git,\n    __pycache__,\n    scripts,\n    tables,\n    outputs,\n    *.pyi\n"
  },
  {
    "path": "shells/addition_ff.sh",
    "content": "## FF\n# nope\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None\n\n# fire\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_fire_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" \n\n# abacus\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus\n\n## FF w/ II\n# nope\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.forward_only_model_with_skip=True\n# fire\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_fire_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"  arch.forward_only_model_with_skip=True\n# abacus\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus arch.forward_only_model_with_skip=True\n\n\n## FF w/ II\n# Abacus + FIRE\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_abacus_attn_emb_fire_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" arch.forward_only_model_with_skip=True arch.embedding.pos_embedding=abacus \n# Abacus + RoPE\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_abacus_attn_emb_rope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus arch.forward_only_model_with_skip=True arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=true"
  },
  {
    "path": "shells/addition_lt.sh",
    "content": "### Looped Transformer experiments\n# vary number of layers in recurrent_block: arch.layers_in_recurrent_block\n# vary number of recurrences: arch.maximal_recurrence\n\n# NOPE\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None\n# FIRE\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_fire_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" \n# ABACUS\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/+_bucket_method_n_20_m_20_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus\n"
  },
  {
    "path": "shells/bitwise_or.sh",
    "content": "# bitwise or is sometimes refered to as pos_arth in the code\n\n## LT\n# NOPE\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None\n#  FIRE\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_nope_attn_emb_fire_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"\n# abacus\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_1_16_TBPTT_1024_batch_size_512_mask_before_equals_true_start_emb_abacus_attn_emb_nope_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=16 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus\n\n## FF\n#nope\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_nope_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.forward_only_model_with_skip=True\n# fire\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_nope_attn_emb_fire_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" arch.forward_only_model_with_skip=True\n# abacus\npython pretrain.py name=pos_or_one_vec_zeros_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_abacus_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=1 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/or_one_vec_zeros/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus arch.forward_only_model_with_skip=True\n\n## FF w/ II\n# nope\npython pretrain.py name=pos_or_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_nope_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/pos_arith_add_20_20_p_00/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.forward_only_model_with_skip=True\n# fire\npython pretrain.py name=pos_or_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_nope_attn_emb_fire_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/pos_arith_add_20_20_p_00/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=None arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"  arch.forward_only_model_with_skip=True\n# abacus\npython pretrain.py name=pos_or_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_256_mask_before_equals_true_start_emb_abacus_attn_emb_nope_with_skip_connections_run_1 wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=256 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=16 arch.maximal_recurrence=1 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/pos_arith_add_20_20_p_00/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True arch.embedding.pos_embedding=abacus arch.forward_only_model_with_skip=True\n"
  },
  {
    "path": "shells/evaluation.sh",
    "content": "# there is an automated helper in gen_eval_script.py for generating these evaluation scripts\n\n# Addition\npython arithmetic_eval_quicker.py name=<name> base_dir=$cramming_base_dir data=arithmetic max_rec=<max_rec> token_limit=105 big_eval_step_<STEP_NUM>=True reverse_inputs=True checkerboard=<EVEN/ODD> remove_padding=True data.sources.arithmetic.tokenizer_type=\"pad\"\n\n# Extended Addition Eval, i.e. 100\npython arithmetic_eval_quicker.py name=<name> base_dir=$cramming_base_dir data=arithmetic max_rec=<max_Rec> token_limit=105 big_eval_step_5=True reverse_inputs=True checkerboard=even remove_padding=True extended_eval=True data.sources.arithmetic.tokenizer_type=\"pad\"\n\n# Multiplication\npython arithmetic_eval_quicker.py name=<NAME> base_dir=$cramming_base_dir data=arithmetic max_rec=<max_rec> token_limit=30 mul=True data.sources.arithmetic.tokenizer_type=\"pad\"\n\n# Sorting\n# max_size_given = end of grid, start_ind_... = start of grid, i.e. this evaluates from 1,1 to final_size, final_size\npython sort_eval.py name=<name> base_dir=$cramming_base_dir data=arithmetic max_rec=<max_rec> sort_reverse=True data.sources.arithmetic.tokenizer_type='sort' max_size_given={final_size + 1} start_ind_1_given={1} start_ind_2_given={1}\n\n# Bitwise OR\npython arithmetic_eval_quicker.py name=<name> base_dir=$cramming_base_dir data=arithmetic max_rec=<max_rec> token_limit=105 big_eval_step_<STEP_NUM>=True checkerboard=<EVEN/ODD> pos_arth_ood=True data.sources.arithmetic.tokenizer_type=\"pad\" remove_padding=False"
  },
  {
    "path": "shells/generate_and_tokenize_data.sh",
    "content": "## Training Data -- these commands approximately correspond to the zipped data we provide\n\n# bitwise or\npython create_pos_or_variants.py --n 20 --m 20 --dir_name <NAME> --max 100\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.01\n\n# addition\npython create_data_split.py --bucket --op + --n 20 --m 20 --limit 20000000 --p 0.0 --dir_name <NAME> --reverse_all\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.01\n\n# addition with index hints\npython create_data_split.py --bucket --op + --n 20 --m 20 --limit 20000000 --p 0.0 --dir_name <NAME> --reverse_all --index_hints\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type index\n\n# multiplication\npython create_data_split.py --bucket --op x --n 15 --m 15 --limit 20000000 --dir_name <NAME>  --reverse_all --p 0.0\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.01\n\n# sorting\npython create_data_split.py --uniform_distribution_sort_data --continue_to_tokenize --tokenize --tokenizer_type sort --test_split_ratio 0.01 --n 10 --m 10 --limit 20000000 --dir <NAME> --sort_generation_method bucket_uniform_distribution --reverse_all\n\n## Evaluation Data -- run line and tokenize once for each operand length\n# bitwise or\npython create_pos_or_variants.py --n <i> --m <j> --dir_name <NAME> --exact --eval --max 100\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.0\n\n# addition\npython create_data_split.py --op + --n <i> --m <j> --num_samples 100 --dir_name <NAME> --exact\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.0\n\n# multiplication\npython create_data_split.py --op x --n <i> --m <j> --num_samples 100 --dir_name <NAME> --exact\npython create_data_split.py --tokenize --dir_name <NAME> --tokenizer_type pad --test_split_ratio 0.0\n\n# sorting\npython create_data_split.py --uniform_distribution_sort_data --continue_to_tokenize --tokenize --tokenizer_type sort --test_split_ratio 0.01 --n <i> --m <j> --limit 100 --dir <NAME> --sort_generation_method bucket_uniform_distribution --reverse_all --exact"
  },
  {
    "path": "shells/multiplication.sh",
    "content": "## only Looped Transformer experiments for multiplication\ntorchrun --nproc_per_node=8 --standalone pretrain.py name=mul_bucket_15_15_reverse_all_pad_00_depthrec_4_4_TBPTT_1024_nope_mask_before_equals_batch_512_fire_abacus_8_gpu wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=4 arch.maximal_recurrence=4 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/x_bucket_method_n_15_m_15_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.00006 data.sources.arithmetic.tokenizer_type=\"pad\" arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" arch.mask_before_equals=True impl.fullgraph=false arch.loss_reduction=none arch.throttle=True arch.embedding.pos_embedding=\"abacus\"\n\ntorchrun --nproc_per_node=8 --standalone pretrain.py name=mul_bucket_15_15_reverse_all_pad_00_depthrec_4_4_TBPTT_1024_nope_mask_before_equals_batch_512_fire_nope_8_gpu wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=4 arch.maximal_recurrence=4 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/x_bucket_method_n_15_m_15_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.00006 data.sources.arithmetic.tokenizer_type=\"pad\" arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\" arch.mask_before_equals=True impl.fullgraph=false arch.loss_reduction=none arch.throttle=True arch.embedding.pos_embedding=None\n\ntorchrun --nproc_per_node=8 --standalone pretrain.py name=mul_bucket_15_15_reverse_all_pad_00_depthrec_4_4_TBPTT_1024_nope_mask_before_equals_batch_512_abacus_8_gpu wandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir impl.microbatch_size=512 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=4 arch.maximal_recurrence=4 arch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True data.sources.arithmetic.tokenized_dataset_path=\"arithmetic_data/x_bucket_method_n_15_m_15_20000000_p_00_reverse_all/hf_tokenized_dataset\" train.optim.lr=0.0001 data.sources.arithmetic.tokenizer_type=\"pad\" arch.mask_before_equals=True impl.fullgraph=false arch.loss_reduction=none arch.throttle=True arch.embedding.pos_embedding=\"abacus\""
  },
  {
    "path": "shells/sorting.sh",
    "content": "# REMINDER SET BASE DIR\n\n\n## fire reverse\n## fire reverse recall\n## fire reverse recurrence\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_fire_8x1_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.attention.type='self-attention' \\\n\tarch.attention.rotary_embedding='fire' impl.fullgraph=false impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last'\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_fire_recall_8x1_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.attention.type='self-attention' \\\n\tarch.attention.rotary_embedding='fire' impl.fullgraph=false impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last' arch.forward_only_model_with_skip=True\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_fire_1x8_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=8 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.attention.type='self-attention' \\\n\tarch.attention.rotary_embedding='fire' impl.fullgraph=false impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last'\n\n## abacus reverse\n## abacus reverse recall\n## abacus reverse recurrence\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_8x1_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\"\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_8x1_skip_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\" arch.forward_only_model_with_skip=True\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_1x8_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=8 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\"\n\n\n## abacus fire reverse\n## abacus fire reverse recall\n## abacus fire reverse recurrence\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_with_fire_8x1_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\" \\\n\tarch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_with_fire_8x1_skip_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=8 arch.maximal_recurrence=1 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\" \\\n\tarch.forward_only_model_with_skip=True arch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\"\n\ntorchrun --nproc_per_node=1 --standalone pretrain.py name=sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all_abacus_with_fire_1x8_1_24_run_1 \\\n\twandb=none arch=crammed-depthrecurrent data=arithmetic base_dir=$cramming_base_dir \\\n\timpl.microbatch_size=32 budget=24 impl.compile_torch=False arch.objective_layout=TBPTT arch.layers_in_recurrent_block=1 arch.maximal_recurrence=8 \\\n\tarch.hidden_size=1024 arch.intermed_size=2048 impl.forbid_dataset_preprocessing=False impl.save_intermediate_checkpoints=True impl.save_final_model=True \\\n\tdata.sources.arithmetic.tokenized_dataset_path='arithmetic_data/sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p_00_reverse_all/hf_tokenized_dataset' \\\n\ttrain.optim.lr=0.0001 arch.embedding.pos_embedding=None data.sources.arithmetic.tokenizer_type='sort' arch.mask_before_equals=True arch.embedding.pos_embedding=\"abacus\" \\\n\tarch.attention.type=\"self-attention\" arch.attention.rotary_embedding=\"fire\""
  },
  {
    "path": "sort_eval.py",
    "content": "import logging\nimport hydra\nfrom omegaconf import OmegaConf\nimport cramming\nimport torch\nfrom safetensors.torch import load_file\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport json\nimport numpy as np\nimport re\nimport pandas as pd\nimport datasets\nimport os\nfrom typing import List, Dict\nfrom cramming.data.tokenizer_preparation import get_tokenizer\nimport random\n\nlog = logging.getLogger(__name__)\n\ndef grid_plotter(data, type=\"accs\", name='_large', extra_path=None):\n    \"\"\"plot a 2d accuracy grid\"\"\"\n    data = np.array(data)*100\n    df = pd.DataFrame(data)\n\n    # Create the heatmap\n    plt.figure(figsize=(10, 8))\n    sns.heatmap(df, annot=True, cmap=\"YlGnBu\", fmt=\".1f\", annot_kws={'size': 8,'rotation':0})\n    \n    # Customize the plot\n    plt.title(\"Accuracy - percetange, rounded to 1dp\")\n    plt.ylabel(\"1st Number Length\")\n    plt.xlabel(\"2nd Number Length\")\n    size = data.shape[0]\n    plt.xticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))\n    plt.yticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))\n\n    if extra_path is not None:\n        plt.savefig(f\"{extra_path}{type}{name}_grid_plot\", bbox_inches='tight')\n    else:\n        plt.savefig(f\"{type}{name}_grid_plot\", bbox_inches='tight')\n    plt.clf()\n\ndef grid_logic(cfg):\n    \"\"\"logic to select function to control which part of a 2d grid this run should be responsible for evaling\"\"\"\n\n    # origional testing\n    def logic_func_large(data_size_1, data_size_2):\n        return (data_size_1 <= 23 or data_size_2 <=23)\n    logic_func = logic_func_large\n    name = '_large'\n    max_size = 23+1\n    \n    if cfg.ood_only:\n        def logic_func_ood(data_size_1, data_size_2):\n            return (data_size_1 >=24 or data_size_2 >=24) and (data_size_1 <= 30 or data_size_2 <=30)\n        logic_func = logic_func_ood\n        name = '_ood_only'\n        max_size = 30+1\n        \n    if cfg.up_to_40:\n        def logic_func_40(data_size_1, data_size_2):\n            return (data_size_1 >=31 or data_size_2 >=31) and (data_size_1 <=40 or data_size_2 <=40)\n        logic_func = logic_func_40\n        name = '_up_to_40'\n        max_size = 40+1\n        \n    if cfg.up_to_50:\n        def logic_func_50(data_size_1, data_size_2):\n            return (data_size_1 >=41 or data_size_2 >=41) and (data_size_1 <=50 or data_size_2 <=50)\n        logic_func = logic_func_50\n        name = '_up_to_50'\n        max_size = 50+1\n\n    # checkerboarding: for the large eval we can checkerboard:\n\n    if cfg.checkerboard is not None:\n        if cfg.checkerboard == 'even':\n            def checkerboard_even(data_size_1, data_size_2):\n                return ((data_size_1+data_size_2)%2 ==0)\n            checkerboard_func = checkerboard_even\n            checkerboard_str = \"_even\"\n        elif cfg.checkerboard == 'odd':\n            def checkerboard_odd(data_size_1, data_size_2):\n                return ((data_size_1+data_size_2)%2 ==1)\n            checkerboard_func = checkerboard_odd\n            checkerboard_str = \"_odd\"\n        else:\n            print(\"checkerboard config not allowed\")\n            exit()\n    else:\n        def always_true(data_size_1, data_size_2):\n            return True\n        checkerboard_func = always_true\n        checkerboard_str = \"\"\n\n\n    # if we are testing up to 100, split into 10 steps each of approximately equal number of forward passes required\n    if cfg.big_eval_step_1: # 1 -> 46\n        def logic_func_big_1(data_size_1, data_size_2):\n            return (data_size_1 <= 46 and data_size_2 <= 46) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_1\n        name = '_big_eval_1'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_2: # 47 -> 58\n        def logic_func_big_2(data_size_1, data_size_2):\n            return (data_size_1 >=47 or data_size_2 >=47) and (data_size_1 <=58 and data_size_2 <=58) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_2\n        name = '_big_eval_2'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_3: # 59 -> 67\n        def logic_func_big_3(data_size_1, data_size_2):\n            return (data_size_1 >=59 or data_size_2 >=59) and (data_size_1 <=67 and data_size_2 <=67) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_3\n        name = '_big_eval_3'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_4: # 68 -> 74\n        def logic_func_big_4(data_size_1, data_size_2):\n            return (data_size_1 >=68 or data_size_2 >=68) and (data_size_1 <=74 and data_size_2 <=74) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_4\n        name = '_big_eval_4'+checkerboard_str\n        max_size = 100+1\n      \n    if cfg.big_eval_step_5: # 75 -> 80\n        def logic_func_big_5(data_size_1, data_size_2):\n            return (data_size_1 >= 75 or data_size_2 >=75) and (data_size_1 <=80 and data_size_2 <=80) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_5\n        name = '_big_eval_5'+checkerboard_str\n        max_size = 100+1\n\n    if cfg.big_eval_step_6: # 81 -> 85\n        def logic_func_big_6(data_size_1, data_size_2):\n            return (data_size_1 >= 81 or data_size_2 >=81) and (data_size_1 <=85 and data_size_2 <=85) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_6\n        name = '_big_eval_6'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_7: # 86 -> 90\n        def logic_func_big_7(data_size_1, data_size_2):\n            return (data_size_1 >= 86 or data_size_2 >=86) and (data_size_1 <=90 and data_size_2 <=90) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_7\n        name = '_big_eval_7'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_8: # 91 -> 94\n        def logic_func_big_8(data_size_1, data_size_2):\n            return (data_size_1 >= 91 or data_size_2 >=91) and (data_size_1 <=94 and data_size_2 <=94) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_8\n        name = '_big_eval_8'+checkerboard_str\n        max_size = 100+1\n    \n    if cfg.big_eval_step_9: # 95 -> 97\n        def logic_func_big_9(data_size_1, data_size_2):\n            return (data_size_1 >= 95 or data_size_2 >=95) and (data_size_1 <=97 and data_size_2 <=97) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_9\n        name = '_big_eval_9'+checkerboard_str\n        max_size = 100+1\n        \n    if cfg.big_eval_step_10: # 98 -> 100\n        def logic_func_big_10(data_size_1, data_size_2):\n            return (data_size_1 >= 98 or data_size_2 >=98) and (data_size_1 <=100 and data_size_2 <=100) and checkerboard_func(data_size_1, data_size_2)\n        logic_func = logic_func_big_10\n        name = '_big_eval_10'+checkerboard_str\n        max_size = 100+1\n\n    # boolean_list_precidence = [large, ood_only, up_to_40, up_to_50, big_eval_step_1, big_eval_step_2, big_eval_step_3, big_eval_step_4, big_eval_step_5]\n\n    log.info(f\"large = {cfg.large}\")\n    log.info(f\"ood only = {cfg.ood_only}\")\n    log.info(f\"up to 40 = {cfg.up_to_40}\")\n    log.info(f\"up to 50 = {cfg.up_to_50}\")\n    log.info(f\"big eval 1 = {cfg.big_eval_step_1}\")\n    log.info(f\"big eval 2 = {cfg.big_eval_step_2}\")\n    log.info(f\"big eval 3 = {cfg.big_eval_step_3}\")\n    log.info(f\"big eval 4 = {cfg.big_eval_step_4}\")\n    log.info(f\"big eval 5 = {cfg.big_eval_step_5}\")\n    log.info(f\"big eval 6 = {cfg.big_eval_step_6}\")\n    log.info(f\"big eval 7 = {cfg.big_eval_step_7}\")\n    log.info(f\"big eval 8 = {cfg.big_eval_step_8}\")\n    log.info(f\"big eval 9 = {cfg.big_eval_step_9}\")\n    log.info(f\"big eval 10 = {cfg.big_eval_step_10}\")\n    log.info(f\"the last true value in the above list will be run, mul and pos arith can take control after this\")\n\n    return logic_func, name, max_size\n\ndef main(cfg):\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, \"checkpoints\")\n    tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,\n                                                                                local_checkpoint_folder,\n                                                                                cfg.eval.arch_modifications)\n    if cfg.max_rec is not None: # can have more/less recurrences for eval\n        cfg_arch.maximal_recurrence_in_eval = cfg.max_rec\n    else:\n        cfg_arch.maximal_recurrence_in_eval = cfg_arch.maximal_recurrence\n    log.info(f\"cfg_arch.maximal_recurrence_in_eval changed to {cfg_arch.maximal_recurrence_in_eval}\")\n    cfg_arch.throttle = False # turn throttle off\n\n    logic_func, name, max_size = grid_logic(cfg)\n\n    # import tokeniser\n    cfg_data_sources_values_list = list(cfg.data.sources.values())[0]\n    if cfg_data_sources_values_list[\"provider\"] == \"arithmetic\":\n        tokenizer = get_tokenizer(cfg_data_sources_values_list[\"tokenizer_type\"])\n    else: \n        log.info(\"exiting as this is only for arithmetic\")\n        exit()\n    vocab = tokenizer.ids_to_tokens\n    EOS_token = tokenizer._convert_token_to_id(tokenizer.eos_token)\n    PAD_token = tokenizer._convert_token_to_id(tokenizer.pad_token)\n    assert PAD_token == 0, \"PAD token must be token zero for our code to work\"\n\n    # Load model\n    if 'alpha' not in cfg_arch:\n        cfg_arch['alpha'] = 1.0\n\n    model = cramming.construct_model(cfg_arch, tokenizer).to(device)\n    model = cramming.backend.load_model_checkpoint(model, model_file)\n    model.to(device)\n    model.eval()\n\n    log.info(f\"greedy = {cfg.greedy}, note: if greedy = True this overrides any temperature arguments\")\n    ## Greedy decoding will overide any temperature arguments\n\n    if cfg.max_size_given is not None: # allows unique splits for eval\n        max_size = cfg.max_size_given\n\n    # Grid plots - grid search from 1x1 to 12x12 data\n    data_sizes = list(range(1, max_size))\n    acc_grid = np.zeros((len(data_sizes),len(data_sizes)))\n    start_ind_1 = 0\n    start_ind_2 = 0\n    tuple_method = False\n    completed_one = False\n    if \"big_eval\" in name:\n        tuple_method = True\n        # go up two layers and search for grid\n        try:\n            with open(f\"../../accs_grid_quick{name}.json\", 'r') as file:\n                data = json.load(file)\n            start_ind_1 = data[1]\n            start_ind_2 = data[2]\n            acc_grid = np.array(data[0])\n            log.info(\"loaded grid from previous run\")\n        except:\n            pass\n\n    if cfg.start_ind_1_given is not None: # allows unique splits for eval\n        start_ind_1 = cfg.start_ind_1_given\n    if cfg.start_ind_2_given is not None:\n        start_ind_2 = cfg.start_ind_2_given\n    log.info(f\"start_ind_1 = {start_ind_1}, start_ind_2 = {start_ind_2}\")\n\n    os.makedirs(\"outputs\", exist_ok=True)\n\n    all_outputs_folder_path = f\"../../all_outputs_max_recurrence={cfg_arch.maximal_recurrence_in_eval}\"\n    os.makedirs(all_outputs_folder_path, exist_ok=True)\n\n    if not cfg.extended_eval:\n        # main 2d loop\n        for data_size_1 in data_sizes:\n            for data_size_2 in data_sizes:\n                proceed = False\n                if data_size_1 >= start_ind_1 or data_size_2 >= start_ind_2:\n                    proceed = True\n\n                if not proceed:\n                    continue\n\n                # check if done\n                # if done it will be done and saved in f\"../../acc_for_{data_size_1}_{data_size_2}.txt\"\n                if os.path.exists(f\"{all_outputs_folder_path}/acc_for_{data_size_1}_{data_size_2}.txt\"):\n                    with open(f\"{all_outputs_folder_path}/acc_for_{data_size_1}_{data_size_2}.txt\", 'r') as file:\n                        acc = float(file.read())\n                    acc_grid[data_size_1-1, data_size_2-1] = acc\n                    continue\n\n                if logic_func(data_size_1, data_size_2):\n                    completed_one = True\n                    log.info(f\"Starting iteration in grid eval for size: {data_size_1} and {data_size_2}\")\n                    # only one option -- sorting with reversed numbers\n                    file_path = f\"../../../../data/arithmetic_data/sort_reverse/sort_uniform_distribution_sort_basic_max_digits_n_{data_size_1}_max_length_m_{data_size_2}_200_p_00_reverse_all/hf_tokenized_dataset\"\n                   \n                    tokenized_dataset = datasets.load_from_disk(file_path)[\"test\"]\n                    data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)\n\n                    # keep track of totals for a batch as we only eval one sample at a time\n                    correct_total = 0\n                    all_total = 0\n                    top_1_total = 0\n                    for batch in data_loader:\n                        input_ids = batch[\"input_ids\"]\n                        input_ids = torch.stack(input_ids).to(device)\n                        input_ids = torch.transpose(input_ids, 0, 1)\n\n                        all = 0\n                        correct = 0\n                        top_1 = 0\n                        for i in range(len(input_ids)):\n                            example = input_ids[i]\n                            equals_token = tokenizer._convert_token_to_id(\"=\")\n                            equals_indices = torch.where(example == equals_token)[0].item()\n                            question = example[:equals_indices + 1]\n                            answer = example[equals_indices + 1:]\n                            \n                            question = question.unsqueeze(0)\n\n                            local_token_limit = int(len(answer) * 2)\n                            predicted_ids = model._generate(question,\n                                                            token_limit=local_token_limit,\n                                                            temperature=cfg.temp,\n                                                            steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval,\n                                                            greedy=cfg.greedy, quick=True)\n                            predicted_ids = predicted_ids.squeeze()\n\n                            # get the answer\n                            eos_token = tokenizer._convert_token_to_id(tokenizer.eos_token)\n                            eos_indices = torch.where(answer == eos_token)[0].item()\n                            answer = answer[:eos_indices]\n\n                            predicted_ids = predicted_ids[:len(answer)]\n                            if torch.equal(predicted_ids, answer):\n                                correct += 1\n\n                            top_1_target = answer[0]\n                            top_1_predicted = predicted_ids[0]\n                            if torch.equal(top_1_target, top_1_predicted):\n                                top_1 += 1\n\n                            all += 1\n\n                        correct_total += correct\n                        top_1_total += top_1\n                        all_total += all\n\n\n                    acc = correct_total / all_total\n                    acc_top_1 = top_1_total / all_total\n\n                    log.info(f\"accuracy for data that has numbers \"\n                             f\"with maximum number of digits as {data_size_1} , \"\n                             f\"and the array of length {data_size_2} is {acc * 100}\")\n                    log.info(f\"Top 1 accuracy for data that has numbers \"\n                             f\"with maximum number of digits as {data_size_1} , \"\n                             f\"and the array of length {data_size_2} is {acc_top_1 * 100}\")\n\n                    question = tokenizer.decode(question.squeeze().tolist())\n                    answer = tokenizer.decode(answer.tolist())\n                    predicted = tokenizer.decode(predicted_ids.tolist())\n                    log.info(f\"For example : sort {question} for which the answer is {answer} , \"\n                             f\"and the predicted is {predicted}\")\n                    acc_grid[(data_size_1-1), (data_size_2-1)] = acc * 100\n                    \n                    # save all in case of crash\n                    with open(f\"{all_outputs_folder_path}/acc_for_{data_size_1}_{data_size_2}.txt\", \"w\") as file:\n                        file.write(f\"{acc * 100}\")\n                    with open(f\"{all_outputs_folder_path}/top_1_acc_for_{data_size_1}_{data_size_2}.txt\", \"w\") as file:\n                        file.write(f\"{acc_top_1 * 100}\")\n\n        log.info(f\"acc grid: {acc_grid}\")\n\n        with open(f\"accs_grid_quick_{start_ind_1}_{start_ind_2}_{max_size}.json\", \"w\") as file:\n            json.dump(acc_grid.tolist(), file)\n\n        # Grid plots - one for accs one for contains\n        grid_plotter(acc_grid, name=f\"{start_ind_1}_{start_ind_2}_{max_size}\")\n        grid_plotter(acc_grid, name=f\"{start_ind_1}_{start_ind_2}_{max_size}\", extra_path=all_outputs_folder_path)\n\n    log.info(\"Eval complete\")\n\n@hydra.main(config_path=\"cramming/config\", config_name=\"cfg_eval\", version_base=\"1.3\")\ndef launch(cfg):\n    log.info(\"calling main launch\")\n    cfg = cramming.utils.pathfinder(cfg)\n    log.info(OmegaConf.to_yaml(cfg, resolve=True))\n    main(cfg)\n\nif __name__ == \"__main__\":\n    launch()"
  },
  {
    "path": "upload_processed_dataset.py",
    "content": "\"\"\"Script to upload a processed dataset to the huggingface hub. You probably don't need this :)\"\"\"\n\n\nimport hydra\nimport logging\nfrom omegaconf import OmegaConf\nimport tempfile\nimport os\n\nfrom datasets import load_dataset\n\nimport cramming\n\n\nlog = logging.getLogger(__name__)\n\n\ndef upload(cfg, setup):\n    dataset, tokenizer = cramming.load_pretraining_corpus(cfg.data, cfg.impl)\n    checksum = cramming.data.utils.checksum_config(cfg.data)\n    processed_dataset_name = f\"{cfg.data.name}_{checksum}\"\n\n    use_own_chunking = True\n    chunk_size = 8192 * 32\n    num_files = len(dataset) // chunk_size + 1\n    target_types = [\"input_ids\"]\n\n    files = []\n    # Split dataset in parquet files\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        if use_own_chunking:\n            # Loop through the dataset and write each chunk to a Parquet file\n            # This is not really necessary, but nice to save only target_types and to match chunk sizes to target batch sizes\n            for idx in range(num_files):\n                chunk = dataset.select(range(idx * chunk_size, min(len(dataset), (idx + 1) * chunk_size)))\n                filename = f\"{tmpdirname}/train_{idx}.parquet\"\n                chunk.to_pandas()[target_types].to_parquet(filename, index=False)\n                files.append(filename)\n                log.info(f\"Chunk {idx} written to file {filename}.\")\n\n            # Re-assemble parqueted dataset\n            dataset = load_dataset(\"parquet\", data_files=files)\n\n        # Define the dataset info\n        description = f\"\"\"This is a preprocessed dataset for the cramming-project.\n\n                                Use only with the tokenizer prescribed here.\n                                This version is {processed_dataset_name}, which corresponds to the following setup:\n                                {OmegaConf.to_yaml(cfg, resolve=True)}\n\n                                Limitations and bias:\n                                This training data was further filtered and sorted beyond the normal preprocessing.\n                                These modifications were not tested for unintended consequences.\n\n                              \"\"\"\n        dataset[\"train\"].info.description = description\n        # dataset_tags = [\"cramming\", \"English\", \"preprocessed\"]\n\n        # Launch upload\n        log.info(\"Preparing for dataset upload ...\")\n        dataset.push_to_hub(processed_dataset_name, private=True)\n\n        # Upload tokenizer to same adress - this is annoying because by default tokenizers are pushed to model directories\n        # tokenizer.push_to_hub(processed_dataset_name) -> this will push to a new directory in HF models\n        from huggingface_hub import HfApi\n\n        api = HfApi()\n        log.info(\"Preparing for tokenizer upload ...\")\n        tokenizer_loc = os.path.join(os.path.join(cfg.impl.path, processed_dataset_name), \"tokenizer\")\n        for file in os.listdir(tokenizer_loc):\n            api.upload_file(\n                path_or_fileobj=os.path.join(tokenizer_loc, file),\n                path_in_repo=os.path.join(\"tokenizer\", file),\n                repo_id=f\"{api.whoami()['name']}/{processed_dataset_name}\",\n                repo_type=\"dataset\",\n            )\n        log.info(\"Upload completed succesfully.\")\n\n\n@hydra.main(config_path=\"cramming/config\", config_name=\"cfg_pretrain\", version_base=\"1.3\")\ndef launch(cfg):\n    cramming.utils.main_launcher(cfg, upload, job_name=\"upload\")\n\n\nif __name__ == \"__main__\":\n    launch()\n"
  }
]