Showing preview only (516K chars total). Download the full file or copy to clipboard to get everything.
Repository: mcleish7/arithmetic
Branch: main
Commit: 86022a57d38c
Files: 132
Total size: 479.5 KB
Directory structure:
gitextract_shohcgjg/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── abacus.py
├── arithmetic_eval_quicker.py
├── cramming/
│ ├── __init__.py
│ ├── architectures/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── components.py
│ │ ├── construction.py
│ │ ├── crammed_depthrecurrent.py
│ │ ├── crammed_transformer.py
│ │ ├── embeddings.py
│ │ ├── huggingface_interface.py
│ │ ├── losses.py
│ │ └── sanity_check.py
│ ├── backend/
│ │ ├── __init__.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── optimizer_modifiers.py
│ │ │ ├── progressive_batching.py
│ │ │ └── schedulers.py
│ │ ├── prepare_backend.py
│ │ ├── torch_default.py
│ │ └── utils.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── arch/
│ │ │ ├── __init__.py
│ │ │ ├── albert.yaml
│ │ │ ├── crammed-depthrecurrent.yaml
│ │ │ ├── crammed-fakeRNN.yaml
│ │ │ ├── crammed-janus.yaml
│ │ │ ├── crammed-rnn.yaml
│ │ │ ├── crammed-stack-janus.yaml
│ │ │ ├── crammed-tiny.yaml
│ │ │ ├── crammed-transformer.yaml
│ │ │ ├── gpt2-base.yaml
│ │ │ ├── hf-gpt2.yaml
│ │ │ └── sanitycheck.yaml
│ │ ├── cfg_eval.yaml
│ │ ├── cfg_pretrain.yaml
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── arithmetic.yaml
│ │ │ ├── c4-subset-processed.yaml
│ │ │ ├── openweb.yaml
│ │ │ ├── proofpile.yaml
│ │ │ ├── sanity-check-1.yaml
│ │ │ ├── sanity-check-2.yaml
│ │ │ └── sources/
│ │ │ ├── ag_news.yaml
│ │ │ ├── arithmetic.yaml
│ │ │ ├── bookcorpus.yaml
│ │ │ ├── c4.yaml
│ │ │ ├── dash_books.yaml
│ │ │ ├── fake.yaml
│ │ │ ├── iwslt.yaml
│ │ │ ├── local.yaml
│ │ │ ├── no_code_stackexchange.yaml
│ │ │ ├── openwebtext.yaml
│ │ │ ├── oscar.yaml
│ │ │ ├── proofpiledata.yaml
│ │ │ ├── the_pile.yaml
│ │ │ ├── the_pileCC.yaml
│ │ │ ├── the_pile_dedup.yaml
│ │ │ ├── the_pile_natural.yaml
│ │ │ ├── the_pile_stream.yaml
│ │ │ ├── uncorpus.yaml
│ │ │ ├── uspto.yaml
│ │ │ ├── wikibooks.yaml
│ │ │ ├── wikinews.yaml
│ │ │ ├── wikipedia.yaml
│ │ │ ├── wikiquote.yaml
│ │ │ ├── wikiversity.yaml
│ │ │ └── wikivoyage.yaml
│ │ ├── eval/
│ │ │ ├── __init__.py
│ │ │ ├── pythia.yaml
│ │ │ └── tasks/
│ │ │ ├── lambada_openai.yaml
│ │ │ └── winogrande.yaml
│ │ ├── hydra/
│ │ │ ├── __init__.py
│ │ │ └── job_logging/
│ │ │ └── custom.yaml
│ │ ├── impl/
│ │ │ ├── __init__.py
│ │ │ ├── _default.yaml
│ │ │ └── torch-default.yaml
│ │ ├── train/
│ │ │ ├── __init__.py
│ │ │ ├── common.yaml
│ │ │ ├── cramming.yaml
│ │ │ ├── janus-regime.yaml
│ │ │ ├── optim/
│ │ │ │ ├── adafactor.yaml
│ │ │ │ ├── adahessian.yaml
│ │ │ │ ├── adam.yaml
│ │ │ │ ├── adam8bit.yaml
│ │ │ │ ├── adam_classic.yaml
│ │ │ │ ├── adamscale.yaml
│ │ │ │ ├── agd.yaml
│ │ │ │ ├── lion.yaml
│ │ │ │ ├── radam.yaml
│ │ │ │ ├── sgd.yaml
│ │ │ │ └── shampoo.yaml
│ │ │ └── optim_mod/
│ │ │ ├── disabled.yaml
│ │ │ ├── larc.yaml
│ │ │ ├── lars.yaml
│ │ │ ├── progressive.yaml
│ │ │ └── sam.yaml
│ │ └── wandb/
│ │ ├── default.yaml
│ │ └── none.yaml
│ ├── data/
│ │ ├── __init__.py
│ │ ├── arithmetic_tokenizers.py
│ │ ├── curriculum_sorting.py
│ │ ├── deduplicate.py
│ │ ├── pretraining_preparation.py
│ │ ├── tokenizer_preparation.py
│ │ └── utils.py
│ └── utils.py
├── create_data_split.py
├── create_pos_or_variants.py
├── dataset_analysis.py
├── gen_eval_script.py
├── load_local_model.py
├── pretrain.py
├── pretty_plotter.py
├── pretty_plotter_big.py
├── pretty_plotter_sort.py
├── pyproject.toml
├── setup.cfg
├── shells/
│ ├── addition_ff.sh
│ ├── addition_lt.sh
│ ├── bitwise_or.sh
│ ├── evaluation.sh
│ ├── generate_and_tokenize_data.sh
│ ├── multiplication.sh
│ └── sorting.sh
├── sort_eval.py
└── upload_processed_dataset.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
outputs
tables/*/*.csv
tables/*/*.csv#
tables/*.csv
tables/*.csv#
tables/*.ods
*.png
*.pdf
# torchdynamo debug
isolate
repro.py
checkpoints
wandb-metadata.json
torch_compile_debug/
dedup
.vs/
*.pdf
images
*.temp.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
*.csv
*.txt
*.pth
cramming-data/
sanity.sh
log/
del.sh
del.py
sort_plots/
================================================
FILE: .pre-commit-config.yaml
================================================
# precommit hooks from https://github.com/ashleve/lightning-hydra-template
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: debug-statements
- id: detect-private-key
# python code formatting
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: [--line-length, "140", "--fast"] # ;>
# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.3.0
hooks:
- id: prettier
types: [yaml]
# python code analysis
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Sean McLeish, Jonas Geiping
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
# added by check-manifest
include *.py
include *.yaml
recursive-include cramming *.md
recursive-include cramming *.yaml
global-exclude *.pyc
global-exclude __pycache__
================================================
FILE: README.md
================================================
# Transformers Can Do Arithmetic with the Right Embeddings! [Link to arXiv paper](https://arxiv.org/abs/2405.17399)
A joint project by: Sean McLeish, Arpit Bansal, Alex Stein, Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild and Tom Goldstein
This repository contains code to replicate our research. It is a fork of the language model training framework [cramming](https://github.com/JonasGeiping/cramming) edited to for a next token prediction objective.
We provide a standalone implementation of Abacus Embeddings in [abacus.py](abacus.py).
## Citing Our Work
To cite our work, please use this bibtex.
```
@article{mcleish2024transformers,
title={Transformers Can Do Arithmetic with the Right Embeddings},
author={Sean McLeish and Arpit Bansal and Alex Stein and Neel Jain and John Kirchenbauer and Brian R. Bartoldson and Bhavya Kailkhura and Abhinav Bhatele and Jonas Geiping and Avi Schwarzschild and Tom Goldstein},
journal={arXiv preprint arXiv:2405.17399},
year={2024}
}
```
# Getting Started
We developed in Python 3.10.4, to install run:
```
git clone git@github.com:mcleish7/arithmetic.git
cd arithmetic
pip install .
```
On some machines you will need to run:
1. `pip install multiprocess -U`
2. `pip install dill -U`
3. `pip install apache-beam -U`
# Arithmetic
## Datasets
We release our datasets on [Google Drive](https://drive.google.com/drive/folders/1DqjCrUM1cNV7069Zl25_qBw2Px2xAw9j?usp=sharing) both in zipped format. We recommend you work with the zipped version until it is correctly placed in your file system.
Alternatively, you can make your own datasets using [create_data_split.py](create_data_split.py) using the commands from [shells/generate_and_tokenize_data.sh](shells/generate_and_tokenize_data.sh).
## File Structure
We recommend creating another directory `cramming-data` inside of arithmetic. This is where the models, logs and data will be stored.
You can either export you cramming base directory path to your `.bashrc` or you can replace `$cramming_base_dir` manually in the provided shells.
```
cd arithmetic
mkdir cramming-data
echo 'export cramming_base_dir=MY_BASE_DIR' >> ~/.bashrc
source ~/.bashrc
```
For example, this may look like: `echo 'export cramming_base_dir=~/arithmetic/cramming-data' >> ~/.bashrc`
For example our file system looks like:
```
cramming-generative
└── cramming-data
├── addition-train-one
│ ├── pretrain/<DATE>/<TIME>
│ │ ├── .hydra
│ │ │ ├── config.yaml
│ │ │ ├── hydra.yaml
│ │ │ └── overrides.yaml
│ │ └── addition-train-one_pretrain.log
│ ├── checkpoints/FINAL_<LOSS_VAL>
│ │ ├── model_config.json
│ │ ├── model.safetensors
│ │ └── state_dict.pth
│ └── downstream
└── data
└── arithmetic_data
├── +_grid_eval_dataset_reverse_all_tokenized
└── ... other datasets ...
```
## Training
Example commands are in the [shells](shells) directory, organised by task.
### Explanation of Some Commands
1. Give samples instead of tokens equal importance in loss: `arch.loss_reduction=none`
2. Divide the gradients in the recurrent block by the number of recurrences: `arch.throttle=True`
3. Mask before the equals sign: `arch.mask_before_equals=True`
4. Skip connections inside of the recurrent block: `arch.forward_only_model_with_skip=True`
5. Multi-GPU: `python` -> `torchrun --nproc_per_node=<NUM GPUS> --standalone ` and add `impl.fullgraph=false`
### Positional Embeddings:
#### Absolute
1. Learned: `arch.embedding.pos_embedding=learned`
2. Abacus: `arch.embedding.pos_embedding=abacus`
* If you want the maximum k in abacus to be larger: `arch.embedding.max_abacus_len=100`, be default this value is 100. Abacus is also implemented in a standalone manner in [abacus.py](abacus.py).
#### Relative
1. NoPE: `arch.embedding.pos_embedding=None`
2. FIRE: `arch.embedding.pos_embedding=None arch.attention.type="self-attention" arch.attention.rotary_embedding="fire"`
3. FIRE randomised: e.g:`arch.embedding.pos_embedding=None arch.attention.type="self-attention" arch.attention.rotary_embedding="fire" arch.attention.max_length=128` by default `arch.attention.max_length=0` so setting this longer than the max sequence length gives some randomness in the embedding.
4. RoPE: `arch.attention.type="self-attention" arch.attention.rotary_embedding=true`
### Checkpointing
We have implemented *single* GPU training checkpointing, to do this use:
`impl.save_every_n_minutes=60 impl.save_intermediate_model_name='last'`
This saves a checkpoint every 60 minutes under the name 'last'
Caution: This feature is not fully tested for multi-GPU cases. We also cannot currently train models which have used their full budget for longer.
### WandB
You can log runs to your weights&biases account. To do so, simply modify `wandb.entity` and `wandb.project` on the command line or at [cramming/config/wandb/default.yaml](cramming/config/wandb/default.yaml).
## Testing
We show examples in [shells/evaluation.sh](shells/evaluation.sh).
We provide a very basic automation in [gen_eval_script.py](gen_eval_script.py), this prints the basic commands you may need to further edit these.
### Addition
For addition we have a very large possible evaluation set, we do a grid search over a 100x100 grid which we split into 20 pieces with the aim of balancing the number of forward calls across all 20 pieces.
We then have a further eval for operand lengths 100->160.
### Multiplication
We only evaluate up to 25x25, which we do in a single job.
### Sorting
Sorting uses a separate evaluation file [sort_eval.py](sort_eval.py), this is because the evaluation calls cannot be parallelised, making evaluation much longer.
The evaluation cannot be parallelised because the place of the equals sign is not fixed for a batch.
We currently evaluate across 30 jobs for a 30x30 grid but this can be reduced to a smaller number of jobs using these flags: `max_size_given, start_ind_1_given, start_ind_2_given`
### Bitwise OR
We use the same framework as for addition but the process is quicker as some of the batches do not contain 100 samples as there are not 100 possibilities for some batches. Unlike addition we do not sample with replacement for this task.
# Analysis
1. We provide [pretty_plotter.py](pretty_plotter.py) to combine the small evaluation grids together into one plot.
Use this by putting the model name into the string at the top of the `main` function.
2. For the large 100x100 grids we provide [pretty_plotter_big.py](pretty_plotter_big.py).
These are designed to be as flexible as possible but may need to be edited to fit your file set up.
3. For sorting, we provide [pretty_plotter_sort.py](pretty_plotter_sort.py), this allows us to read the individual `.txt` files created during testing and merge them all together into a nice plot.
# Contact
Please, feel free to contact us with any questions, or open an issue on Github.
================================================
FILE: abacus.py
================================================
"""Implementation of abacus embeddings"""
# Example of how to extract digit tokens to pass into constructor
# digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])
class Abacus(torch.nn.Module):
"""
Abacus Embeddings, learned emebddings resued for each digit.
Integers must be reversed for this to work correctly.
Transformers Can Do Arithmetic with the Right Embeddings, McLeish et al. (2024)
"""
def __init__(self, digit_tokens, embedding_dim, max_seq_length=1024, max_k=99):
"""
digit_tokens (list): list of the tokens for each of the 10 digits, `digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])`
embedding_dim (int): dimension to embed into
max_seq_length (int): maximum number of embeddings that can be trained
max_k (int): maximum k value which we randomly shift by during training
"""
super().__init__()
self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
self.register_buffer("digits", torch.tensor(digit_tokens), persistent=False)
self.max_k = max_k
def helper(self, mask, device):
"""
Converts a binary mask of digit locations into spans of consecutive digits
"""
mask_shape = mask.shape
# Create a shifted version of the mask to detect changes from 0 to 1
shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)
starts = (shifted_mask != mask) & mask
# Generate IDs for each segment of 1s, processing row-wise
segment_ids = torch.cumsum(starts, dim=1)
# Generate an index array row-wise
index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)
# Reset index at the start of each segment
reset_index = torch.zeros_like(mask).long()
second_term = index * starts.long()
reset_index = reset_index.scatter_add(1, segment_ids, second_term)
# Calculate positions in segment
positions = index - reset_index.gather(1, segment_ids) + 1
# Ensure only values within 1-segments are non-zero
result = positions * mask
return result
def forward(self, input_ids):
"""
input_ids (tensor): a batch of inputs, each row is a sample
"""
mask = torch.isin(input_ids, self.digits)
output = self.helper(mask, input_ids.device)
k=0
if self.training:
k = random.randint(0, self.max_k)
output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1
return self.embedding(output)
================================================
FILE: arithmetic_eval_quicker.py
================================================
import logging
import hydra
from omegaconf import OmegaConf
import cramming
import torch
from safetensors.torch import load_file
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import re
import pandas as pd
import datasets
import os
from typing import List, Dict
from cramming.data.tokenizer_preparation import get_tokenizer
import random
log = logging.getLogger(__name__)
def grid_plotter(data, type="accs", name='_large', extra_path=None):
"""plot a 2d accuracy grid"""
data = np.array(data)*100
df = pd.DataFrame(data)
# Create the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(df, cmap="YlGnBu", fmt=".1f", annot_kws={'size': 8,'rotation':0})
# Customize the plot
plt.title("Accuracy - percetange, rounded to 1dp")
plt.ylabel("1st Number Length")
plt.xlabel("2nd Number Length")
size = data.shape[0]
plt.xticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))
plt.yticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))
if extra_path is not None:
plt.savefig(f"{extra_path}{type}{name}_grid_plot", bbox_inches='tight')
else:
plt.savefig(f"{type}{name}_grid_plot", bbox_inches='tight')
plt.clf()
def index_hints_helper(num, tokenizer):
"""Add index hints into a tokenized number"""
char_set = tokenizer.char_set
shape1 = num.shape[1]
for i in range(shape1):
this_char_token = tokenizer._convert_token_to_id(char_set[i])
char_to_insert = this_char_token * torch.ones((num.shape[0], 1), dtype=num.dtype, device=num.device)
num = torch.cat((num[:,:(2*i)], char_to_insert, num[:,(2*i):]), dim=1)
return num
def grid_logic(cfg):
"""logic to select function to control which part of a 2d grid this run should be responsible for evaling"""
# origional testing
def logic_func_large(data_size_1, data_size_2):
return (data_size_1 <= 23 or data_size_2 <=23)
logic_func = logic_func_large
name = '_large'
max_size = 23+1
if cfg.ood_only:
def logic_func_ood(data_size_1, data_size_2):
return (data_size_1 >=24 or data_size_2 >=24) and (data_size_1 <= 30 or data_size_2 <=30)
logic_func = logic_func_ood
name = '_ood_only'
max_size = 30+1
if cfg.up_to_40:
def logic_func_40(data_size_1, data_size_2):
return (data_size_1 >=31 or data_size_2 >=31) and (data_size_1 <=40 or data_size_2 <=40)
logic_func = logic_func_40
name = '_up_to_40'
max_size = 40+1
if cfg.up_to_50:
def logic_func_50(data_size_1, data_size_2):
return (data_size_1 >=41 or data_size_2 >=41) and (data_size_1 <=50 or data_size_2 <=50)
logic_func = logic_func_50
name = '_up_to_50'
max_size = 50+1
# checkerboarding: for the large eval we can checkerboard:
if cfg.checkerboard is not None:
if cfg.checkerboard == 'even':
def checkerboard_even(data_size_1, data_size_2):
return ((data_size_1+data_size_2)%2 ==0)
checkerboard_func = checkerboard_even
checkerboard_str = "_even"
elif cfg.checkerboard == 'odd':
def checkerboard_odd(data_size_1, data_size_2):
return ((data_size_1+data_size_2)%2 ==1)
checkerboard_func = checkerboard_odd
checkerboard_str = "_odd"
else:
print("checkerboard config not allowed")
exit()
else:
def always_true(data_size_1, data_size_2):
return True
checkerboard_func = always_true
checkerboard_str = ""
# if we are testing up to 100, split into 10 steps each of approximately equal number of forward passes required
if cfg.big_eval_step_1: # 1 -> 46
def logic_func_big_1(data_size_1, data_size_2):
return (data_size_1 <= 46 and data_size_2 <= 46) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_1
name = '_big_eval_1'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_2: # 47 -> 58
def logic_func_big_2(data_size_1, data_size_2):
return (data_size_1 >=47 or data_size_2 >=47) and (data_size_1 <=58 and data_size_2 <=58) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_2
name = '_big_eval_2'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_3: # 59 -> 67
def logic_func_big_3(data_size_1, data_size_2):
return (data_size_1 >=59 or data_size_2 >=59) and (data_size_1 <=67 and data_size_2 <=67) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_3
name = '_big_eval_3'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_4: # 68 -> 74
def logic_func_big_4(data_size_1, data_size_2):
return (data_size_1 >=68 or data_size_2 >=68) and (data_size_1 <=74 and data_size_2 <=74) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_4
name = '_big_eval_4'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_5: # 75 -> 80
def logic_func_big_5(data_size_1, data_size_2):
return (data_size_1 >= 75 or data_size_2 >=75) and (data_size_1 <=80 and data_size_2 <=80) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_5
name = '_big_eval_5'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_6: # 81 -> 85
def logic_func_big_6(data_size_1, data_size_2):
return (data_size_1 >= 81 or data_size_2 >=81) and (data_size_1 <=85 and data_size_2 <=85) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_6
name = '_big_eval_6'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_7: # 86 -> 90
def logic_func_big_7(data_size_1, data_size_2):
return (data_size_1 >= 86 or data_size_2 >=86) and (data_size_1 <=90 and data_size_2 <=90) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_7
name = '_big_eval_7'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_8: # 91 -> 94
def logic_func_big_8(data_size_1, data_size_2):
return (data_size_1 >= 91 or data_size_2 >=91) and (data_size_1 <=94 and data_size_2 <=94) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_8
name = '_big_eval_8'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_9: # 95 -> 97
def logic_func_big_9(data_size_1, data_size_2):
return (data_size_1 >= 95 or data_size_2 >=95) and (data_size_1 <=97 and data_size_2 <=97) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_9
name = '_big_eval_9'+checkerboard_str
max_size = 100+1
if cfg.big_eval_step_10: # 98 -> 100
def logic_func_big_10(data_size_1, data_size_2):
return (data_size_1 >= 98 or data_size_2 >=98) and (data_size_1 <=100 and data_size_2 <=100) and checkerboard_func(data_size_1, data_size_2)
logic_func = logic_func_big_10
name = '_big_eval_10'+checkerboard_str
max_size = 100+1
# boolean_list_precidence = [large, ood_only, up_to_40, up_to_50, big_eval_step_1, big_eval_step_2, big_eval_step_3, big_eval_step_4, big_eval_step_5]
log.info(f"large = {cfg.large}")
log.info(f"ood only = {cfg.ood_only}")
log.info(f"up to 40 = {cfg.up_to_40}")
log.info(f"up to 50 = {cfg.up_to_50}")
log.info(f"big eval 1 = {cfg.big_eval_step_1}")
log.info(f"big eval 2 = {cfg.big_eval_step_2}")
log.info(f"big eval 3 = {cfg.big_eval_step_3}")
log.info(f"big eval 4 = {cfg.big_eval_step_4}")
log.info(f"big eval 5 = {cfg.big_eval_step_5}")
log.info(f"big eval 6 = {cfg.big_eval_step_6}")
log.info(f"big eval 7 = {cfg.big_eval_step_7}")
log.info(f"big eval 8 = {cfg.big_eval_step_8}")
log.info(f"big eval 9 = {cfg.big_eval_step_9}")
log.info(f"big eval 10 = {cfg.big_eval_step_10}")
log.info(f"the last true value in the above list will be run, mul and pos arith can take control after this")
return logic_func, name, max_size
def main(cfg):
device = "cuda" if torch.cuda.is_available() else "cpu"
local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, "checkpoints")
tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,
local_checkpoint_folder,
cfg.eval.arch_modifications)
if cfg.max_rec is not None: # can have more/less recurrences for eval
cfg_arch.maximal_recurrence_in_eval = cfg.max_rec
else:
cfg_arch.maximal_recurrence_in_eval = cfg_arch.maximal_recurrence
log.info(f"cfg_arch.maximal_recurrence_in_eval changed to {cfg_arch.maximal_recurrence_in_eval}")
cfg_arch.throttle = False # turn throttle off
logic_func, name, max_size = grid_logic(cfg)
if cfg.mul: # multiplication
def logic_func_for_mul(data_size_1, data_size_2):
return (data_size_1 <= 25 or data_size_2 <= 25)
logic_func = logic_func_for_mul
name = '_large'
max_size = 25+1
log.info(f"mul = {cfg.mul}")
if cfg.pos_arth: # bitwise OR
def logic_func_for_pos(data_size_1, data_size_2):
return (data_size_1 <= 25 or data_size_2 <= 25)
logic_func = logic_func_for_pos
name = '_large'
max_size = 25+1
log.info(f"pos_arth = {cfg.pos_arth}")
if cfg.pos_arth_ood:
def logic_func_for_pos_ood(data_size_1, data_size_2):
return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
logic_func = logic_func_for_pos_ood
name = '_ood_only'
max_size = 40+1
log.info(f"pos_arth_ood = {cfg.pos_arth_ood}")
# import tokeniser
cfg_data_sources_values_list = list(cfg.data.sources.values())[0]
if cfg_data_sources_values_list["provider"] == "arithmetic":
tokenizer = get_tokenizer(cfg_data_sources_values_list["tokenizer_type"])
else:
log.info("exiting as this is only for arithmetic")
exit()
vocab = tokenizer.ids_to_tokens
EOS_token = tokenizer._convert_token_to_id(tokenizer.eos_token)
PAD_token = tokenizer._convert_token_to_id(tokenizer.pad_token)
assert PAD_token == 0, "PAD token must be token zero for our code to work"
# Load model
if 'alpha' not in cfg_arch:
cfg_arch['alpha'] = 1.0
model = cramming.construct_model(cfg_arch, tokenizer).to(device)
model = cramming.backend.load_model_checkpoint(model, model_file)
model.to(device)
model.eval()
log.info(f"greedy = {cfg.greedy}, note: if greedy = True this overrides any temperature arguments")
## Greedy decoding will overide any temperature arguments
if cfg.max_size_given is not None: # allows unique splits for eval
max_size = max_size_given
# Grid plots - grid search from 1x1 to 12x12 data
data_sizes = list(range(1, max_size))
acc_grid = np.zeros((len(data_sizes),len(data_sizes)))
start_ind_1 = 0
start_ind_2 = 0
tuple_method = False
completed_one = False
if "big_eval" in name:
tuple_method = True
# go up two layers and search for grid
try:
with open(f"../../accs_grid_quick{name}.json", 'r') as file:
data = json.load(file)
start_ind_1 = data[1]
start_ind_2 = data[2]
acc_grid = np.array(data[0])
log.info("loaded grid from previous run")
except:
pass
if cfg.start_ind_1_given is not None: # allows unique splits for eval
start_ind_1 = cfg.start_ind_1_given
if cfg.start_ind_2_given is not None:
start_ind_2 = cfg.start_ind_2_given
log.info(f"start_ind_1 = {start_ind_1}, start_ind_2 = {start_ind_2}")
os.makedirs("outputs", exist_ok=True)
if not cfg.extended_eval:
# main 2d loop
for data_size_1 in data_sizes:
for data_size_2 in data_sizes:
if (data_size_1 < start_ind_1 or data_size_2 < start_ind_2) and not completed_one:
continue
else:
proceed = False
# if both data sizes are less than the start indices, then dont proceed
# but if one of them is greater than the start indices, then proceed
if data_size_1 >= start_ind_1 or data_size_2 >= start_ind_2:
proceed = True
if not proceed:
continue
print(f"evaluating for {data_size_1} and {data_size_2}")
if logic_func(data_size_1, data_size_2):
completed_one = True
log.info(f"Starting iteration in grid eval for size: {data_size_1} and {data_size_2}")
correct_total = 0
# get the correct dataset, these names may need to be changed if you make new datasets
file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_padded_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset"
if cfg.reverse_inputs:
file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset"
if cfg.mul:
file_path = f"../../../../data/arithmetic_data/x_grid_eval_dataset_2_reverse_all_tokenized/x_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_exact_seed_91/hf_tokenized_dataset"
if cfg.pos_arth or cfg.pos_arth_ood:
file_path = f"../../../../data/arithmetic_data/pos_or_one_vec_zeros_eval/or_one_vec_zeros_{data_size_1}_{data_size_2}/hf_tokenized_dataset"
tokenized_dataset = datasets.load_from_disk(file_path)["test"]
data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)
equals_tensor = data_size_1+data_size_2+6
if cfg.pos_arth or cfg.pos_arth_ood:
equals_tensor = data_size_1+data_size_2+2
for batch in data_loader:
# split prompt and answer
tokenized_prompts = batch["input_ids"][:equals_tensor]
tokenized_prompts = torch.stack(tokenized_prompts).to(device)
tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)
tokenized_answers = batch["input_ids"][equals_tensor:]
tokenized_answers = torch.stack(tokenized_answers).to(device)
tokenized_answers = torch.transpose(tokenized_answers, 0, 1)
if cfg.remove_padding and (cfg_data_sources_values_list["tokenizer_type"] != "index"):
# removes the padding from the eval data
num1 = tokenized_prompts[:,:data_size_1]
op = tokenized_prompts[:,data_size_1+1:data_size_1+2]
num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]
equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]
tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)
if cfg_data_sources_values_list["tokenizer_type"] == "index":
# adding in the index hints to the input numbers
num1 = tokenized_prompts[:,:data_size_1]
num1 = index_hints_helper(num1, tokenizer)
op = tokenized_prompts[:,data_size_1+1:data_size_1+2]
num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]
num2 = index_hints_helper(num2, tokenizer)
equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]
tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)
predicted_ids = None
## below inserts the characters for the model, we decided against this in the end
predicted_ids = model._generate(tokenized_prompts, token_limit=(tokenized_answers.shape[1]*2), temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)
predicted_ids = torch.transpose(predicted_ids, 0, 1)
new_tensor = torch.zeros_like(predicted_ids)
for i in range(predicted_ids.size(0)): # inefficient!!
# Filter out values greater than 17
filtered_values = predicted_ids[i][predicted_ids[i] <= 17]
# Place filtered values in new tensor and pad with zeros
new_tensor[i, :len(filtered_values)] = filtered_values
predicted_ids = new_tensor[:, :tokenized_answers.shape[1]] # trim off the excess
predicted_ids = torch.transpose(predicted_ids, 0, 1)
else:
predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)
if len(predicted_ids.shape) > 1: # i.e. we have a batch of more than one
predicted_ids = torch.transpose(predicted_ids, 0, 1)
else:
predicted_ids = predicted_ids.reshape((1,-1)) # add a batch dim otherwise
# ignore everything after EOS on eval but replacing all after EOS with PAD
eval_tensor = predicted_ids.clone()
input_tensor_EOS = (eval_tensor == EOS_token).int()
indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)
mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]
eval_tensor[mask] = PAD_token
# compare eval tensor to correct outputs
elementwise_equal = torch.eq(eval_tensor, tokenized_answers)
rows_equal = torch.all(elementwise_equal, dim=1)
num_equal_rows = torch.sum(rows_equal).item()
correct_total += (num_equal_rows/tokenized_prompts.shape[0])
log.info(f"accuracy for {data_size_1}, {data_size_2}: {num_equal_rows} = {correct_total*100}%")
# combine the prompts and outputs
complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)
tokens_list = complete_lines.tolist()
decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens
log.info(f"example for {data_size_1}, {data_size_2}: {decoded_batch[0]}")
# save the answers down so we don't eval twice ever
with open(f"outputs/+_n_{data_size_1}_m_{data_size_2}.json", 'w') as json_file:
json.dump(decoded_batch, json_file)
acc_grid[(data_size_1-1),(data_size_2-1)] = correct_total
if tuple_method:
with open(f"../../accs_grid_quick{name}.json", "w") as file:
tuple_to_save = (acc_grid.tolist(),data_size_1,data_size_2)
json.dump(tuple_to_save, file)
log.info(f"acc grid: {acc_grid}")
with open(f"accs_grid_quick{name}.json", "w") as file:
json.dump(acc_grid.tolist(), file)
# Grid plots - one for accs one for contains
grid_plotter(acc_grid, name=name)
if cfg.extended_eval:
# extended eval to eval large numbers easily, used the large eval numebers to split up into multiple parts
number = int(re.findall(r'\d+', name)[0])
log.info("starting extended eval")
# this is hard coded for reverse all, addition past 100x100 grid, removing the padding
accs = dict()
batch_size_extended_eval = 100
old_data_path = None
for root, dirs, files in os.walk("../.."):
if f"over_100_{number}.json" in files:
old_data_path = os.path.join(root, f"over_100_{number}.json")
if number == 1:
start = 101
list_to_do = range(start,161)
elif number == 2:
list_to_do = [1000, 800]
elif number == 3:
list_to_do = [200, 700, 900]
elif number == 4:
list_to_do = [300, 400, 500, 600]
else:
print("number too high")
exit()
if old_data_path is not None: # read the old accs dict and don't repeat what we have already done
with open(old_data_path, 'r') as file:
data = json.load(file)
accs = {int(k): v for k, v in data.items()}
to_do = set(list_to_do).difference(set(accs.keys()))
list_to_do = list(to_do)
log.info(f"In extended eval with number {number}")
for data_size in list_to_do:
log.info(f"Extended eval {data_size}")
correct_total = 0
file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized_over_100/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_42/hf_tokenized_dataset"
tokenized_dataset = datasets.load_from_disk(file_path)["test"]
data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=batch_size_extended_eval, shuffle=False)
equals_tensor = data_size+data_size+6
for batch in data_loader:
# get prompt and answer
tokenized_prompts = batch["input_ids"][:equals_tensor]
tokenized_prompts = torch.stack(tokenized_prompts).to(device)
tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)
tokenized_answers = batch["input_ids"][equals_tensor:]
tokenized_answers = torch.stack(tokenized_answers).to(device)
tokenized_answers = torch.transpose(tokenized_answers, 0, 1)
# remove the padding
num1 = tokenized_prompts[:,:data_size]
op = tokenized_prompts[:,data_size+1:data_size+2]
num2 = tokenized_prompts[:,data_size+3:data_size+data_size+3]
equals = tokenized_prompts[:,data_size+data_size+4:data_size+data_size+5]
tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)
# get the output from the model
predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=cfg.temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=cfg.greedy, quick=True)
predicted_ids = torch.transpose(predicted_ids, 0, 1) # add a batch dim
eval_tensor = predicted_ids.clone()
input_tensor_EOS = (eval_tensor == EOS_token).int()
indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)
mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]
eval_tensor[mask] = PAD_token
elementwise_equal = torch.eq(eval_tensor, tokenized_answers)
rows_equal = torch.all(elementwise_equal, dim=1)
num_equal_rows = torch.sum(rows_equal).item()
correct_total += (num_equal_rows/tokenized_prompts.shape[0])
log.info(f"accuracy for {data_size}, {data_size}: {num_equal_rows} = {correct_total*100}%")
# combine the prompts and outputs
complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)
tokens_list = complete_lines.tolist()
decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens
log.info(f"example for {data_size}, {data_size}: {decoded_batch[0]}")
# save the answers down so we don't eval twice ever
accs[data_size] = correct_total
with open(f"over_100_{number}.json", 'w') as json_file:
json.dump(accs, json_file)
log.info("Eval complete")
@hydra.main(config_path="cramming/config", config_name="cfg_eval", version_base="1.3")
def launch(cfg):
log.info("calling main launch")
cfg = cramming.utils.pathfinder(cfg)
log.info(OmegaConf.to_yaml(cfg, resolve=True))
main(cfg)
if __name__ == "__main__":
launch()
================================================
FILE: cramming/__init__.py
================================================
"""Initialize cramming"""
from cramming import utils
from cramming.architectures import construct_model
from cramming.backend import load_backend
from cramming.data import load_pretraining_corpus, prepare_dataloaders
__all__ = [
"construct_model",
"load_backend",
"prepare_dataloaders",
"load_pretraining_corpus",
"utils",
]
import hydra
"""Construct interfaces to some cfg folders for use in packaged installations:"""
def get_config(overrides=[]):
"""Return default hydra config."""
with hydra.initialize(config_path="config"):
cfg = hydra.compose(config_name="cfg", overrides=overrides)
print(f"Loading default config {cfg.name}.")
return cfg
def get_model_config(arch="hf-bert-tiny", overrides=[]):
"""Return default hydra config for a given attack."""
with hydra.initialize(config_path="config/arch"):
cfg = hydra.compose(config_name=arch, overrides=overrides)
print(f"Loading model configuration {cfg.architecture}.")
return cfg
def get_backend_config(backend="torch-default", overrides=[]):
"""Return default hydra config for a given attack."""
with hydra.initialize(config_path="config/impl"):
cfg = hydra.compose(config_name=backend, overrides=overrides)
print(f"Loading backend {cfg.name}.")
return cfg
================================================
FILE: cramming/architectures/__init__.py
================================================
"""This module handles all questions of model architecture."""
from .construction import construct_model
__all__ = ["construct_model"]
================================================
FILE: cramming/architectures/attention.py
================================================
"""Attention modules. Most code heavily stolen from the GPT-neoX implementation"""
import torch
from transformers.models.bert.modeling_bert import BertSelfAttention
from .embeddings import Rotary, RotarySanityCheck, RotaryEleutherAI, RotaryLLAMA, FIRE
from typing import Optional
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear # use to mark output projections of attn while it exists
def get_attention_mechanism(idx, hidden_size, cfg_attention, norm_fn: torch.nn.Identity):
# ########## main implementation
if cfg_attention.type == "self-attention":
mechanism = SeqFirstSelfAttention(hidden_size, cfg_attention, norm_fn) # neox
# ########## other things:
elif cfg_attention.type == "pytorch":
mechanism = SelfAttentionPyTorch(hidden_size, cfg_attention) # torch default
elif cfg_attention.type == "pytorch-seqfirst":
mechanism = SeqFirstSelfAttentionPyTorch(hidden_size, cfg_attention) # torch default
elif cfg_attention.type == "huggingface":
mechanism = BertAttentionWrapper(hidden_size, cfg_attention) # always includes bias!
elif cfg_attention.type == "fourier":
mechanism = FourierMixing(hidden_size, cfg_attention)
elif cfg_attention.type == "none":
mechanism = Identity(hidden_size)
elif cfg_attention.type == "rn":
mechanism = RandomNoise(hidden_size) # i.e. no signal on where to look
else:
raise ValueError(f"Invalid attention type {cfg_attention.type} given.")
return mechanism
class Identity(torch.nn.Module):
"""mini wrapper around BERT attention from huggingface for sanity checks."""
__constants__ = ["LAYOUT"]
LAYOUT = "[B S H]"
def __init__(self, hidden_size):
super().__init__()
self.output_dim = hidden_size
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
return hidden_states
class RandomNoise(torch.nn.Module):
"""mini wrapper around BERT attention from huggingface for sanity checks."""
__constants__ = ["LAYOUT"]
LAYOUT = "[B S H]"
def __init__(self, hidden_size):
super().__init__()
self.output_dim = hidden_size
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
print("using rn")
return hidden_states + torch.normal(0, 0.1, hidden_states.shape).to(hidden_states.device)
class BertAttentionWrapper(BertSelfAttention):
"""mini wrapper around BERT attention from huggingface for sanity checks."""
__constants__ = ["LAYOUT"]
LAYOUT = "[B S H]"
def __init__(self, hidden_size, cfg_attention):
class config:
pass
config.hidden_size = hidden_size
config.num_attention_heads = cfg_attention.num_attention_heads
config.attention_probs_dropout_prob = 0.0
config.is_decoder = True
super().__init__(config)
if cfg_attention.skip_output_projection:
self.dense = torch.nn.Identity()
else:
self.dense = torch.nn.Linear(hidden_size, hidden_size, bias=cfg_attention.bias_in_proj)
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
return self.dense(super().forward(hidden_states, attention_mask)[0])
class SelfAttentionPyTorch(torch.nn.Module):
"""Minimal wrapper around pytorch self attention."""
__constants__ = ["LAYOUT"]
LAYOUT = "[B S H]"
def __init__(self, hidden_size, cfg_attention):
super().__init__()
self.attn = torch.nn.MultiheadAttention(
hidden_size,
cfg_attention.num_attention_heads,
dropout=0.0,
batch_first=True,
bias=cfg_attention.bias_in_proj,
add_bias_kv=cfg_attention.qkv_bias,
)
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
return self.attn(
hidden_states,
hidden_states,
hidden_states,
attn_mask=attention_mask[0, 0, :, :],
need_weights=False,
is_causal=True,
)[0]
class SeqFirstSelfAttentionPyTorch(torch.nn.Module):
"""Minimal wrapper around pytorch self attention."""
__constants__ = ["LAYOUT"]
LAYOUT = "[S B H]"
def __init__(self, hidden_size, cfg_attention):
super().__init__()
self.attn = torch.nn.MultiheadAttention(
hidden_size,
cfg_attention.num_attention_heads,
dropout=0.0,
batch_first=False,
bias=cfg_attention.bias_in_proj,
add_bias_kv=cfg_attention.qkv_bias,
)
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
return self.attn(
hidden_states,
hidden_states,
hidden_states,
attn_mask=attention_mask[0, 0, :, :],
need_weights=False,
is_causal=True,
)[0]
class SeqFirstSelfAttention(torch.nn.MultiheadAttention):
"""Self-attention layer.
This is the gpt neo-x implementation from:
https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py (which is a megatron variant)
This is a modified version of the neo-x implementation that I can manage to compile without graph breaks.
Inherits from MultiheadAttention to catch the same initialization
"""
__constants__ = ["LAYOUT"]
LAYOUT: str = "[S B H]"
def __init__(self, hidden_size: int, cfg_attention, norm_module=torch.nn.Identity):
torch.nn.Module.__init__(self)
self.hidden_size = hidden_size
self.num_attention_heads = cfg_attention.num_attention_heads
self.hidden_per_head = self.hidden_size // cfg_attention.num_attention_heads
self.register_buffer("norm_factor", torch.tensor(self.hidden_per_head).rsqrt())
self.cfg_attention = cfg_attention
self.use_fire = False
self.norm = norm_module()
# Strided linear layer.
self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * self.hidden_size, self.hidden_size))
if cfg_attention.qkv_bias:
self.in_proj_bias = torch.nn.Parameter(torch.zeros(3 * self.hidden_size))
else:
self.in_proj_bias = None
self.bias_k, self.bias_v = None, None # for compat with MultiheadAttention
self.output_dim = hidden_size
if cfg_attention.rotary_embedding == "sanity":
self.rotary_emb = RotarySanityCheck(self.hidden_per_head, seq_dim=0)
elif cfg_attention.rotary_embedding == "v2":
self.rotary_emb = RotaryEleutherAI(self.hidden_per_head)
elif cfg_attention.rotary_embedding == "llama":
self.rotary_emb = RotaryLLAMA(self.hidden_per_head)
elif cfg_attention.rotary_embedding == "fire":
self.rotary_emb = FIRE(cfg_attention.num_attention_heads, max_length=cfg_attention.max_length)
self.use_fire = True
elif cfg_attention.rotary_embedding:
self.rotary_emb = Rotary(self.hidden_per_head, seq_dim=0)
else:
self.rotary_emb = None
if cfg_attention.sequence_op == "torch-softmax":
self.sequence_op = TorchSoftmax(cfg_attention.seq_op_in_fp32)
elif cfg_attention.sequence_op == "shaped-attention":
self.sequence_op = TorchShaped(cfg_attention.seq_op_in_fp32, hidden_size=self.hidden_size)
elif cfg_attention.sequence_op == "swin-cosine":
self.sequence_op = SwinCosine(cfg_attention.seq_op_in_fp32)
elif cfg_attention.sequence_op == "torch-norm":
self.sequence_op = TorchNormalize(self.num_attention_heads, cfg_attention.seq_op_in_fp32)
elif cfg_attention.sequence_op == "none":
self.sequence_op = ScaledIdentity(cfg_attention.seq_op_in_fp32)
elif cfg_attention.sequence_op == "cumsum":
self.sequence_op = Cumsum(cfg_attention.seq_op_in_fp32)
elif cfg_attention.sequence_op == "cumsumexp":
self.sequence_op = CumsumExp(cfg_attention.seq_op_in_fp32)
else:
raise ValueError(f"Invalid sequence operation {cfg_attention.sequence_op} given.")
if cfg_attention.skip_output_projection:
self.out_proj = torch.nn.Identity()
else:
self.out_proj = NonDynamicallyQuantizableLinear(hidden_size, hidden_size, bias=cfg_attention.bias_in_proj)
self.attention_func = self.attention
def attention(self, query_layer, key_layer, value_layer, attention_mask: Optional[torch.Tensor] = None, training: bool = False, fire: Optional[torch.Tensor] = None):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.shape[1], query_layer.shape[2], query_layer.shape[0], key_layer.shape[0])
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# this better be fused in a clever way:
matmul_result = torch.bmm(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) * self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
if fire is not None:
attention_scores += fire
# ===========================
# Attention probs
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.sequence_op(attention_scores, attention_mask)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.shape[1], value_layer.shape[2], query_layer.shape[0], value_layer.shape[3])
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
return context_layer
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
# =====================
# hidden_states: [sq, b, h]
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = torch.nn.functional.linear(hidden_states, self.in_proj_weight, self.in_proj_bias)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, 3 * self.hidden_per_head)
mixed_x_layer = mixed_x_layer.view(
hidden_states.shape[0], hidden_states.shape[1], self.num_attention_heads, 3 * self.hidden_per_head
)
# print("mixed shape ",mixed_x_layer.shape) (82, 24, 16, 192)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [self.hidden_per_head] * 3, dim=3)
fire = None
if self.rotary_emb is not None:
if self.use_fire:
fire = self.rotary_emb(query_layer.size(0), query_layer.device)
else:
query_layer, key_layer = self.rotary_emb(query_layer, key_layer)
# print(query_layer.shape)
# ==================================
# Attention computation
# ==================================
context_layer = self.attention_func(query_layer, key_layer, value_layer, attention_mask, self.training, fire)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
# new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(context_layer.shape[0], context_layer.shape[1], self.hidden_size)
return self.out_proj(self.norm(context_layer))
class FourierMixing(torch.nn.Module):
"""Fourier mixing layer as described in the FNet paper.
Layer takes input with size [Batch, Seq, Hidden] and returns output of the same size.
This function can take an attention mask as input, but will ignore it.
"""
__constants__ = ["LAYOUT"]
LAYOUT = "[B S H]"
def __init__(self, hidden_size, cfg_attention):
super().__init__()
self.fft_op_in_fp32 = True # Always necessary (atleast on pytorch 1.12)
self.output_dim = hidden_size
if cfg_attention.rotary_embedding:
if cfg_attention.low_level_fusion:
self.rotary_emb = torch.jit.script(Rotary(hidden_size, seq_dim=1))
else:
self.rotary_emb = Rotary(hidden_size, seq_dim=0)
else:
self.rotary_emb = None
def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
"""Forward will take an attention mask but ignore it!"""
if self.rotary_emb is not None:
# full rotary (mostly on for compatibility, no guarantees on this being non-terrible)
cos, sin = self.rotary_emb.get_cos_sin_cache(hidden_states)
hidden_states = (hidden_states * cos[:, 0]) + (self.rotary_emb.rotate_half(hidden_states) * sin[:, 0])
if self.fft_op_in_fp32:
hidden_state_dtype = hidden_states.dtype
hidden_states = hidden_states.float()
else:
hidden_state_dtype = None
# Implementation 1:
# hidden_states = torch.fft.fft(torch.fft.fft(hidden_states, dim=0, , norm="ortho"), dim=2, , norm="ortho").real
# Implementation 2:
hidden_states = torch.fft.fftn(hidden_states, dim=(1, 2), norm="ortho").real # could also cast into angle?
if self.fft_op_in_fp32:
hidden_states = hidden_states.to(hidden_state_dtype)
return hidden_states
class TorchSoftmax(torch.nn.Module):
def __init__(self, seq_op_in_fp32=False):
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
if attention_mask is not None:
inputs = inputs.masked_fill_(attention_mask, -10000.0)
probs = torch.softmax(inputs, dim=-1).to(dtype=input_dtype)
return probs
class TorchShaped(torch.nn.Module):
"""Noci et al."""
def __init__(self, seq_op_in_fp32=False, hidden_size=768):
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
self.register_buffer("nfactor", torch.tensor(hidden_size).rsqrt())
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
input_dtype = inputs.dtype
breakpoint()
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
if attention_mask is not None:
inputs = inputs.masked_fill_(attention_mask, -10000.0)
probs = torch.softmax(inputs * self.nfactor, dim=-1).to(dtype=input_dtype)
I = torch.eye(probs.shape[-1], dtype=probs.dtype, device=probs.device)[None, None, :, :]
shaped_outputs = probs + I - 1 / probs.shape[-1]
return shaped_outputs
class SwinCosine(torch.nn.Module):
"""kind of SwinCosine, but not quite (normalizations scaled by mean(q) and mean(k))"""
def __init__(self, seq_op_in_fp32=False, tau=0.1, eps=1e-8):
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
self.tau = 0.1
self.eps = eps
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
"""inputs are q_i, k_j -> o_ij. Normalize"""
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
row_norm = inputs.mean(dim=-1, keepdim=True).norm(dim=-2, keepdim=True)
col_norm = inputs.mean(dim=-2, keepdim=True).norm(dim=-1, keepdim=True)
outputs = inputs / torch.clamp(row_norm * col_norm * self.tau, min=self.eps)
if attention_mask is not None:
outputs[:, :, attention_mask[0, 0]] = 0
return outputs.to(dtype=input_dtype)
class TorchNormalize(torch.nn.Module):
def __init__(self, num_attention_heads=1, seq_op_in_fp32=False):
"""Normalized attention pooling as described in Richter&Wattenhofer, 2020."""
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
self.seq_gamma = torch.nn.Parameter(torch.ones(1, num_attention_heads, 1, 1))
self.seq_beta = torch.nn.Parameter(torch.zeros(1, num_attention_heads, 1, 1))
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
# Inputs are [b, np, sq, sk]
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
if attention_mask is not None:
inputs.masked_fill_(attention_mask, 0.0)
norms = torch.nn.functional.layer_norm(inputs, inputs.shape[1:], eps=1e-05)
norms = (norms * self.seq_gamma + self.seq_beta).to(dtype=input_dtype)
return norms
class ScaledIdentity(torch.nn.Module):
def __init__(self, seq_op_in_fp32):
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
"""Sequence-scaled input."""
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
return (inputs * torch.as_tensor(inputs.shape[2]).rsqrt()).to(dtype=input_dtype)
class Cumsum(torch.nn.Module):
def __init__(self, seq_op_in_fp32):
super().__init__()
self.seq_op_in_fp32 = seq_op_in_fp32
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
"""Sequence-scaled input cumulative sum."""
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
return (inputs.cumsum(dim=-1) * pow(inputs.shape[2], -0.5)).to(dtype=input_dtype)
class CumsumExp(torch.nn.Module):
def __init__(self, seq_op_in_fp32):
super().__init__()
self.seq_op_in_fp32 = True # Required as of pytorch 1.13
def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
"""Sequence-scaled input cumulative sum."""
input_dtype = inputs.dtype
if self.seq_op_in_fp32:
inputs = inputs.to(dtype=torch.float)
return (inputs.logcumsumexp(dim=-1) * pow(inputs.shape[2], -0.5)).to(dtype=input_dtype)
================================================
FILE: cramming/architectures/components.py
================================================
"""Basic transformer components."""
import torch
from typing import Tuple
from functools import partial
from .embeddings import SinusoidalPositional, LearnablePositional, ScaledSinosoidal, Abacus
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear # use to mark output projections of attn while it exists
INPLACE = False
class EmbeddingComponent(torch.nn.Module):
"""Absolute Embeddings and sine embeddings"""
def __init__(self, cfg_embedding, norm, norm_eps):
super().__init__()
self.word_embedding = torch.nn.Embedding(cfg_embedding.vocab_size, cfg_embedding.embedding_dim)
if cfg_embedding.pos_embedding == "learned":
self.pos_embedding = LearnablePositional(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)
elif cfg_embedding.pos_embedding == "learned_rand":
self.pos_embedding = LearnablePositionalRand(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)
elif cfg_embedding.pos_embedding == "sinusoidal":
self.pos_embedding = SinusoidalPositional(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)
elif cfg_embedding.pos_embedding == "scaled-sinusoidal":
self.pos_embedding = ScaledSinosoidal(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length)
elif cfg_embedding.pos_embedding == "abacus":
self.pos_embedding = Abacus(cfg_embedding.embedding_dim, cfg_embedding.max_seq_length, max_k=cfg_embedding.max_abacus_len)
else:
self.pos_embedding = None
if cfg_embedding.normalization:
self.stabilize_low_precision = cfg_embedding.get("stable_low_precision", False)
self.norm = _get_norm_fn(norm)(cfg_embedding.embedding_dim, eps=norm_eps)
else:
self.stabilize_low_precision = False
self.norm = torch.nn.Identity()
def forward(self, input_ids):
embeds = self.word_embedding(input_ids)
if self.pos_embedding is not None:
embeds += self.pos_embedding(input_ids)
if self.stabilize_low_precision:
# Stabilize as in bnb StableEmbedding
return self.norm(embeds.to(torch.get_default_dtype())).to(embeds.dtype)
else:
return self.norm(embeds)
class PredictionHeadComponent(torch.nn.Module):
def __init__(self, cfg_arch):
super().__init__()
if cfg_arch.embedding.embedding_dim == cfg_arch.hidden_size:
output_size = cfg_arch.hidden_size
else:
output_size = cfg_arch.embedding.embedding_dim
self.dense = torch.nn.Linear(cfg_arch.hidden_size, output_size, bias=cfg_arch.use_bias)
self.nonlin = _get_nonlin_fn(cfg_arch.nonlin, use_gating=False)()
self.norm = _get_norm_fn(cfg_arch.norm)(output_size, eps=cfg_arch.norm_eps)
def forward(self, hidden_states):
hidden_states = self.norm(self.nonlin(self.dense(hidden_states)))
return hidden_states
class NormalizedResidualConnection(torch.nn.Module):
"""Implement variations on residual connection types, especially stabilized versions and deep/shaped propagation."""
def __init__(self, input_dim, cfg_arch, output_dim=None, dropout=0.0):
super().__init__()
output_dim = input_dim if output_dim is None else output_dim
self.dropout = torch.nn.Dropout(dropout) if dropout > 0 else torch.nn.Identity()
if cfg_arch.norm_scheme == "pre":
self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)
self._chosen_forward_impl = self._prenormalization_residual
elif cfg_arch.norm_scheme == "post":
self.norm = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)
self._chosen_forward_impl = self._postnormalization_residual
elif cfg_arch.norm_scheme == "simple":
self._chosen_forward_impl = self._simple_residual
elif cfg_arch.norm_scheme == "deepnorm":
self.norm = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)
if "num_transformer_layers" in cfg_arch:
self.alpha = (2.0 * cfg_arch.num_transformer_layers) ** 0.25
elif "layers_in_recurrent_block" in cfg_arch:
self.alpha = (2.0 * cfg_arch.layers_in_recurrent_block * cfg_arch.maximal_recurrence) ** 0.25
else:
raise ValueError("Need to define `num_transformer_layers` in config for deepnorm.")
self._chosen_forward_impl = self._deepnorm_residual
elif cfg_arch.norm_scheme == "shaped":
self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)
self.gamma = 0.214 # Noci et al., could make this into a parameter
self.alpha = torch.as_tensor(1 - self.gamma**2).sqrt().item()
self._chosen_forward_impl = self._prenorm_equalized_residual
elif cfg_arch.norm_scheme == "sandwich":
self.norm = _get_norm_fn(cfg_arch.norm)(input_dim, eps=cfg_arch.norm_eps)
self.norm2 = _get_norm_fn(cfg_arch.norm)(output_dim, eps=cfg_arch.norm_eps)
self._chosen_forward_impl = self._sandwich_residual
else:
raise ValueError(f"Invalid type of residual connection {cfg_arch.norm_scheme} given.")
def _simple_residual(self, residual, layer, states, *args, **kwargs):
return residual + self.dropout(layer(states, *args, **kwargs))
def _prenormalization_residual(self, residual, layer, states, *args, **kwargs):
return residual + self.dropout(layer(self.norm(states), *args, **kwargs))
def _postnormalization_residual(self, residual, layer, states, *args, **kwargs):
return self.norm(residual + layer(states, *args, **kwargs))
def _deepnorm_residual(self, residual, layer, states, *args, **kwargs):
return self.norm(residual * self.alpha + self.dropout(layer(states, *args, **kwargs)))
def _prenorm_equalized_residual(self, residual, layer, states, *args, **kwargs):
return residual * self.alpha + self.dropout(layer(self.norm(states), *args, **kwargs)) * self.gamma
def _sandwich_residual(self, residual, layer, states, *args, **kwargs):
return self.norm2(residual + self.dropout(layer(self.norm(states), *args, **kwargs)))
def forward(self, residual: torch.Tensor, layer_callable: torch.nn.Module, states: torch.Tensor, *args, **kwargs):
"""Argument might look weird here, but I find it nicer because it reads like the pre/post schemes from left to right,
as
residual + layer ( state )
Additional args are passed directly into the layer callable
"""
return self._chosen_forward_impl(residual, layer_callable, states, *args, **kwargs)
def _get_norm_fn(norm_name):
if norm_name == "ScaleNorm":
norm_fn = ScaleNorm
elif norm_name == "RMSNorm":
norm_fn = RMSNorm
elif norm_name == "ApexLayerNorm":
from apex.normalization import FusedLayerNorm
norm_fn = FusedLayerNorm
else:
norm_fn = getattr(torch.nn, norm_name)
return norm_fn
def _get_nonlin_fn(nonlin_name, use_gating=True):
if "glu" in nonlin_name.lower():
nonlin_name = nonlin_name.split("glu")[0]
wrap_in_glu = use_gating
else:
wrap_in_glu = False
nonlin_fn = getattr(torch.nn, nonlin_name) # dont mess this up :<
try:
nonlin_fn = partial(nonlin_fn, inplace=INPLACE)
nonlin_fn()
except TypeError:
nonlin_fn = getattr(torch.nn, nonlin_name)
if wrap_in_glu:
return partial(GLU, nonlin_fn)
else:
return nonlin_fn
class GLU(torch.nn.Module):
"""*-GLU activation functions.
Implementation mostly following megatron
"""
def __init__(self, sub_activation):
super().__init__()
self.sub_activation = sub_activation()
def forward(self, inputs):
x, gate = inputs.chunk(2, dim=-1)
return self.sub_activation(gate) * x
class ScaleNorm(torch.nn.Module):
"""Quick and simple scale norm implementation. "elementwise_affine" is not the ideal name but for compat with LayerNorm
Do we also need FixNorm (cosine in the last layer)? It's a maybe here:
https://github.com/lucidrains/performer-pytorch/issues/55#issuecomment-762544686
"""
def __init__(self, hidden_size: int, eps: float = 1e-5, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.learnable_scale = torch.nn.Parameter(torch.tensor(float(hidden_size) ** -0.5))
else:
self.register_buffer("learnable_scale", torch.tensor(float(hidden_size) ** -0.5))
def forward(self, inputs):
"""This is the same eps clipping as in the original ScaleNorm implementation."""
return inputs * self.learnable_scale / torch.norm(inputs, dim=-1, keepdim=True).clamp(min=self.eps)
class RMSNorm(torch.nn.Module):
"""The RMS variant of scaling norms. "elementwise_affine" is not the ideal name but for compat with LayerNorm"""
def __init__(self, hidden_size: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.learnable_scale = torch.nn.Parameter(torch.ones(hidden_size) ** -0.5)
else:
self.register_buffer("learnable_scale", torch.ones(hidden_size) ** -0.5)
def _legacy_forward(self, inputs):
"""This is the same eps clipping as in the original ScaleNorm implementation."""
return inputs * self.learnable_scale / torch.norm(inputs, dim=-1, keepdim=True).clamp(min=1e-8)
def _norm(self, x):
"""LLama implementation"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.learnable_scale
def get_causal_attention_mask(input_ids) -> torch.Tensor:
"""Simplified triangular causal mask. Adapted for multiple heads."""
seq_length = input_ids.shape[1] # not transposed yet
device = input_ids.device
# lower triangular attention mask
mask = torch.tril(torch.ones((1, 1, seq_length, seq_length), device=device)).view(1, 1, seq_length, seq_length)
# convert to binary
return mask < 0.5
def get_extended_attention_mask(attention_mask: torch.Tensor, input_shape: Tuple[int], causal_attention: bool = False) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
Method stolen from huggingface :)
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if causal_attention:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=attention_mask.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=attention_mask.device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})")
# extended_attention_mask = extended_attention_mask.to(dtype=self.setup["dtype"]) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
"""Collect inits."""
@torch.no_grad()
def _init_module(module, init_method="normal", init_std=0.02, hidden_size=768, num_layers=12):
"""Todo: refactor this insanity"""
if "deepnorm" in init_method: # This is a xavier init with changes in the MHA inits
if "normal" in init_method:
gain = init_std
elif "subln" in init_method:
gain = torch.as_tensor(2 * num_layers).log().sqrt() # foundation transformer paper, use with subln
elif "straight" in init_method:
gain = torch.as_tensor(8 * num_layers).pow(-0.25) # deepnorm paper, use with deepnorm
elif "as-is" in init_method: # use locally defined inits for each module
gain = 1.0
else:
raise ValueError(f"Invalid init method {init_method} given.")
if isinstance(module, torch.nn.Linear):
if isinstance(module, NonDynamicallyQuantizableLinear):
# This is handled below in the MultiheadAttention section
pass
else:
if module.weight is not None:
torch.nn.init.xavier_normal_(module.weight, gain=gain)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0, std=module.weight.shape[1] ** -0.5)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, torch.nn.LayerNorm):
if module.weight is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, torch.nn.MultiheadAttention): # be careful with other transformer definitions!
if "mimetic" in init_method:
if module.in_proj_weight is not None:
h = module.in_proj_weight.shape[1]
Z1 = module.in_proj_weight.new_empty([h, h])
torch.nn.init.xavier_normal_(Z1, gain=gain) # as per deepnorm prescription
I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)
U1, S1, V1 = torch.linalg.svd(Z1 + I, full_matrices=False)
V = U1 @ torch.diag_embed(S1.sqrt())
O = V1 @ torch.diag_embed(S1.sqrt())
k = module.head_dim
I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)
Qlist, Klist = [], []
for head in range(module.num_heads):
Z2 = module.in_proj_weight.new_empty([h, h])
torch.nn.init.xavier_normal_(Z2, gain=1.0) # as per deepnorm prescription
U2, S2, V2 = torch.linalg.svd(Z2 + I, full_matrices=False)
Qlist.append(U2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))
Klist.append(V2[:, :k] @ torch.diag_embed(S2[:k].sqrt()))
Q, K = torch.cat(Qlist, dim=-1), torch.cat(Klist, dim=-1)
module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())
if module.out_proj is not None:
module.out_proj.weight.data.copy_(O)
else:
if module.in_proj_weight is not None:
h = module.in_proj_weight.shape[1]
Q, K, V = (
module.in_proj_weight.new_empty([h, h]),
module.in_proj_weight.new_empty([h, h]),
module.in_proj_weight.new_empty([h, h]),
)
torch.nn.init.xavier_normal_(Q, gain=1.0) # as per deepnorm prescription
torch.nn.init.xavier_normal_(K, gain=1.0)
torch.nn.init.xavier_normal_(V, gain=gain)
module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())
# init outproj:
if module.out_proj is not None:
torch.nn.init.xavier_normal_(module.out_proj.weight, gain=gain)
if module.out_proj.bias is not None:
module.out_proj.bias.data.zero_()
if module.in_proj_bias is not None:
module.in_proj_bias.data.zero_()
if module.bias_k is not None:
module.bias_k.data.zero_()
if module.bias_v is not None:
module.bias_v.data.zero_()
if module.out_proj is not None and module.out_proj.bias is not None:
module.out_proj.bias.data.zero_()
else:
if "normal" in init_method:
std = init_std
elif init_method == "small" in init_method:
# Transformers without Tears: Improving
# the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010)
std = torch.as_tensor(2 / (5 * hidden_size)).sqrt()
elif "megatron" in init_method:
std = torch.as_tensor(1 / (3 * hidden_size)).sqrt()
# Megatron init is near-equal to normal if hidden=768, but otherwise smaller
elif "wang" in init_method:
std = 2 / num_layers / torch.as_tensor(hidden_size).sqrt()
elif "as-is" in init_method: # use locally defined inits for each module
return
else:
raise ValueError(f"Invalid init method {init_method} given.")
if isinstance(module, torch.nn.Linear):
if isinstance(module, NonDynamicallyQuantizableLinear):
# This is handled below in the MultiheadAttention section
pass
else:
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
if module.weight is not None:
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, torch.nn.LayerNorm):
if module.weight is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, torch.nn.MultiheadAttention): # be careful with other transformer definitions!
if "mimetic" in init_method:
if module.in_proj_weight is not None:
h = module.in_proj_weight.shape[1]
Z1 = module.in_proj_weight.new_empty([h, h]).normal_() / h
I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)
U1, S1, V1 = torch.linalg.svd(0.2 * Z1 + 0.2 * I, full_matrices=False)
V = U1 @ torch.diag_embed(S1.sqrt())
O = V1 @ torch.diag_embed(S1.sqrt())
k = module.head_dim
I = torch.eye(h, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype)
Qlist, Klist = [], []
for head in range(module.num_heads):
# Z2 = module.in_proj_weight.new_empty([h, h]).normal_() / h
U2, S2, V2 = torch.linalg.svd(0 + 0.5 * I, full_matrices=False) # alpha1 =0 from Trockman
Qlist.append(U2[:, :k] @ torch.diag_embed(S2[:k].sqrt())) # this is a bit pointless, ...
Klist.append(V2[:, :k] @ torch.diag_embed(S2[:k].sqrt())) # ... I've left it here for alpha1 not zero
Q, K = torch.cat(Qlist, dim=-1), torch.cat(Klist, dim=-1)
module.in_proj_weight.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())
if module.out_proj is not None:
module.out_proj.weight.data.copy_(O)
else:
if module.in_proj_weight is not None:
module.in_proj_weight.data.normal_(mean=0.0, std=std)
if module.out_proj is not None:
module.out_proj.weight.data.normal_(mean=0.0, std=std)
if module.in_proj_bias is not None:
module.in_proj_bias.data.zero_()
if module.bias_k is not None:
module.bias_k.data.zero_()
if module.bias_v is not None:
module.bias_v.data.zero_()
# init outproj:
if module.out_proj is not None and module.out_proj.bias is not None:
module.out_proj.bias.data.zero_()
================================================
FILE: cramming/architectures/construction.py
================================================
"""Interface to construct models."""
from .huggingface_interface import construct_huggingface_model
from .sanity_check import SanityCheckforPreTraining
from .crammed_transformer import construct_crammed_transformer
from .crammed_depthrecurrent import construct_crammed_recurrent
import logging
from ..utils import is_main_process
log = logging.getLogger(__name__)
def construct_model(cfg_arch, tokenizer):
model = None
eos_token_id = tokenizer.eos_token # tokenizer.vocab["<eot>"]
if "model_type" in cfg_arch:
# attempt to solve locally
if "SanityCheckLM" in cfg_arch.model_type:
model = SanityCheckforPreTraining(cfg_arch.width, tokenizer.vocab_size)
elif "ScriptableCrammedTransformer" in cfg_arch.model_type:
model = construct_crammed_transformer(cfg_arch, tokenizer.vocab_size)
elif "ScriptableCrammedDepthRecurrent" in cfg_arch.model_type:
equals_token = tokenizer.vocab["="]
model = construct_crammed_recurrent(cfg_arch, tokenizer.vocab_size, equals_token)
if model is not None: # Return local model arch
num_params = sum([p.numel() for p in model.parameters()])
if is_main_process():
log.info(f"Model with architecture {cfg_arch.model_type} loaded with {num_params:,} parameters.")
return model
try: # else try on HF
model = construct_huggingface_model(cfg_arch, tokenizer.vocab_size)
num_params = sum([p.numel() for p in model.parameters()])
if is_main_process():
log.info(f"Model with config {cfg_arch} loaded with {num_params:,} parameters.")
return model
except Exception as e:
raise ValueError(f"Invalid model architecture {cfg_arch.model_type} given. Error: {e}")
================================================
FILE: cramming/architectures/crammed_depthrecurrent.py
================================================
"""Variant for modifications of the transformer architecture that are depth-recurrent"""
import torch
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from typing import Optional
from omegaconf import OmegaConf
from .components import (
_get_norm_fn,
_get_nonlin_fn,
EmbeddingComponent,
GLU,
get_causal_attention_mask,
_init_module,
NormalizedResidualConnection,
)
from .attention import get_attention_mechanism
class crammedDepthRecurrentConfig(PretrainedConfig):
model_type = "crammedDepthRecurrent"
def __init__(self, cfg_arch_container: dict = {}, **kwargs):
self.arch = cfg_arch_container
super().__init__(**kwargs)
def construct_crammed_recurrent(cfg_arch, vocab_size, equals_token):
"""See the config file for details on what is possible."""
cfg_arch.embedding.vocab_size = vocab_size
config = crammedDepthRecurrentConfig(OmegaConf.to_container(cfg_arch, resolve=True))
if config.arch["objective_layout"] in ["fixed", "albert"]:
model = ScriptableRecurrentLMForPreTraining(config)
elif config.arch["objective_layout"] in ["TBPTT", "deepthinking"]:
model = ScriptableRecurrentLMBPTT(config, equals_token)
else:
raise ValueError(f"Invalid layout {config.arch['objective_layout']} of training objective given.")
return model
class FFNComponent(torch.nn.Module):
"""Note: The FF layer is not auto-scaled when using a GLU type activation.
Better do this manually and choose a sensible intermed_size that is nicely divisible.
The neox suggestion for approx. equal parameter count is int(4 * 2 / 3 * hidden_size) * 2 [this is ~5.33]
"""
def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=None):
super().__init__()
self.dense_in = torch.nn.Linear(hidden_size, intermed_size, bias=cfg_arch.use_bias)
self.nonlin = _get_nonlin_fn(cfg_arch.nonlin)()
if isinstance(self.nonlin, GLU):
intermed_output_size = intermed_size // 2
else:
intermed_output_size = intermed_size
if cfg_arch.sub_normalization:
self.norm = _get_norm_fn(cfg_arch.norm)(intermed_output_size, eps=cfg_arch.norm_eps)
else:
self.norm = torch.nn.Identity()
output_size = hidden_size if output_size is None else output_size
self.dense_out = torch.nn.Linear(intermed_output_size, output_size, bias=cfg_arch.use_bias)
def forward(self, hidden_states):
return self.dense_out(self.norm(self.nonlin(self.dense_in(hidden_states))))
class TransformerLayer(torch.nn.Module):
"""A transformer structure based on the components from above."""
def __init__(self, idx, cfg_arch):
super().__init__()
self.residual1 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)
self.residual2 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)
if cfg_arch.attention.sub_normalization:
sub_norm_fn = lambda: _get_norm_fn(cfg_arch.norm)(cfg_arch.hidden_size, eps=cfg_arch.norm_eps) # noqa
else:
sub_norm_fn = torch.nn.Identity
self.attn = get_attention_mechanism(idx, cfg_arch.hidden_size, cfg_arch.attention, sub_norm_fn)
self.ffn = FFNComponent(cfg_arch.hidden_size, cfg_arch.intermed_size, cfg_arch)
self.LAYOUT = self.attn.LAYOUT
def forward(self, states, attention_mask: Optional[torch.Tensor] = None):
states = self.residual1(states, self.attn, states, attention_mask)
states = self.residual2(states, self.ffn, states)
return states
class TransformerBlock(torch.nn.Module):
"""A transformer block of multiple layers (without weightsharing)."""
def __init__(self, layers, cfg_arch):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
self.seq_first = self.layers[0].LAYOUT == "[S B H]" if len(self.layers) > 0 else False
self.injection_type = cfg_arch.input_injection_type
if self.injection_type == "linear":
self.adapter = torch.nn.Linear(cfg_arch.hidden_size * 2, cfg_arch.hidden_size, bias=False)
elif self.injection_type == "ffn":
self.ffn = FFNComponent(cfg_arch.hidden_size * 2, cfg_arch.intermed_size, cfg_arch, cfg_arch.hidden_size)
def forward(self, states, injected_state, attention_mask: Optional[torch.Tensor] = None):
if self.injection_type == "none":
states = states
elif self.injection_type == "add": # this is the deafault in the config
states = states + injected_state
elif self.injection_type == "linear":
combined_inputs = torch.cat([states, injected_state], dim=-1)
states = self.adapter(combined_inputs)
elif self.injection_type == "ffn":
combined_inputs = torch.cat([states, injected_state], dim=-1)
states = self.ffn(combined_inputs)
for layer in self.layers:
states = layer(states, attention_mask)
return states
class TransposedAdapter(torch.nn.Linear): # steal init
def __init__(self, embedding_dim, hidden_size, original_adapter, tie_weights=True):
torch.nn.Module.__init__(self)
# self.adapter.weight = self.encoder.adapter.weight.T # this would be nice but cannot assign like this
if tie_weights:
self.weight = original_adapter.weight
else:
self.adapter_active = False
self.weight = torch.nn.Parameter(torch.randn([hidden_size, embedding_dim])) # transposed
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inputs):
return torch.nn.functional.linear(inputs, self.weight.T)
class ScriptableRecurrentLM(PreTrainedModel):
"""Depth-recurrent model. Trying to include most reasonable variations of this concept"""
config_class = crammedDepthRecurrentConfig
def __init__(self, config):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)
if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:
self.adapter = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, bias=False)
else:
self.adapter = torch.nn.Identity()
self.state_init = self.cfg.state_init
self.recurrent_block = torch.compile(
TransformerBlock([TransformerLayer(idx, self.cfg) for idx in range(self.cfg.layers_in_recurrent_block)], self.cfg),
mode="default",
disable=not self.cfg.local_compilation,
)
self.seq_first = self.recurrent_block.seq_first
if self.cfg.head == "identity":
self.head = torch.nn.Identity()
elif self.cfg.head == "ffn":
self.head = FFNComponent(self.cfg.hidden_size, self.cfg.intermed_size, self.cfg)
elif self.cfg.head == "linear":
self.head = torch.nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.use_bias)
else:
raise ValueError(f"Invalid head layout {self.cfg.head} given.")
if self.cfg.final_norm:
self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)
else:
self.final_norm = torch.nn.Identity()
self.register_buffer("attention_mask", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)
def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = None, num_steps_with_grad: int = None):
if input_ids.shape[1] != self.attention_mask.shape[1]:
self.attention_mask = get_causal_attention_mask(input_ids)
hidden_states = self.adapter(self.embedding(input_ids))
if self.seq_first:
hidden_states = hidden_states.transpose(0, 1).contiguous()
injected_state = hidden_states.clone()
num_steps_prefix = 0 if num_steps_no_grad is None else num_steps_no_grad
hidden_states = self.initialize_state(hidden_states)
# Recurr without gradients
with torch.no_grad():
for repeat in range(num_steps_prefix):
hidden_states = self.recurrent_block(hidden_states, injected_state, self.attention_mask).clone()
num_steps_active = self.cfg.maximal_recurrence if num_steps_with_grad is None else num_steps_with_grad
# Recur with gradients
for repeat in range(num_steps_active):
hidden_states = self.recurrent_block(hidden_states, injected_state, self.attention_mask).clone()
return self.final_norm(self.head(hidden_states))
def initialize_state(self, hidden_states):
if self.cfg.initial_hidden_randomized:
batch_size = hidden_states.shape[0]
if self.state_init == "normal":
hidden_states = torch.randn_like(hidden_states)
elif self.state_init == "embed": # initialized like a BERT embedding
hidden_states = torch.randn_like(hidden_states).mul(0.02)
elif self.state_init == "zero":
hidden_states = torch.zeros_like(hidden_states)
elif self.state_init == "unit":
hidden_states = torch.randn_like(hidden_states)
std, mean = torch.std_mean(hidden_states, dim=-1, keepdim=True)
hidden_states = (hidden_states - mean) / std
return hidden_states
class ScriptableRecurrentLMReplicaConcat(PreTrainedModel):
"""Depth-recurrent model. with skips inside block
This is nearly the same as ScriptableRecurrentLM but has skips inside block too"""
config_class = crammedDepthRecurrentConfig
def __init__(self, config):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)
if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:
self.adapter = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, bias=False)
else:
self.adapter = torch.nn.Identity()
self.state_init = self.cfg.state_init
self.max_recurs = self.cfg.layers_in_recurrent_block
self.recurrent_blocks = []
print("Initializing feedforward blocks with recall connections")
for _ in range(self.max_recurs):
self.recurrent_blocks.append(
torch.compile(TransformerBlock([TransformerLayer(1, self.cfg)], self.cfg),
mode="default",
disable=not self.cfg.local_compilation,)
)
self.recurrent_blocks = torch.nn.ModuleList(self.recurrent_blocks)
print(f"Initialized feedforward blocks with recall connections. "
f"It has the depth of {self.max_recurs}")
self.seq_first = self.recurrent_blocks[0].seq_first
if self.cfg.head == "identity":
self.head = torch.nn.Identity()
elif self.cfg.head == "ffn":
self.head = FFNComponent(self.cfg.hidden_size, self.cfg.intermed_size, self.cfg)
elif self.cfg.head == "linear":
self.head = torch.nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.use_bias)
else:
raise ValueError(f"Invalid head layout {self.cfg.head} given.")
if self.cfg.final_norm:
self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)
else:
self.final_norm = torch.nn.Identity()
self.register_buffer("attention_mask", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)
def apply_recurrent_block(self, hidden_states, injected_state, attention_mask):
for block in self.recurrent_blocks:
hidden_states = block(hidden_states, injected_state, attention_mask)
return hidden_states
def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = None, num_steps_with_grad: int = None):
if input_ids.shape[1] != self.attention_mask.shape[1]:
self.attention_mask = get_causal_attention_mask(input_ids)
hidden_states = self.adapter(self.embedding(input_ids))
if self.seq_first:
hidden_states = hidden_states.transpose(0, 1).contiguous()
injected_state = hidden_states.clone()
num_steps_prefix = 0 if num_steps_no_grad is None else num_steps_no_grad
hidden_states = self.initialize_state(hidden_states)
# Recurr without gradients
with torch.no_grad():
for repeat in range(num_steps_prefix):
hidden_states = self.apply_recurrent_block(hidden_states, injected_state, self.attention_mask).clone()
num_steps_active = self.cfg.maximal_recurrence if num_steps_with_grad is None else num_steps_with_grad
# Recur with gradients
for repeat in range(num_steps_active):
hidden_states = self.apply_recurrent_block(hidden_states, injected_state, self.attention_mask).clone()
return self.final_norm(self.head(hidden_states))
def initialize_state(self, hidden_states):
if self.cfg.initial_hidden_randomized:
batch_size = hidden_states.shape[0]
if self.state_init == "normal":
hidden_states = torch.randn_like(hidden_states)
elif self.state_init == "embed": # initialized like a BERT embedding
hidden_states = torch.randn_like(hidden_states).mul(0.02)
elif self.state_init == "zero":
hidden_states = torch.zeros_like(hidden_states)
elif self.state_init == "unit":
hidden_states = torch.randn_like(hidden_states)
std, mean = torch.std_mean(hidden_states, dim=-1, keepdim=True)
hidden_states = (hidden_states - mean) / std
return hidden_states
"""Generator fn for these models."""
@torch.no_grad()
def _generate(self, input_ids, token_limit=100, temperature=1.0, steps_at_generation_time=None, track_steps=False, greedy=False, quick=False, **kwargs):
"""Generate token_limit many tokens from input_ids prompt.
track_steps = for making thinking plots
"""
predicted_ids = []
tracking = []
num_steps = self.cfg.maximal_recurrence_in_eval if steps_at_generation_time is None else steps_at_generation_time
logit_tensor = torch.zeros(token_limit, num_steps, self.cfg.embedding.vocab_size)
for gen_idx in range(token_limit):
if input_ids.shape[1] != self.encoder.attention_mask.shape[1]:
self.encoder.attention_mask = get_causal_attention_mask(input_ids)
hidden_states = self.encoder.adapter(self.encoder.embedding(input_ids))
if self.encoder.seq_first:
hidden_states = hidden_states.transpose(0, 1).contiguous()
injected_state = hidden_states
hidden_states = self.encoder.initialize_state(hidden_states)
# Recur without gradient
step = []
with torch.no_grad():
for repeat in range(num_steps):
if hasattr(self.encoder, 'recurrent_blocks'):
for block in self.encoder.recurrent_blocks:
hidden_states = block(hidden_states, injected_state, self.encoder.attention_mask)
else:
hidden_states = self.encoder.recurrent_block._orig_mod(hidden_states, injected_state,
self.encoder.attention_mask)
if track_steps:
# keep track of the intermediate probs
output_states = self.encoder.final_norm(self.encoder.head(hidden_states.clone()))
logits = self.decoder(self.adapter(output_states))
logits = logits[-1, :, :] if self.encoder.seq_first else logits[:, -1, :]
if greedy:
probs = torch.softmax(logits, dim=-1)
predicted_token = torch.argmax(logits, dim=1).unsqueeze(dim=0)
else:
probs = torch.softmax(logits * temperature, dim=-1)
predicted_token = torch.multinomial(probs, 1)
logit_tensor[gen_idx, repeat, :] = probs
step.append(predicted_token)
if track_steps:
predicted_token = step[-1]
else:
# calcualte the probs if we haven't already
output_states = self.encoder.final_norm(self.encoder.head(hidden_states.clone()))
logits = self.decoder(self.adapter(output_states))
logits = logits[-1, :, :] if self.encoder.seq_first else logits[:, -1, :]
if greedy:
predicted_token = torch.argmax(logits, dim=1).unsqueeze(dim=0)
else:
predicted_token = torch.multinomial(torch.softmax(logits * temperature, dim=-1), 1)
if quick:
input_ids = torch.cat((input_ids, torch.transpose(predicted_token, 0, 1)), dim=1)
else:
input_ids = torch.cat([input_ids, predicted_token], dim=-1)
predicted_ids += [predicted_token]
tracking.append(step)
if quick:
generated_ids = torch.stack(predicted_ids, dim=1).squeeze()
else:
generated_ids = torch.cat(predicted_ids, dim=-1)
if track_steps:
return generated_ids, tracking, logit_tensor # tracking is a [num generated tokens, num recurrences] list of lists of tensors of which each tensor is a token id
return generated_ids
class ScriptableRecurrentLMForPreTraining(PreTrainedModel):
"""Pretraining version"""
config_class = crammedDepthRecurrentConfig
def __init__(self, config):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.encoder = ScriptableRecurrentLM(config)
if self.cfg.embedding.embedding_dim != self.cfg.hidden_size:
self.adapter = TransposedAdapter(
self.cfg.embedding.embedding_dim, self.cfg.hidden_size, self.encoder.adapter, self.cfg.tie_weights
)
else:
self.adapter = torch.nn.Identity()
self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)
if self.cfg.tie_weights:
self.decoder.weight = self.encoder.embedding.word_embedding.weight
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) # size_average defaults to True so when using masking loss is calculated correctly
self._init_weights()
def _init_weights(self, module=None):
modules = self.modules() if module is None else [module]
for module in modules:
_init_module(
module,
self.cfg.init.type,
self.cfg.init.std,
self.cfg.hidden_size,
self.cfg.layers_in_recurrent_block * self.cfg.maximal_recurrence,
)
def forward(self, input_ids: torch.Tensor, *args, **kwargs):
outputs = self.decoder(self.adapter(self.encoder(input_ids, num_steps_no_grad=0, num_steps_with_grad=self.cfg.maximal_recurrence)))
if self.encoder.seq_first:
shifted_outputs = outputs[:-1]
shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()
outputs = outputs.detach().transpose(0, 1)
else:
shifted_outputs = outputs[..., :-1, :].contiguous()
shifted_labels = input_ids[..., 1:].contiguous()
outputs = outputs.detach()
# Flatten the tokens and compute loss
loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), shifted_labels.view(-1))
return {"loss": loss, "logits": outputs[:, -1, :], "log_perplexity": loss.clone().detach()}
def _generate(self, input_ids, token_limit=100, temperature=0.7, steps_at_generation_time=None):
return _generate(self, input_ids, token_limit, temperature, steps_at_generation_time)
class ScriptableRecurrentLMBPTT(PreTrainedModel):
"""Pretraining version with stochastic depth / trunc. BPTT"""
config_class = crammedDepthRecurrentConfig
def __init__(self, config, equals_token):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.equals_token = equals_token
self.max_recurrences_for_training = self.cfg.maximal_recurrence
self.max_backprop = max(self.cfg.maximal_recurrence // 2 if self.cfg.max_backprop is None else self.cfg.max_backprop, 1)
try:
self.forward_only_model_with_skip = self.cfg.forward_only_model_with_skip
if self.cfg.forward_only_model_with_skip:
print("Using forward only model with skip")
self.encoder = ScriptableRecurrentLMReplicaConcat(config)
else:
self.encoder = ScriptableRecurrentLM(config)
except:
self.encoder = ScriptableRecurrentLM(config)
self.adapter = TransposedAdapter(self.cfg.embedding.embedding_dim, self.cfg.hidden_size, self.encoder.adapter, self.cfg.tie_weights)
self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)
if self.cfg.tie_weights:
self.decoder.weight = self.encoder.embedding.word_embedding.weight
self.throttle = self.cfg.throttle
self.alpha = self.cfg.alpha
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction=self.cfg.loss_reduction) # size_average defaults to True so when using masking loss is calculated correctly
self._init_weights()
self.mask_before_equals = self.cfg.mask_before_equals
self.model_call = self.prog_model_call_with_masking # moved the logic for masking before equals into this function
def _init_weights(self, module=None):
modules = self.modules() if module is None else [module]
for module in modules:
_init_module(
module,
self.cfg.init.type,
self.cfg.init.std,
self.cfg.hidden_size,
self.cfg.layers_in_recurrent_block * self.cfg.maximal_recurrence,
)
def set_max_recurrences_for_training(self, new_max):
"""Can play around with recurrences during training"""
self.max_recurrences_for_training = new_max
self.max_backprop = max(self.max_recurrences_for_training // 2 if self.cfg.max_backprop is None else self.cfg.max_backprop, 1)
def forward(self, input_ids: torch.Tensor, *args, **kwargs):
"""
WARNING: max iters outputs is used for logits and entropy calcs
"""
if self.training:
loss, outputs = self.forward_progressive(input_ids)
if self.throttle:
Ek = 1 + min(self.max_recurrences_for_training / 4, self.max_backprop / 2)
loss = loss * (Ek / self.max_backprop)
else:
loss, outputs = self.model_call(input_ids, n=self.cfg.maximal_recurrence_in_eval, k=0)
return {"loss": loss, "logits": outputs[:, -1, :], "log_perplexity": loss.clone().detach()}
def forward_progressive(self, input_ids):
"""Implements progressive loss"""
if self.alpha != 1:
# max iters forward pass
n = self.max_recurrences_for_training-self.max_backprop
k = self.max_backprop # i.e. maxmimise the number of layers we back prop through
loss_max_iters, outputs_max_iters = self.model_call(input_ids, n=n, k=k)
else:
loss_max_iters = torch.zeros(1, dtype=torch.float32).to(input_ids.get_device())
if self.alpha != 0:
# stochastic forward pass
n = torch.randint(low=0, high=self.max_recurrences_for_training, size=(1,))
k = torch.randint(low=1, high=1 + min(self.max_recurrences_for_training - n, self.max_backprop), size=(1,))
loss_progressive, outputs_progressive = self.model_call(input_ids, n=n, k=k)
if self.alpha == 1:
outputs_max_iters = outputs_progressive
else:
loss_progressive = torch.zeros(1, dtype=torch.float32).to(input_ids.get_device())
loss = (1 - self.alpha) * loss_max_iters + self.alpha * loss_progressive
# Returning outputs max_iters to be used for logits, could try outputs_progressive
return loss, outputs_max_iters
def prog_model_call_with_masking(self, input_ids, n, k):
if self.mask_before_equals: # mask before equals
indices_of_equals = (input_ids == self.equals_token).nonzero()[:, 1] # gets the index of equals sign for each tensor in the batch
max_indices = torch.arange(input_ids.size(1), device=input_ids.device) # tensor for mask
masks = max_indices.unsqueeze(0) > indices_of_equals.unsqueeze(1) # fill tensor after including index of = sign for each row
else: # mask only the random padding
masks = input_ids != 0
outputs = self.decoder(self.adapter(self.encoder(input_ids, num_steps_no_grad=n, num_steps_with_grad=k)))
if self.encoder.seq_first:
shifted_outputs = outputs[:-1]
shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()
outputs = outputs.detach().transpose(0, 1)
masked = torch.mul(shifted_labels, masks[..., 1:].transpose(0, 1))
else:
shifted_outputs = outputs[..., :-1, :].contiguous()
shifted_labels = input_ids[..., 1:].contiguous()
outputs = outputs.detach()
masked = torch.mul(shifted_labels, masks[..., 1:])
masked[masked == 0] = -100 # mask all 0's in loss
shifted_outputs_shape = shifted_outputs.shape
loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), masked.view(-1)) # CE_Loss(Input, Target)
if self.cfg.loss_reduction=='none': # giving all output samples equal weighting
loss = loss.view(shifted_outputs_shape[0],shifted_outputs_shape[1])
loss = torch.mean(loss, dim=1)
loss = torch.mean(loss)
return loss, outputs
def _generate(self, input_ids, token_limit=100, temperature=1.0, steps_at_generation_time=None, track_steps=False, greedy=False, quick=False):
return _generate(self, input_ids, token_limit, temperature, steps_at_generation_time, track_steps, greedy=greedy, quick=quick)
# ###### HF registry here? ############### #
AutoConfig.register("crammedDepthRecurrent", crammedDepthRecurrentConfig)
AutoModel.register(crammedDepthRecurrentConfig, ScriptableRecurrentLM)
AutoModelForCausalLM.register(crammedDepthRecurrentConfig, ScriptableRecurrentLMForPreTraining)
================================================
FILE: cramming/architectures/crammed_transformer.py
================================================
"""Base file for modifications of the transformer architecture"""
import torch
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from typing import Optional
from omegaconf import OmegaConf
from .components import (
_get_norm_fn,
_get_nonlin_fn,
NormalizedResidualConnection,
EmbeddingComponent,
GLU,
get_causal_attention_mask,
_init_module,
)
from .attention import get_attention_mechanism
class crammedTransformerConfig(PretrainedConfig):
model_type = "crammedTransformer"
def __init__(self, cfg_arch_container: dict = {}, **kwargs):
self.arch = cfg_arch_container
super().__init__(**kwargs)
def construct_crammed_transformer(cfg_arch, vocab_size):
"""See the config file for details on what is possible."""
cfg_arch.embedding.vocab_size = vocab_size
config = crammedTransformerConfig(OmegaConf.to_container(cfg_arch, resolve=True))
model = ScriptableLMForPreTraining(config)
return model
class FFNComponent(torch.nn.Module):
"""Note: The FF layer is not auto-scaled when using a GLU type activation.
Better do this manually and choose a sensible intermed_size that is nicely divisible.
The neox suggestion for approx. equal parameter count is int(4 * 2 / 3 * hidden_size) * 2 [this is ~5.33]
"""
def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=None):
super().__init__()
self.dense_in = torch.nn.Linear(hidden_size, intermed_size, bias=cfg_arch.use_bias)
self.nonlin = _get_nonlin_fn(cfg_arch.nonlin)()
if isinstance(self.nonlin, GLU):
intermed_output_size = intermed_size // 2
else:
intermed_output_size = intermed_size
if cfg_arch.sub_normalization:
self.norm = _get_norm_fn(cfg_arch.norm)(intermed_output_size, eps=cfg_arch.norm_eps)
else:
self.norm = torch.nn.Identity()
output_size = hidden_size if output_size is None else output_size
self.dense_out = torch.nn.Linear(intermed_output_size, output_size, bias=cfg_arch.use_bias)
def forward(self, hidden_states):
return self.dense_out(self.norm(self.nonlin(self.dense_in(hidden_states))))
class TransformerLayer(torch.nn.Module):
"""A transformer structure based on the components from above."""
def __init__(self, idx, cfg_arch):
super().__init__()
self.residual1 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)
self.residual2 = NormalizedResidualConnection(cfg_arch.hidden_size, cfg_arch)
if cfg_arch.attention.sub_normalization:
sub_norm_fn = lambda: get_norm_fn(cfg_arch.norm)(cfg_arch.hidden_size, eps=cfg_arch.norm_eps) # noqa
else:
sub_norm_fn = torch.nn.Identity
self.attn = get_attention_mechanism(idx, cfg_arch.hidden_size, cfg_arch.attention, sub_norm_fn)
self.ffn = FFNComponent(cfg_arch.hidden_size, cfg_arch.intermed_size, cfg_arch)
self.LAYOUT = self.attn.LAYOUT
def forward(self, states, attention_mask: Optional[torch.Tensor] = None):
states = self.residual1(states, self.attn, states, attention_mask)
states = self.residual2(states, self.ffn, states)
return states
class ScriptableLM(PreTrainedModel):
"""Simplified transformer wrapper."""
config_class = crammedTransformerConfig
def __init__(self, config):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.embedding = EmbeddingComponent(self.cfg.embedding, self.cfg.norm, self.cfg.norm_eps)
self.layers = torch.nn.ModuleList([TransformerLayer(idx, self.cfg) for idx in range(self.cfg.num_transformer_layers)])
self.seq_first = self.layers[0].LAYOUT == "[S B H]" if len(self.layers) > 0 else False
if self.cfg.final_norm:
self.final_norm = _get_norm_fn(self.cfg.norm)(self.cfg.hidden_size, eps=self.cfg.norm_eps)
else:
self.final_norm = torch.nn.Identity()
self.register_buffer("attention_mask", torch.ones([0, 0, 0, 0], dtype=torch.bool), persistent=False)
def forward(self, input_ids: torch.Tensor):
if input_ids.shape[1] != self.attention_mask.shape[1]:
self.attention_mask = get_causal_attention_mask(input_ids)
hidden_states = self.embedding(input_ids)
if self.seq_first:
hidden_states = hidden_states.transpose(0, 1).contiguous()
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module(hidden_states, self.attention_mask)
# if self.seq_first:
# hidden_states = hidden_states.transpose(0, 1).contiguous()
# this happens only in the output if necessary
return self.final_norm(hidden_states)
class ScriptableLMForPreTraining(PreTrainedModel):
"""Pretraining version with optional prediction head and variant for sparse prediction."""
config_class = crammedTransformerConfig
def __init__(self, config):
super().__init__(config)
self.cfg = OmegaConf.create(config.arch)
self.encoder = ScriptableLM(config)
self.decoder = torch.nn.Linear(self.cfg.embedding.embedding_dim, self.cfg.embedding.vocab_size, bias=self.cfg.decoder_bias)
self.decoder.weight = self.encoder.embedding.word_embedding.weight
self.loss_fn = torch.nn.CrossEntropyLoss()
self._init_weights()
def _init_weights(self, module=None):
modules = self.modules() if module is None else [module]
for module in modules:
_init_module(
module,
self.cfg.init.type,
self.cfg.init.std,
self.cfg.hidden_size,
self.cfg.num_transformer_layers,
)
def forward(self, input_ids: torch.Tensor, *args, **kwargs):
outputs = self.decoder(self.encoder(input_ids))
if self.encoder.seq_first:
shifted_outputs = outputs[:-1]
shifted_labels = input_ids.transpose(0, 1)[1:].contiguous()
outputs = outputs.detach().transpose(0, 1)
else:
shifted_outputs = outputs[..., :-1, :].contiguous()
shifted_labels = input_ids[..., 1:].contiguous()
outputs = outputs.detach()
# Flatten the tokens and compute loss
loss = self.loss_fn(shifted_outputs.view(-1, shifted_outputs.shape[-1]), shifted_labels.view(-1))
return {"loss": loss, "logits": outputs[:, -1, :], "log_perplexity": loss.clone().detach()}
# ###### HF registry here? ############### #
AutoConfig.register("crammedTransformer", crammedTransformerConfig)
AutoModel.register(crammedTransformerConfig, ScriptableLM)
AutoModelForCausalLM.register(crammedTransformerConfig, ScriptableLMForPreTraining)
================================================
FILE: cramming/architectures/embeddings.py
================================================
"""Non-standard embedding implementations."""
import torch
import math
from typing import Tuple
from einops import repeat
import random
class PositionalEmbedding(torch.nn.Module):
# https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15C1-L31C37
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = (1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))).float()
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
# sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
tensor_24_17_1 = pos_seq.float().unsqueeze(2)
vector_512_expanded = self.inv_freq.unsqueeze(0).unsqueeze(1)
result = torch.matmul(tensor_24_17_1, vector_512_expanded)
sinusoid_inp = result.squeeze(2)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb
class RandomNoise(torch.nn.Module):
def __init__(self, embedding_dim, max_seq_length=5000):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, input_ids):
return torch.normal(0, 0.1, size=(input_ids.size(0), input_ids.size(1), self.embedding_dim)).to(input_ids.device)
class RPE(torch.nn.Module):
# https://jaketae.github.io/study/relative-positional-encoding/
# def __init__(self, embedding_dim, max_seq_length=5000):
# super().__init__()
# def forward(self, input_ids):
# return torch.normal(0, 0.1, size=input_ids.shape)
def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):
super().__init__()
d_head, remainder = divmod(d_model, num_heads)
if remainder:
raise ValueError("incompatible `d_model` and `num_heads`")
self.max_len = max_len
self.d_model = d_model
self.num_heads = num_heads
self.key = torch.nn.Linear(d_model, d_model)
self.value = torch.nn.Linear(d_model, d_model)
self.query = torch.nn.Linear(d_model, d_model)
self.dropout = torch.nn.Dropout(dropout)
self.Er = torch.nn.Parameter(torch.randn(max_len, d_head))
self.register_buffer("mask", torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0))
# self.mask.shape = (1, 1, max_len, max_len)
def forward(self, x):
# x.shape == (batch_size, seq_len, d_model)
batch_size, seq_len, _ = x.shape
if seq_len > self.max_len:
raise ValueError("sequence length exceeds model capacity")
k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
# k_t.shape = (batch_size, num_heads, d_head, seq_len)
v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# shape = (batch_size, num_heads, seq_len, d_head)
start = self.max_len - seq_len
Er_t = self.Er[start:, :].transpose(0, 1)
# Er_t.shape = (d_head, seq_len)
QEr = torch.matmul(q, Er_t)
# QEr.shape = (batch_size, num_heads, seq_len, seq_len)
Srel = self.skew(QEr)
# Srel.shape = (batch_size, num_heads, seq_len, seq_len)
QK_t = torch.matmul(q, k_t)
# QK_t.shape = (batch_size, num_heads, seq_len, seq_len)
attn = (QK_t + Srel) / math.sqrt(q.size(-1))
mask = self.mask[:, :, :seq_len, :seq_len]
# mask.shape = (1, 1, seq_len, seq_len)
attn = attn.masked_fill(mask == 0, float("-inf"))
# attn.shape = (batch_size, num_heads, seq_len, seq_len)
attn = torch.nn.functional.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
# out.shape = (batch_size, num_heads, seq_len, d_head)
out = out.transpose(1, 2)
# out.shape == (batch_size, seq_len, num_heads, d_head)
out = out.reshape(batch_size, seq_len, -1)
# out.shape == (batch_size, seq_len, d_model)
return self.dropout(out)
def skew(self, QEr):
# QEr.shape = (batch_size, num_heads, seq_len, seq_len)
padded = torch.nn.functional.pad(QEr, (1, 0))
# padded.shape = (batch_size, num_heads, seq_len, 1 + seq_len)
batch_size, num_heads, num_rows, num_cols = padded.shape
reshaped = padded.reshape(batch_size, num_heads, num_cols, num_rows)
# reshaped.size = (batch_size, num_heads, 1 + seq_len, seq_len)
Srel = reshaped[:, :, 1:, :]
# Srel.shape = (batch_size, num_heads, seq_len, seq_len)
return Srel
# module partially stolen from pytorch examples:
class SinusoidalPositional(torch.nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
"""
def __init__(self, embedding_dim, max_seq_length=5000):
super().__init__()
pe = torch.zeros(max_seq_length, embedding_dim)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe, persistent=False)
def forward(self, input_ids):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [batch size, sequence length, embed dim]
output: [batch size, sequence length, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
return self.pe[:, : input_ids.shape[1], :]
class ScaledSinosoidal(SinusoidalPositional):
"""Sinusoidal with scaling (see FLASH paper)."""
def __init__(self, embedding_dim, max_seq_length):
super().__init__(embedding_dim, max_seq_length)
self.scale_factor = torch.nn.Parameter(torch.tensor([1.0 / embedding_dim**0.5]))
def forward(self, input_ids):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [batch size, sequence length, embed dim]
output: [batch size, sequence length, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
return self.scale_factor * self.pe[:, : input_ids.shape[1], :]
class LearnablePositional(torch.nn.Module):
"""Shorthand for a learnable embedding."""
def __init__(self, embedding_dim, max_seq_length=1024):
super().__init__()
self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
def forward(self, input_ids):
"""This is a batch-first implementation"""
position_ids = self.position_ids[:, : input_ids.shape[1]]
return self.embedding(position_ids)
class LearnablePositionalRand(torch.nn.Module):
"""Shorthand for a learnable embedding."""
def __init__(self, embedding_dim, max_seq_length=1024):
super().__init__()
self.max_length = max_seq_length
self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
def forward(self, input_ids):
"""This is a batch-first implementation"""
seq_length = input_ids.shape[1]
device = input_ids.device
if seq_length > self.max_length: # max length will be increased to max sequnece length if max length is short
max_length = seq_length
else:
max_length = self.max_length
position_ids = self.position_ids[:, : input_ids.shape[1]]
position_ids = torch.sort(torch.randperm(max_length, dtype=torch.long, device=device)[:seq_length]).values
return self.embedding(position_ids)
# Code stolen from GPT-X:
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=True)
self.seq_len_cached = def_seq_length
self.seq_dim = seq_dim
cos_cache, sin_cache = self._get_cos_sin()
self.register_buffer("cos_cached", cos_cache, persistent=False)
self.register_buffer("sin_cached", sin_cache, persistent=False)
# Force fusions on batched version
def rotate_half(x: torch.Tensor):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] # torch.split(x, x.shape[-1] // 2, dim=-1) # not faster
return torch.cat((-x2, x1), dim=-1)
def rope_fn(cos: torch.Tensor, sin: torch.Tensor, query_layer: torch.Tensor, key_layer: torch.Tensor):
QK = torch.cat([query_layer, key_layer], dim=1)
rotated = QK * cos[: QK.shape[0]] + rotate_half(QK) * sin[: QK.shape[0]]
return torch.split(rotated, query_layer.shape[1], dim=1)
self.rope_fn = rope_fn # handle fusion on module level
@torch.no_grad()
def get_cos_sin_cache(self, x: torch.Tensor):
seq_len = x.shape[self.seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = x.shape[self.seq_dim]
cos_cache, sin_cache = self._get_cos_sin()
self.cos_cached = cos_cache.to(x.device)
self.sin_cached = sin_cache.to(x.device)
return self.cos_cached, self.sin_cached
def _get_cos_sin(self):
t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
if self.seq_dim == 0:
return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()
else:
return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()
def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
cos_cached, sin_cached = self.get_cos_sin_cache(query_layer)
return self.rope_fn(cos_cached, sin_cached, query_layer, key_layer)
@torch.jit.export
def single_forward(self, inputs: torch.Tensor):
"""For cases where shapes of Q and K do not match."""
cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]
return inputs * cos + self.rotate_half(inputs) * sin
def rotate_half(self, x: torch.Tensor):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1) # torch.split(x, x.shape[-1] // 2, dim=-1) # not faster
class RotarySanityCheck(torch.nn.Module):
"""not again..."""
def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=True)
self.seq_len_cached = def_seq_length
self.seq_dim = seq_dim
cos_cache, sin_cache = self._get_cos_sin()
self.register_buffer("cos_cached", cos_cache, persistent=False)
self.register_buffer("sin_cached", sin_cache, persistent=False)
@torch.no_grad()
def get_cos_sin_cache(self, x: torch.Tensor):
seq_len = x.shape[self.seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = x.shape[self.seq_dim]
cos_cache, sin_cache = self._get_cos_sin()
self.cos_cached = cos_cache.to(x.device)
self.sin_cached = sin_cache.to(x.device)
return self.cos_cached, self.sin_cached
def _get_cos_sin(self):
t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
if self.seq_dim == 0:
return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()
else:
return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()
def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
# cos, sin = self.get_cos_sin_cache(key_layer)
# cos, sin = (cos[offset : query_layer.shape[0] + offset, ...], sin[offset : query_layer.shape[0] + offset, ...])
cos, sin = self.cos_cached, self.sin_cached
return (query_layer * cos) + (self.rotate_half(query_layer) * sin), (key_layer * cos) + (self.rotate_half(key_layer) * sin)
def rotate_half(self, x: torch.Tensor):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1) # torch.split(x, x.shape[-1] // 2, dim=-1) # not faster
@torch.jit.export
def single_forward(self, inputs: torch.Tensor):
"""For cases where shapes of Q and K do not match."""
cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]
return inputs * cos + self.rotate_half(inputs) * sin
# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/rotary.py who adapted from
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
class RotaryEleutherAI(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
"""
_seq_len_cached: int
# _cos_cached: Optional[torch.Tensor]
# _sin_cached: Optional[torch.Tensor]
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)
_cos_cached, _sin_cached = self._update_cos_sin_tables(torch.randn(1, 128, 1), seq_dimension=-2)
self.register_buffer("_cos_cached", _cos_cached, persistent=False)
self.register_buffer("_sin_cached", _sin_cached, persistent=False)
@torch.jit.ignore
def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
# if seq_len != self._seq_len_cached: # or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
cos_cached = repeat(torch.cos(freqs).to(x.dtype), "... d -> ... (d 2)")
sin_cached = repeat(torch.sin(freqs).to(x.dtype), "... d -> ... (d 2)")
return cos_cached, sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
# assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
# self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=seq_dimension)
return (
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
)
def rotate_half(self, x: torch.Tensor):
x = x.unflatten(dim=-1, sizes=(-1, 2))
x1, x2 = x.unbind(dim=-1)
rotated_x = torch.stack((-x2, x1), dim=-1)
return rotated_x.flatten(start_dim=-2)
def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_dimension: int = -2):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[: x.shape[seq_dimension], :]
sin = sin[: x.shape[seq_dimension], :]
if seq_dimension == -3:
cos = cos[:, None, :]
sin = sin[:, None, :]
return (x * cos) + (self.rotate_half(x) * sin)
class RotaryLLAMA(torch.nn.Module):
"""Facebook implementation of rotary embeddings."""
def __init__(self, hidden_per_head, base=10000, max_seq_length=512, seq_dim: int = 0):
super().__init__()
self.seq_dim: int = seq_dim
freqs_cis = self.precompute_freqs_cis(dim=hidden_per_head, end=max_seq_length * 2, theta=base)
self.register_buffer("freqs_cis", freqs_cis)
def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
return self.apply_rotary_emb(query_layer, key_layer, freqs_cis=self.freqs_cis)
def apply_rotary_emb(self, xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):
freqs_cis = freqs_cis[: x.shape[self.seq_dim]]
# shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
# shape = [1, seq_length, 1, hidden_per_head]
shape = [s if i == self.seq_dim or i == x.ndim - 1 else 1 for i, s in enumerate(x.shape)]
return freqs_cis.view(*shape)
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
class FIRE(torch.nn.Module):
def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-6, max_length=0):
"""
FIRE attention bias module (https://arxiv.org/abs/2310.04418).
Args:
num_heads: number of attention heads.
mlp_width: Width of MLP.
init_c: initial value of log transformation parameter
init_L: initial value of thresholding parameter
eps: small constant for numerical stability
"""
super(FIRE, self).__init__()
self.max_length = max_length # using random PE
# Define the MLP layers
self.mlp = torch.nn.Sequential(torch.nn.Linear(1, mlp_width), torch.nn.ReLU(), torch.nn.Linear(mlp_width, num_heads))
# Initialize c (log transformation parameter)
self.c = torch.nn.Parameter(torch.tensor(init_c))
# Initialize L (threshold)
self.init_L = torch.nn.Parameter(torch.tensor(init_L), requires_grad=False)
self.L_multiplier = torch.nn.Parameter(torch.tensor(1.0)) # learn a multiplier to L
self.eps = eps
def forward(self, seq_length, device):
"""
Compute FIRE attention bias (https://arxiv.org/abs/2310.04418).
Args:
x: input sequence, shape [bsz, num_heads, seq_len, hidden_dim]
Returns:
attention bias of shape [1, num_heads, seq_len, seq_len]
"""
if (seq_length > self.max_length) or (
not self.training
): # max length will be increased to max sequnece length if max length is short
max_length = seq_length
else:
max_length = self.max_length
# take a subset (of length seq_length) of a random permutation of length max_length, then sort it to
positions = torch.sort(torch.randperm(max_length, dtype=torch.float, device=device)[:seq_length]).values
relative_distances = positions[:, None] - positions[None, :]
# Thresholding the normalizer for short sequence modeling
threshold = torch.abs(self.L_multiplier * self.init_L)
position_normalizer = torch.max(positions, threshold)[:, None]
# Amplifying differences among local positions with log transform
relative_distances = torch.log(torch.abs(self.c * relative_distances) + 1)
position_normalizer = torch.log(torch.abs(self.c * position_normalizer) + 1)
# Progressive interpolation
normalized_distances = relative_distances / (position_normalizer + self.eps)
fire_bias = self.mlp(normalized_distances.unsqueeze(-1)).unsqueeze(0)
fire_bias = fire_bias.permute(0, 3, 1, 2)
return fire_bias
class Abacus(torch.nn.Module):
"""Abacus Embeddings, learned emebddings resued for each digit"""
def __init__(self, embedding_dim, max_seq_length=1024, max_k=99):
super().__init__()
self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
self.max_k = max_k # the max_k here by default is 99 as we add it on after istead of generate with it
def helper(self, mask, device):
mask_shape = mask.shape
# Create a shifted version of the mask to detect changes from 0 to 1
shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)
starts = (shifted_mask != mask) & mask
# Generate IDs for each segment of 1s, processing row-wise
segment_ids = torch.cumsum(starts, dim=1)
# Generate an index array row-wise
index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)
# Reset index at the start of each segment
reset_index = torch.zeros_like(mask).long()
second_term = index * starts.long()
reset_index = reset_index.scatter_add(1, segment_ids, second_term)
# Calculate positions in segment
positions = index - reset_index.gather(1, segment_ids) + 1
# Ensure only values within 1-segments are non-zero
result = positions * mask
return result
def forward(self, input_ids):
"""This is a batch-first implementation"""
"""
This is a batch-first implementation
designed to work with our tokenizers, for a more versatile implementation, look at the abacus.py file
sort tokenizer: '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13
{'0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'D': 14, ',': 15, ':': 16, '=': 17, ' ': 18, 'A': 19, 'B': 20, 'C': 21, 'E': 22, 'F': 23, 'G': 24, 'H': 25, 'I': 26, 'J': 27, 'K': 28, 'L': 29, 'M': 30, 'N': 31, 'O': 32, 'P': 33, 'Q': 34, 'R': 35, 'S': 36, 'T': 37, 'U': 38, 'V': 39, 'W': 40, 'X': 41, 'Y': 42, 'Z': 43, 'a': 44, 'b': 45, 'c': 46, 'd': 47, 'e': 48, 'f': 49, 'g': 50, 'h': 51, 'i': 52, 'j': 53, 'k': 54, 'l': 55, 'm': 56, 'n': 57, 'o': 58, 'p': 59, 'q': 60, 'r': 61, 's': 62, 't': 63, 'u': 64, 'v': 65, 'w': 66, 'y': 67, 'z': 68, '!': 69, '@': 70, '£': 71, '#': 72, '$': 73, '%': 74, '^': 75, '&': 76, '*': 77, '(': 78, ')': 79, '~': 80, '?': 81, '.': 82, '<': 83, '>': 84, '{': 85, '}': 86, '[': 87, ']': 88, ';': 89, '/': 90, '|': 91, 'β': 92, 'Γ': 93, 'Δ': 94, 'δ': 95, 'ε': 96, 'ζ': 97, 'η': 98, 'θ': 99, 'κ': 100, 'Λ': 101, 'λ': 102, 'μ': 103, 'Ξ': 104, 'ξ': 105, 'Π': 106, 'π': 107, 'Σ': 108, 'ς': 109, 'τ': 110, 'Φ': 111, 'φ': 112, 'χ': 113, 'Ψ': 114, 'ψ': 115, 'Ω': 116, 'ω': 117, '[PAD]': 0, '[UNK]': 1, '[BOS]': 2, '[EOS]': 3}
"""
mask = (input_ids >= 4) & (input_ids <= 13)
output = self.helper(mask, input_ids.device)
k=0
if self.training:
k = random.randint(0, self.max_k)
output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1
return self.embedding(output)
================================================
FILE: cramming/architectures/huggingface_interface.py
================================================
"""HF model variations based on reconfiguring their huggingface implementations."""
import transformers
def construct_huggingface_model(cfg_arch, vocab_size):
"""construct model from given configuration. Only works if this arch exists on the hub."""
if isinstance(cfg_arch, transformers.PretrainedConfig):
configuration = cfg_arch
else:
model_type = cfg_arch["model_type"]
configuration = transformers.AutoConfig.from_pretrained(pretrained_model_name_or_path=model_type, **cfg_arch)
configuration.vocab_size = vocab_size
model = transformers.AutoModelForPreTraining.from_config(configuration)
model.vocab_size = model.config.vocab_size
old_forward = model.forward
def modified_forward(input_ids, attention_mask=None, **kwargs):
return old_forward(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
model.forward = modified_forward
return model
================================================
FILE: cramming/architectures/losses.py
================================================
import torch
import math
class CosineLoss(torch.nn.Module):
__constants__ = ["reduction"]
reduction: str
def __init__(self, reduction: str = "mean", dim=-1, eps=1e-8) -> None:
super().__init__()
self.reduction = reduction
assert self.reduction == "mean"
self.dim = dim
self.eps = eps
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return 1 - torch.nn.functional.cosine_similarity(x1, x2, self.dim, self.eps).mean()
class CrossEntropyWithZLoss(torch.nn.Module):
"""Cross Entropy plus logit regularization via z_loss."""
__constants__ = ["ignore_index", "z_loss_factor"]
ignore_index: int
z_loss_factor: float
def __init__(self, ignore_index=-100, z_loss_factor=1e-4):
super().__init__()
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
self.z_loss_factor = z_loss_factor
self.ignore_index = ignore_index
def forward(self, inputs, labels):
"""Is this is the optimal implementation? Is this even what is meant?
I wish there were more answers or code for PaLM
This implementation assumes that log(Z) is log(sum(exp(logits))).
The usage of log2 here is also a bit wild...
"""
z_reg = inputs.exp().sum(dim=-1).log2().sum() * self.z_loss_factor
return self.loss_fn(inputs, labels) + z_reg
class MSELoss(torch.nn.Module):
"""MSE Loss as a drop-in replacement for Cross Entropy Loss.
This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes."""
def __init__(self, ignore_index=-100):
"""Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)"""
super().__init__()
self.ignore_index = ignore_index
def forward(self, inputs, labels):
"""Is this is the optimal implementation? Could also do an index_select variation..."""
num_classes = inputs.shape[-1]
valid_mask = labels != self.ignore_index
M = math.sqrt(num_classes)
onehot_labels = self._label_to_onehot(labels[valid_mask], M, num_classes=num_classes)
return 1 / (2 * M * num_classes) * (inputs[valid_mask] - onehot_labels).pow(2).sum()
@staticmethod
@torch.jit.script
def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)
onehot_target.scatter_(1, target.view(-1, 1), M)
return onehot_target
class MSELossFast(torch.nn.Module):
"""MSE Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels
This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes."""
def __init__(self, ignore_index=-100):
"""Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)"""
super().__init__()
self.ignore_index = ignore_index
def forward(self, inputs, labels):
"""Is this is the optimal implementation? This at least circumvents literal 1-hot labels"""
num_examples, num_classes = inputs.shape
valid_mask = labels != self.ignore_index
M = math.sqrt(num_classes)
inputs = inputs[valid_mask]
labels = labels[valid_mask]
x_i = inputs.pow(2).sum()
x_j = inputs[torch.arange(labels.shape[-1]), labels].sum()
return 1 / (2 * M * num_classes) * (x_i - 2 * M * x_j + labels.shape[-1] * M**2)
class L1Loss(torch.nn.Module):
"""L1 Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels
This implementation includes a mean reduction in batch dimension and a 1/num_classes reduction in classes."""
def __init__(self, ignore_index=-100):
"""."""
super().__init__()
self.ignore_index = ignore_index
def forward(self, inputs, labels):
"""Optimal scaling is less clear for L1"""
num_classes = inputs.shape[-1]
valid_mask = labels != self.ignore_index
M = math.sqrt(num_classes)
onehot_labels = self._label_to_onehot(labels[valid_mask], float(num_classes), num_classes=num_classes)
return 1 / inputs.shape[0] / M * (inputs[valid_mask] - onehot_labels).abs().sum()
@staticmethod
@torch.jit.script
def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)
onehot_target.scatter_(1, target.view(-1, 1), M)
return onehot_target
class SzegedyLoss(torch.nn.Module):
"""Regression directly back to input embedding. Remove the decoding layer if using this loss.
As mentioned at https://twitter.com/ChrSzegedy/status/1533322132368728064?t=xz00T1YT3-WiE0id-h3MEA&s=19
"""
def __init__(self, embedding_layer, ignore_index=-100, overrelaxation=2.0):
"""Overrelax parameter is quite a bit speculative..."""
super().__init__()
self.embedding = embedding_layer
self.ignore_index = ignore_index
self.overrelaxation = overrelaxation
def forward(self, inputs, labels):
"""This really just does L2(DNN(embed(x[:,:-1]), 2.0 * stop_gradient(embed(x[:,1:]))) as quoted above"""
num_examples, num_classes = inputs.shape
valid_mask = labels != self.ignore_index
M = math.sqrt(num_classes)
inputs = inputs[valid_mask]
with torch.no_grad():
embedded_labels = self.overrelaxation * self.embedding(labels)[valid_mask]
return (inputs - embedded_labels).pow(2).sum() / labels.shape[-1] / num_classes
"""Focal Loss from https://github.com/clcarwin/focal_loss_pytorch (minimally modernized into pytorch 1.12)"""
"""
MIT License
Copyright (c) 2017 carwin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
class FocalLoss(torch.nn.Module):
def __init__(self, gamma: float = 5.0, size_average: bool = True, ignore_index: int = -100):
super().__init__()
self.register_buffer("gamma", torch.as_tensor(gamma, dtype=torch.float), persistent=False)
self.size_average = size_average
self.ignore_index = ignore_index
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
valid_mask = target != self.ignore_index
log_probs = torch.nn.functional.log_softmax(input[valid_mask]).gather(1, target[None, valid_mask])
loss = -1 * (1 - log_probs.exp()) ** self.gamma * log_probs
if self.size_average:
return loss.mean()
else:
return loss.sum()
class IncorrectCrossEntropyLoss(torch.nn.CrossEntropyLoss):
"""CrossEntropyLoss, but only on incorrectly classified examples."""
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
incorrect_preds = input.argmax(dim=-1) != target
return torch.nn.functional.cross_entropy(
input[incorrect_preds],
target[incorrect_preds],
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction,
label_smoothing=self.label_smoothing,
)
================================================
FILE: cramming/architectures/sanity_check.py
================================================
"""Sanity Check architecture."""
import torch
from typing import Optional
class SanityCheckforPreTraining(torch.nn.Module):
"""Make big go fast."""
def __init__(self, width, vocab_size):
super().__init__()
self.word_embedding = torch.nn.Embedding(vocab_size, width, padding_idx=0)
self.transform = torch.nn.Linear(width, width, bias=False)
def forward(
self,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
embeds = self.word_embedding(input_ids)
outputs = self.transform(embeds)
loss = outputs.mean()
return {"logits": outputs, "loss": loss}
================================================
FILE: cramming/backend/__init__.py
================================================
"""This module implements interfaces to the various backends."""
from .prepare_backend import load_backend
from .utils import load_model_checkpoint, get_model_engine_tokenizer_dataloaders
__all__ = [
"load_backend",
"load_model_checkpoint",
"get_model_engine_tokenizer_dataloaders",
]
================================================
FILE: cramming/backend/optimizers/__init__.py
================================================
from .progressive_batching import ProgressiveBatching
from .optimizer_modifiers import SAM, LARS
from .schedulers import get_schedule_fn
================================================
FILE: cramming/backend/optimizers/optimizer_modifiers.py
================================================
"""This is the apex LARS implementation, from the apex repository.
It implements LARS + optional clipping
https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
I did rename it to "LARS".
"""
import torch
class MetaOptimizer(torch.optim.Optimizer):
"""base class for a meta optimizer that wraps and modifies an existing pytorch optimizer."""
def __init__(self, optimizer):
self.param_groups = optimizer.param_groups
self.optim = optimizer
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return self.__class__.__name__ + self.optim.__repr__()
def __getattr__(self, name):
"""Call this only if all other attributes are exhausted."""
return getattr(self.optim, name)
@torch.no_grad()
def step(self, closure=None):
return self.optim.step(closure)
class LARS(MetaOptimizer):
"""
:class:`LARS` [LARC in apex] is a pytorch implementation of both the scaling and clipping variants of LARS,
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
local learning rate for each individual parameter. The algorithm is designed to improve
convergence of large batch training.
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARS(optim)
```
Args:
optimizer: Pytorch optimizer to wrap and modify learning rate for.
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
clip: Decides between clipping or scaling mode of LARC [LARS + clip].
If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter.
If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
"""
def __init__(self, optimizer, trust_coefficient=0.02, clip=False, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = eps
self.clip = clip
def step(self, closure=None):
loss = None
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group["weight_decay"] if "weight_decay" in group else 0
weight_decays.append(weight_decay)
group["weight_decay"] = 0
for p in group["params"]:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr / group["lr"], 1)
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
loss = self.optim.step(closure)
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group["weight_decay"] = weight_decays[i]
return loss
"""This the SAM pytorch implementation from https://github.com/davda54/sam
with a minor modification """
"""
MIT License
Copyright (c) 2021 David Samuel
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
class SAM(MetaOptimizer):
def __init__(self, base_optimizer_instance, rho=0.05):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
self.rho = rho
self.optim = base_optimizer_instance
self.param_groups = base_optimizer_instance.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = self.rho / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
e_w = p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]["e_w"] = e_w
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
self.optim.step() # do the actual "sharpness-aware" update
if zero_grad:
self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
closure()
self.first_step(zero_grad=True)
loss = closure()
self.second_step()
return loss
def _grad_norm(self):
# put everything on the same device, in case of model parallelism
shared_device = self.param_groups[0]["params"][0].device
norm = torch.norm(
torch.stack([p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group["params"] if p.grad is not None]),
p=2,
)
return norm
================================================
FILE: cramming/backend/optimizers/progressive_batching.py
================================================
"""Implementation of a progressive batching meta optimizer.
The optimizer may defer an optimization step until gradient variance is small enough
"""
import torch
from collections import defaultdict
from .optimizer_modifiers import MetaOptimizer
import logging
log = logging.getLogger(__name__)
DEBUG = False
class ProgressiveBatching(MetaOptimizer):
def __init__(self, optimizer, progress_rule="norm-based", theta=0.9, monotone=False, min_sample_guard=2, max_sample_guard=128):
super().__init__(optimizer)
self.progress_rule = progress_rule
self.theta = theta
self.monotone = monotone
self.min_sample_guard = min_sample_guard
self.max_sample_guard = max_sample_guard
self.progress_state = defaultdict(dict)
self.accumulated_steps = 0
self.reset_sample_statistics()
@torch.no_grad()
def step(self):
"""(Maybe) performs a single optimization step."""
self.update_sample_statistics()
if self.accumulated_steps < self.min_sample_guard:
rule_check = False
else:
if self.accumulated_steps > self.max_sample_guard:
rule_check = True
else:
if self.progress_rule == "norm-based":
rule_check = self.norm_test()
elif self.progress_rule == "inner-product":
rule_check = self.inner_product_test()
elif self.progress_rule == "cov":
rule_check = self.coefficient_of_variation()
elif self.progress_rule == "cosine":
rule_check = self.cosine_test()
else:
raise ValueError(f"Invalid progress rules {self.progress_rule} given.")
if rule_check:
self.copy_mean_grad() # reference running mean in p.grad attributes
if self.monotone:
self.min_sample_guard = self.accumulated_steps # raise lower limit if forcing monotone batch sizes
self.reset_sample_statistics() # reset running mean
super().step()
else:
# otherwise defer the step and accumulate more gradients
pass
def inner_product_test(self):
"""Inner product similar to description in Bollapragada,Byrd,Nocedal, "Adaptive Sampling Strategies for Stochastic Optimization".
This is only a zero-memory inner product test.
"""
global_inner_product, global_variance = 0, 0
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
ndivn1 = self.accumulated_steps / (self.accumulated_steps - 1)
corrected_mean = (state["running_mean"] - p.grad / self.accumulated_steps) * ndivn1
global_inner_product += (p.grad * corrected_mean).sum()
global_variance += corrected_mean.pow(2).sum()
final_v = (global_inner_product - global_variance).pow(2)
if DEBUG:
inequality_repr = f"{final_v / (self.accumulated_steps - 1):10.2f} < {self.theta * global_variance**2:10.2f}"
log.info(f"{self.accumulated_steps} - {inequality_repr}")
return final_v / (self.accumulated_steps - 1) < self.theta * global_variance**2
def norm_test(self):
"""Sohams version."""
sample_var, mean_norm = 0, 0
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
sample_var += state["running_variance"].sum() / (self.accumulated_steps - 1) # bessel-corrected variance
mean_norm += state["running_mean"].pow(2).sum()
if DEBUG:
log.info(f"{self.accumulated_steps} - {sample_var / self.accumulated_steps:10.2f} < {self.theta * mean_norm:10.2f}")
return sample_var / self.accumulated_steps < self.theta * mean_norm # divide by |B| as in bigbatch, original version is theta=1
def cosine_test(self):
"""Experimental."""
total_angles, num_params = 0, 0
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
ndivn1 = self.accumulated_steps / (self.accumulated_steps - 1)
corrected_mean = (state["running_mean"] - p.grad / self.accumulated_steps) * ndivn1
total_angles += (p.grad * corrected_mean).sum() / corrected_mean.norm() / p.grad.norm()
num_params += 1
average_angle = total_angles / num_params # rather the average cosine, this not (yet) the angle
if DEBUG:
log.info(f"{self.accumulated_steps} - {average_angle:10.2f} > {self.theta:10.2f}")
return average_angle > self.theta
def coefficient_of_variation(self):
"""unbiased cov test."""
cov, mean_norm, num_params = 0, 0, 0
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
cov += (state["running_variance"].sum() / (self.accumulated_steps - 1)).sqrt() / (state["running_mean"].pow(2).sum() + 1e-6)
mean_norm += state["running_mean"].pow(2).sum()
num_params += 1
unbiased_avg_cov = (1 + 1 / (4 * self.accumulated_steps)) * cov / num_params / self.accumulated_steps
if DEBUG:
log.info(f"{self.accumulated_steps} - {unbiased_avg_cov:10.2f} < {self.theta * 100:10.2f}")
return unbiased_avg_cov < self.theta * 100
def update_sample_statistics(self):
"""Update sample statistics based on welford accumulation. At any step variance can be finalized via running_variance / count"""
self.accumulated_steps += 1
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
current_delta = p.grad - state["running_mean"]
state["running_mean"] += current_delta / self.accumulated_steps
corrected_delta = p.grad - state["running_mean"]
state["running_variance"] += current_delta * corrected_delta
def reset_sample_statistics(self):
"""Allocate new tensors, old references are still required for the optimizer step."""
self.last_full_step_accumulation = self.accumulated_steps + 1
self.accumulated_steps = 0
for group in self.param_groups:
for p in group["params"]:
state = self.progress_state[p]
state["running_mean"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["running_variance"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def copy_mean_grad(self):
for group in self.param_groups:
for p in group["params"]:
p.grad = self.progress_state[p]["running_mean"]
================================================
FILE: cramming/backend/optimizers/schedulers.py
================================================
"""Misc. optimizer implementations."""
import transformers
import math
from torch.optim.lr_scheduler import LambdaLR
import time
from functools import partial
def get_schedule_fn(cfg_train, elapsed_time: float=0.0, true_budget: float = -1):
"""Returns a callable scheduler_fn(optimizer).
Todo: Sanitize and unify these schedulers...
"""
if true_budget <= 0:
true_budget = cfg_train.budget
if (cfg_train.warmup_steps) > 0 and (cfg_train.warmup_steps < 1):
# warmup could be a percentage in which case this line converts to steps again
cfg_train.warmup_steps = int(cfg_train.warmup_steps * cfg_train.steps)
if (cfg_train.cooldown_steps) > 0 and (cfg_train.cooldown_steps < 1):
# cooldown could be a percentage in which case this line converts to steps again
cfg_train.cooldown_steps = int(cfg_train.cooldown_steps * cfg_train.steps)
# Load huggingface schedulers based on total steps
if cfg_train.scheduler == "polynomial-decay":
scheduler_fn = partial(
transformers.get_polynomial_decay_schedule_with_warmup,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
lr_end=1e-7,
power=1.0,
)
elif cfg_train.scheduler == "cosine-decay":
scheduler_fn = partial(
transformers.get_cosine_schedule_with_warmup,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
num_cycles=0.5,
)
elif cfg_train.scheduler == "inverse-sqrt":
scheduler_fn = partial(
get_inverse_sqrt_scheduler,
num_warmup_steps=cfg_train.warmup_steps,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
)
elif cfg_train.scheduler == "one-cycle": # this is a simplified one-cycle
scheduler_fn = partial(
get_one_cycle,
num_training_steps=cfg_train.steps,
)
elif cfg_train.scheduler == "ramp": # this is a simplified one-cycle
scheduler_fn = partial(
get_ramp,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
)
"""Budget Schedulers from here: """
elif cfg_train.scheduler == "budget-inverse-sqrt":
scheduler_fn = partial(
get_budget_inv_sqrt_scheduler,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-constant":
scheduler_fn = partial(
get_budget_constant_scheduler,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-cosine-decay":
scheduler_fn = partial(
get_budget_cosine_schedule_with_warmup,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
num_cycles=0.5,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-cosine-annealing":
scheduler_fn = partial(
get_budget_cosine_half_cycles_with_warmup,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
num_cycles=4,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-linear":
scheduler_fn = partial(
get_budget_linear_schedule_with_warmup,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-polynomial":
scheduler_fn = partial(
get_budget_polynomial_decay_with_warmup,
hour_budget=true_budget,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-one-cycle": # this is a simplified one-cycle
scheduler_fn = partial(
get_budget_one_cycle,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-multi-cycle":
scheduler_fn = partial(
get_budget_multi_cycle,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-ramp":
scheduler_fn = partial(
get_budget_ramp,
hour_budget=true_budget,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-inv-cosine":
scheduler_fn = partial(
get_budget_inv_cosine_schedule,
hour_budget=true_budget,
num_cooldown_steps=cfg_train.cooldown_steps,
num_training_steps=cfg_train.steps,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-dive":
scheduler_fn = partial(
get_budget_dive,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
num_warmup_steps=cfg_train.warmup_steps,
falloff=0.5,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-dive-slow":
scheduler_fn = partial(
get_budget_dive,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
num_warmup_steps=cfg_train.warmup_steps,
falloff=0.75,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-dive-fast":
scheduler_fn = partial(
get_budget_dive,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
num_warmup_steps=cfg_train.warmup_steps,
falloff=0.25,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-triangle1":
scheduler_fn = partial(
get_budget_triangle,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
falloff=0.25,
base_percentage=0.5,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler == "budget-triangle2":
scheduler_fn = partial(
get_budget_triangle,
hour_budget=true_budget,
num_training_steps=cfg_train.steps,
falloff=0.25,
base_percentage=0.25,
elapsed_time=elapsed_time,
)
elif cfg_train.scheduler in [
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
"get_cosine_with_hard_restarts_schedule_with_warmup",
"get_polynomial_decay_schedule_with_warmup",
]:
def scheduler_fn(optimizer):
return transformers.get_scheduler(
name=cfg_train.scheduler,
optimizer=optimizer,
num_warmup_steps=cfg_train.warmup_steps,
num_training_steps=cfg_train.steps,
)
elif cfg_train.scheduler == "none" or cfg_train.scheduler is None:
scheduler_fn = DumbScheduler
else:
raise ValueError(f"Invalid schedule {cfg_train.scheduler} given.")
return scheduler_fn
class DumbScheduler:
def __init__(self, *args, **kwargs):
self._step_count = 0
def step(self, *args, **kwargs):
self._step_count += 1
def _initial_step(self):
self.optimizer._step_count = 0
self._step_count = 0
self.step()
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
return float("NaN")
def get_lr(self):
return float("NaN")
def print_lr(self, is_verbose, group, lr, epoch=None):
print(float("NaN"))
"""FairSeq-like inverse-square-root scheduler:"""
def get_inverse_sqrt_scheduler(optimizer, num_warmup_steps, num_cooldown_steps, num_training_steps):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates)
"""
# linearly warmup for the first args.warmup_updates
lr_step = 1 / num_warmup_steps
# then, decay prop. to the inverse square root of the update number
decay_factor = num_warmup_steps**0.5
decayed_lr = decay_factor * (num_training_steps - num_cooldown_steps) ** -0.5
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step * lr_step)
elif current_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float(decayed_lr * (num_training_steps - current_step) / num_cooldown_steps))
else:
return float(decay_factor * current_step**-0.5)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
def get_one_cycle(optimizer, num_training_steps):
"""Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry."""
def lr_lambda(current_step):
if current_step < num_training_steps / 2:
return float(current_step / (num_training_steps / 2))
else:
return float(2 - current_step / (num_training_steps / 2))
return LambdaLR(optimizer, lr_lambda, -1)
def get_ramp(optimizer, num_cooldown_steps, num_training_steps):
"""to the MOON."""
max_lr = (num_training_steps - num_cooldown_steps) / num_training_steps
def lr_lambda(current_step):
if current_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float(max_lr * (num_training_steps - current_step) / num_cooldown_steps))
else:
return float(current_step / num_training_steps)
return LambdaLR(optimizer, lr_lambda, -1)
"""Wallclock time schedulers."""
def _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, prev_elapsed_time: float = 0.0):
elapsed_hours = (time.time() - initial_time + prev_elapsed_time) / 60 / 60
if current_step == 0:
fake_step = 0
else:
fake_step = int(elapsed_hours / hour_budget * num_training_steps)
# Warning: denominator could be bigger than 1 if passed original budget, so be careful with checkpointing
return fake_step
def get_budget_inv_sqrt_scheduler(optimizer, hour_budget, num_warmup_steps, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):
"""Time-based scheduler as described in Iszak et al. plus inv_sqrt.
Takes in num_warmup_steps and num_training_steps as normal, but actually squeezes the planned schedule into the
budget given by hour_budget, based on wallclock measurements.
Reference: https://github.com/IntelLabs/academic-budget-bert/blob/main/pretraining/schedules.py
"""
decay_factor = num_warmup_steps**0.5
decayed_lr = decay_factor * (num_training_steps - num_cooldown_steps) ** -0.5
initial_time = time.time()
def lr_lambda(current_step: int):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step / num_warmup_steps)
elif fake_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float(decayed_lr * (num_training_steps - fake_step) / num_cooldown_steps))
else:
return float(decay_factor * fake_step**-0.5)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
def get_budget_constant_scheduler(optimizer, hour_budget, num_warmup_steps, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):
"""Time-based scheduler with optional warmup and cooldown (so technically a trapezoidal shape)"""
initial_time = time.time()
def lr_lambda(current_step: int):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step / num_warmup_steps)
elif fake_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float((num_training_steps - fake_step) / num_cooldown_steps))
else:
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
def get_budget_linear_schedule_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):
"""Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget"""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - fake_step) / float(max(1, num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_cosine_schedule_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):
"""Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget"""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step) / float(max(1, num_warmup_steps))
progress = float(fake_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_cosine_half_cycles_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):
"""Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget"""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step) / float(max(1, num_warmup_steps))
progress = float(fake_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_one_cycle(optimizer, hour_budget, num_training_steps, elapsed_time: float = 0.0):
"""Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry."""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_training_steps / 2:
return float(fake_step / (num_training_steps / 2))
else:
return float(2 - fake_step / (num_training_steps / 2))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_multi_cycle(optimizer, hour_budget, num_training_steps, num_cycles=8, elapsed_time: float = 0.0):
"""Simple multi-cycle scheduler. Not including paper/fastai three-phase things or asymmetry."""
initial_time = time.time()
cycle_length = int(num_training_steps / num_cycles)
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps) % cycle_lengt, elapsed_timeh
if fake_step < cycle_length / 2:
return float(fake_step / (cycle_length / 2))
else:
return float(2 - fake_step / (cycle_length / 2))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_ramp(optimizer, hour_budget, num_cooldown_steps, num_training_steps, elapsed_time: float = 0.0):
"""to the moon."""
initial_time = time.time()
max_lr = (num_training_steps - num_cooldown_steps) / num_training_steps
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float(max_lr * (num_training_steps - fake_step) / num_cooldown_steps))
else:
return float(fake_step / num_training_steps)
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_inv_cosine_schedule(optimizer, hour_budget, num_cooldown_steps, num_training_steps, num_cycles=0.5, elapsed_time: float = 0.0):
"""An inverse cosine schedule, with limited budget."""
initial_time = time.time()
ult_step = num_training_steps - num_cooldown_steps
max_lr = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * (1 - ult_step / float(max(1, num_training_steps))))))
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
progress = 1 - fake_step / float(max(1, num_training_steps))
if fake_step > (num_training_steps - num_cooldown_steps):
return max(0.0, float(max_lr * (num_training_steps - fake_step) / num_cooldown_steps))
else:
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_triangle(optimizer, hour_budget, num_training_steps, base_percentage=0.5, falloff=0.5, elapsed_time: float = 0.0):
"""Linear increase from a percentage of the base learning rate, then linear decay.
plot min(0.5 + x * (1 - 0.5)/(1-0.25) / 1000, 1/0.25 - x / (1000 * 0.25)) from 0 to 1000 in the plot range 0 to 1
"""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
return min(
base_percentage + fake_step * (1 - base_percentage) / (1 - falloff) / num_training_steps,
float(1 / falloff - fake_step / (num_training_steps * falloff)),
)
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_dive(optimizer, hour_budget, num_training_steps, num_warmup_steps=0, falloff=0.5, elapsed_time: float = 0.0):
"""Constant, then linear decay.
plot min(1, 1/0.5 - x / (1000 * 0.5)) from 0 to 1000 in the plot range 0 to 1
"""
initial_time = time.time()
def lr_lambda(current_step):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if current_step < num_warmup_steps:
return float(fake_step) / float(max(1, num_warmup_steps))
else:
return min(1.0, float(1 / falloff - fake_step / (num_training_steps * falloff)))
return LambdaLR(optimizer, lr_lambda, -1)
def get_budget_polynomial_decay_with_warmup(optimizer, hour_budget, num_warmup_steps, num_training_steps, lr_end=0.0, power=1.0, elapsed_time: float = 0.0):
"""Follows the huggingface transformers scheduler with the same name, but gets an additional arg hour_budget"""
initial_time = time.time()
lr_init = optimizer.defaults["lr"]
def lr_lambda(current_step: int):
fake_step = _get_fake_step(current_step, initial_time, hour_budget, num_training_steps, elapsed_time)
if fake_step < num_warmup_steps:
return float(fake_step) / float(max(1, num_warmup_steps))
elif fake_step > num_training_steps:
return lr_end / lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = lr_init - lr_end
decay_steps = num_training_steps - num_warmup_steps
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining**power + lr_end
return decay / lr_init # as LambdaLR multiplies by lr_init
return LambdaLR(optimizer, lr_lambda, -1)
================================================
FILE: cramming/backend/prepare_backend.py
================================================
"""Instantiate backend objects in a congruent format."""
import torch
from .torch_default import initialize_torch
_default_setup = dict(device=torch.device("cpu"), dtype=torch.float)
def load_backend(model, tokenizer, cfg_train, cfg_impl, setup=_default_setup, init_compile_and_distribute=True):
if cfg_impl.name == "torch-default":
return initialize_torch(model, tokenizer, cfg_train, cfg_impl, setup=setup, init_compile_and_distribute=init_compile_and_distribute)
else:
raise ValueError(f"Invalid backend {cfg_impl.name} given.")
================================================
FILE: cramming/backend/torch_default.py
================================================
"""Basic training backend engine for pytorch training with all bells and whistles.
Interface set up to be compliant with the deepspeed engine interface.
There are two versions here, the TorchEngineMinimal, which is the default, and TorchEngineFull which contains a few training variations
that were tested but ultimately discarded, so read that part only if you're interested.
"""
import json
import logging
import os
import time
from contextlib import nullcontext
from functools import partial
from typing import Any, Dict, Union
import torch
import torch._inductor.utils
import transformers
from omegaconf import OmegaConf
from safetensors.torch import save_file
from torch.distributed.optim import ZeroRedundancyOptimizer
from transformers.utils.generic import working_or_temp_dir
from .optimizers import LARS, SAM, ProgressiveBatching
from .optimizers.schedulers import get_schedule_fn
# from .utils import group_parameters, prepare_pretraining_dataloader, prepare_validation_dataloader
from .utils import group_parameters, load_model_checkpoint
log = logging.getLogger(__name__)
_default_setup = dict(device=torch.device("cpu"), dtype=torch.float)
import warnings
from ..utils import flatten
warnings.filterwarnings("ignore", "Detected call of ", UserWarning) # schedulers are deliberately used differently
def initialize_torch(model, tokenizer, cfg_train, cfg_impl, setup=_default_setup, init_compile_and_distribute=True):
"""initialize a torch engine."""
model_engine = TorchEngine(
model,
cfg_train,
cfg_impl,
setup=setup,
seq_length=tokenizer.model_max_length,
init_compile_and_distribute=init_compile_and_distribute,
)
model_engine.train()
return model_engine
class TorchEngine(torch.nn.Module):
"""This class mirrors deepspeed functionality and hides variable batch sizes, microbatching, AMP details and compilation"""
def __init__(self, model, cfg_train, cfg_impl, setup=_default_setup, seq_length=128, init_compile_and_distribute=True):
"""Load Engine. The model will be compiled by default.
init_compile_and_distribute=False => In the case we are loading in a checkpoint we might aswell not send it across GPUs as this will be redone later
"""
super().__init__()
self.cfg_train = cfg_train
self.cfg_impl = cfg_impl
if self.cfg_impl.microbatch_size is None:
self.cfg_impl.microbatch_size = self.cfg_train.batch_size
if self.cfg_impl.microbatch_size > self.cfg_train.batch_size:
raise ValueError(f"MBS is {self.cfg_impl.microbatch_size}, but BS is only {self.cfg_train.batch_size}.")
self.current_seq_length = seq_length
# Mixed Precision:
enabled = self.cfg_impl.mixed_precision if setup["device"].type != "cpu" else False
# Modules like LN are unsupported on CPU amp, so mixed precision args are disregarded on CPU
# See https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior and check for layer_norm
enable_scaling = self.cfg_impl.grad_scaling and self.cfg_impl.mixed_precision and setup["device"].type != "cpu"
self.scaler = torch.cuda.amp.GradScaler(enabled=enable_scaling)
amp_dtype = getattr(torch, self.cfg_impl.mixed_precision_target_dtype) if setup["device"].type != "cpu" else torch.bfloat16
self.amp_settings = dict(device_type=setup["device"].type, enabled=enabled, dtype=amp_dtype)
# Choose setup and move model
self.setup = setup
model.to(**self.setup)
self._original_model = model
log.info("Compiling model, in the Constructor of TorchEngine")
model = torch.compile(
model,
mode=self.cfg_impl.mode,
dynamic=self.cfg_impl.dynamic,
fullgraph=self.cfg_impl.fullgraph,
backend=self.cfg_impl.backend,
disable=not cfg_impl.compile_torch,
# detailed options; cannot be given at the same time as mode:
options=flatten(cfg_impl._inductor_vars, parent_key="", sep=".") if cfg_impl._inductor_vars is not None else None,
)
if torch.distributed.is_initialized():
if init_compile_and_distribute:
log.info("Distributing model, in the Constructor of TorchEngine")
self.model = self._init_distributed(model)
else:
log.info(
"<WARNING> NOT Distirbuting model in the Constructor of TorchEngine, we will attempt to do this later as we are loading in a checkpoint"
)
self.model = model
self.num_machines = torch.distributed.get_world_size()
else:
self.model = model
self.model.no_sync = nullcontext
self.num_machines = 1
# Microbatch accumulation settings and counters
self.effective_mbs = self.cfg_impl.microbatch_size * self.num_machines # across machines
self.current_batch_size = self.cfg_train.batch_size if self.cfg_train.batch_size_ramp == 0 else self.effective_mbs
self.accumulation_steps_expected = self.current_batch_size // self.effective_mbs
self.accumulated_samples = 0 # Record the number of samples seen, reset after triggering gradient update
self.steps = 0 # Record the number of times "step" has been triggered
self.steps_since_reset = 0 # Record the number of times "step" has been triggered
self.initial_time = time.time()
self.previous_elapsed_time = 0.0
self.optimizer, self.scheduler = _load_optimizer(model, cfg_train, cfg_impl, self.previous_elapsed_time, self.get_true_budget())
def get_true_budget(self):
return (
min(self.cfg_train.budget, self.cfg_train.overall_budget - self.previous_elapsed_time / 3600)
+ self.previous_elapsed_time / 3600
)
def step(self, batch: dict[str, torch.Tensor]):
loss = self.forward(**batch)["loss"]
self.backward(loss)
self.optimizer_step()
return loss.detach()
def to_device(self, batch: dict[str, torch.Tensor], keys: list[str] = ["input_ids"]):
"""Move batch of data into device memory."""
device_batch = {
k: v.to(device=self.setup["device"], dtype=torch.long if k == "input_ids" else None, non_blocking=True)
for k, v in batch.items()
if k in keys # Add more keywords here if needed
}
return device_batch
def forward(self, *inputs, **kwargs):
self.accumulated_samples += self.effective_mbs
context = self.model.no_sync if self.accumulated_samples < self.current_batch_size else nullcontext
with context():
w
gitextract_shohcgjg/ ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── abacus.py ├── arithmetic_eval_quicker.py ├── cramming/ │ ├── __init__.py │ ├── architectures/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── components.py │ │ ├── construction.py │ │ ├── crammed_depthrecurrent.py │ │ ├── crammed_transformer.py │ │ ├── embeddings.py │ │ ├── huggingface_interface.py │ │ ├── losses.py │ │ └── sanity_check.py │ ├── backend/ │ │ ├── __init__.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── optimizer_modifiers.py │ │ │ ├── progressive_batching.py │ │ │ └── schedulers.py │ │ ├── prepare_backend.py │ │ ├── torch_default.py │ │ └── utils.py │ ├── config/ │ │ ├── __init__.py │ │ ├── arch/ │ │ │ ├── __init__.py │ │ │ ├── albert.yaml │ │ │ ├── crammed-depthrecurrent.yaml │ │ │ ├── crammed-fakeRNN.yaml │ │ │ ├── crammed-janus.yaml │ │ │ ├── crammed-rnn.yaml │ │ │ ├── crammed-stack-janus.yaml │ │ │ ├── crammed-tiny.yaml │ │ │ ├── crammed-transformer.yaml │ │ │ ├── gpt2-base.yaml │ │ │ ├── hf-gpt2.yaml │ │ │ └── sanitycheck.yaml │ │ ├── cfg_eval.yaml │ │ ├── cfg_pretrain.yaml │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── arithmetic.yaml │ │ │ ├── c4-subset-processed.yaml │ │ │ ├── openweb.yaml │ │ │ ├── proofpile.yaml │ │ │ ├── sanity-check-1.yaml │ │ │ ├── sanity-check-2.yaml │ │ │ └── sources/ │ │ │ ├── ag_news.yaml │ │ │ ├── arithmetic.yaml │ │ │ ├── bookcorpus.yaml │ │ │ ├── c4.yaml │ │ │ ├── dash_books.yaml │ │ │ ├── fake.yaml │ │ │ ├── iwslt.yaml │ │ │ ├── local.yaml │ │ │ ├── no_code_stackexchange.yaml │ │ │ ├── openwebtext.yaml │ │ │ ├── oscar.yaml │ │ │ ├── proofpiledata.yaml │ │ │ ├── the_pile.yaml │ │ │ ├── the_pileCC.yaml │ │ │ ├── the_pile_dedup.yaml │ │ │ ├── the_pile_natural.yaml │ │ │ ├── the_pile_stream.yaml │ │ │ ├── uncorpus.yaml │ │ │ ├── uspto.yaml │ │ │ ├── wikibooks.yaml │ │ │ ├── wikinews.yaml │ │ │ ├── wikipedia.yaml │ │ │ ├── wikiquote.yaml │ │ │ ├── wikiversity.yaml │ │ │ └── wikivoyage.yaml │ │ ├── eval/ │ │ │ ├── __init__.py │ │ │ ├── pythia.yaml │ │ │ └── tasks/ │ │ │ ├── lambada_openai.yaml │ │ │ └── winogrande.yaml │ │ ├── hydra/ │ │ │ ├── __init__.py │ │ │ └── job_logging/ │ │ │ └── custom.yaml │ │ ├── impl/ │ │ │ ├── __init__.py │ │ │ ├── _default.yaml │ │ │ └── torch-default.yaml │ │ ├── train/ │ │ │ ├── __init__.py │ │ │ ├── common.yaml │ │ │ ├── cramming.yaml │ │ │ ├── janus-regime.yaml │ │ │ ├── optim/ │ │ │ │ ├── adafactor.yaml │ │ │ │ ├── adahessian.yaml │ │ │ │ ├── adam.yaml │ │ │ │ ├── adam8bit.yaml │ │ │ │ ├── adam_classic.yaml │ │ │ │ ├── adamscale.yaml │ │ │ │ ├── agd.yaml │ │ │ │ ├── lion.yaml │ │ │ │ ├── radam.yaml │ │ │ │ ├── sgd.yaml │ │ │ │ └── shampoo.yaml │ │ │ └── optim_mod/ │ │ │ ├── disabled.yaml │ │ │ ├── larc.yaml │ │ │ ├── lars.yaml │ │ │ ├── progressive.yaml │ │ │ └── sam.yaml │ │ └── wandb/ │ │ ├── default.yaml │ │ └── none.yaml │ ├── data/ │ │ ├── __init__.py │ │ ├── arithmetic_tokenizers.py │ │ ├── curriculum_sorting.py │ │ ├── deduplicate.py │ │ ├── pretraining_preparation.py │ │ ├── tokenizer_preparation.py │ │ └── utils.py │ └── utils.py ├── create_data_split.py ├── create_pos_or_variants.py ├── dataset_analysis.py ├── gen_eval_script.py ├── load_local_model.py ├── pretrain.py ├── pretty_plotter.py ├── pretty_plotter_big.py ├── pretty_plotter_sort.py ├── pyproject.toml ├── setup.cfg ├── shells/ │ ├── addition_ff.sh │ ├── addition_lt.sh │ ├── bitwise_or.sh │ ├── evaluation.sh │ ├── generate_and_tokenize_data.sh │ ├── multiplication.sh │ └── sorting.sh ├── sort_eval.py └── upload_processed_dataset.py
SYMBOL INDEX (450 symbols across 35 files)
FILE: abacus.py
class Abacus (line 5) | class Abacus(torch.nn.Module):
method __init__ (line 11) | def __init__(self, digit_tokens, embedding_dim, max_seq_length=1024, m...
method helper (line 24) | def helper(self, mask, device):
method forward (line 53) | def forward(self, input_ids):
FILE: arithmetic_eval_quicker.py
function grid_plotter (line 21) | def grid_plotter(data, type="accs", name='_large', extra_path=None):
function index_hints_helper (line 44) | def index_hints_helper(num, tokenizer):
function grid_logic (line 54) | def grid_logic(cfg):
function main (line 199) | def main(cfg):
function launch (line 509) | def launch(cfg):
FILE: cramming/__init__.py
function get_config (line 23) | def get_config(overrides=[]):
function get_model_config (line 31) | def get_model_config(arch="hf-bert-tiny", overrides=[]):
function get_backend_config (line 39) | def get_backend_config(backend="torch-default", overrides=[]):
FILE: cramming/architectures/attention.py
function get_attention_mechanism (line 11) | def get_attention_mechanism(idx, hidden_size, cfg_attention, norm_fn: to...
class Identity (line 33) | class Identity(torch.nn.Module):
method __init__ (line 39) | def __init__(self, hidden_size):
method forward (line 43) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class RandomNoise (line 46) | class RandomNoise(torch.nn.Module):
method __init__ (line 52) | def __init__(self, hidden_size):
method forward (line 56) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class BertAttentionWrapper (line 60) | class BertAttentionWrapper(BertSelfAttention):
method __init__ (line 66) | def __init__(self, hidden_size, cfg_attention):
method forward (line 81) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class SelfAttentionPyTorch (line 85) | class SelfAttentionPyTorch(torch.nn.Module):
method __init__ (line 91) | def __init__(self, hidden_size, cfg_attention):
method forward (line 102) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class SeqFirstSelfAttentionPyTorch (line 113) | class SeqFirstSelfAttentionPyTorch(torch.nn.Module):
method __init__ (line 119) | def __init__(self, hidden_size, cfg_attention):
method forward (line 130) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class SeqFirstSelfAttention (line 141) | class SeqFirstSelfAttention(torch.nn.MultiheadAttention):
method __init__ (line 155) | def __init__(self, hidden_size: int, cfg_attention, norm_module=torch....
method attention (line 213) | def attention(self, query_layer, key_layer, value_layer, attention_mas...
method forward (line 262) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class FourierMixing (line 302) | class FourierMixing(torch.nn.Module):
method __init__ (line 311) | def __init__(self, hidden_size, cfg_attention):
method forward (line 323) | def forward(self, hidden_states, attention_mask: Optional[torch.Tensor...
class TorchSoftmax (line 348) | class TorchSoftmax(torch.nn.Module):
method __init__ (line 349) | def __init__(self, seq_op_in_fp32=False):
method forward (line 353) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class TorchShaped (line 363) | class TorchShaped(torch.nn.Module):
method __init__ (line 366) | def __init__(self, seq_op_in_fp32=False, hidden_size=768):
method forward (line 371) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class SwinCosine (line 384) | class SwinCosine(torch.nn.Module):
method __init__ (line 387) | def __init__(self, seq_op_in_fp32=False, tau=0.1, eps=1e-8):
method forward (line 393) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class TorchNormalize (line 408) | class TorchNormalize(torch.nn.Module):
method __init__ (line 409) | def __init__(self, num_attention_heads=1, seq_op_in_fp32=False):
method forward (line 416) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class ScaledIdentity (line 430) | class ScaledIdentity(torch.nn.Module):
method __init__ (line 431) | def __init__(self, seq_op_in_fp32):
method forward (line 435) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class Cumsum (line 443) | class Cumsum(torch.nn.Module):
method __init__ (line 444) | def __init__(self, seq_op_in_fp32):
method forward (line 448) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
class CumsumExp (line 456) | class CumsumExp(torch.nn.Module):
method __init__ (line 457) | def __init__(self, seq_op_in_fp32):
method forward (line 461) | def forward(self, inputs, attention_mask: Optional[torch.Tensor] = None):
FILE: cramming/architectures/components.py
class EmbeddingComponent (line 14) | class EmbeddingComponent(torch.nn.Module):
method __init__ (line 16) | def __init__(self, cfg_embedding, norm, norm_eps):
method forward (line 40) | def forward(self, input_ids):
class PredictionHeadComponent (line 54) | class PredictionHeadComponent(torch.nn.Module):
method __init__ (line 55) | def __init__(self, cfg_arch):
method forward (line 67) | def forward(self, hidden_states):
class NormalizedResidualConnection (line 72) | class NormalizedResidualConnection(torch.nn.Module):
method __init__ (line 75) | def __init__(self, input_dim, cfg_arch, output_dim=None, dropout=0.0):
method _simple_residual (line 108) | def _simple_residual(self, residual, layer, states, *args, **kwargs):
method _prenormalization_residual (line 111) | def _prenormalization_residual(self, residual, layer, states, *args, *...
method _postnormalization_residual (line 114) | def _postnormalization_residual(self, residual, layer, states, *args, ...
method _deepnorm_residual (line 117) | def _deepnorm_residual(self, residual, layer, states, *args, **kwargs):
method _prenorm_equalized_residual (line 120) | def _prenorm_equalized_residual(self, residual, layer, states, *args, ...
method _sandwich_residual (line 123) | def _sandwich_residual(self, residual, layer, states, *args, **kwargs):
method forward (line 126) | def forward(self, residual: torch.Tensor, layer_callable: torch.nn.Mod...
function _get_norm_fn (line 136) | def _get_norm_fn(norm_name):
function _get_nonlin_fn (line 150) | def _get_nonlin_fn(nonlin_name, use_gating=True):
class GLU (line 169) | class GLU(torch.nn.Module):
method __init__ (line 175) | def __init__(self, sub_activation):
method forward (line 179) | def forward(self, inputs):
class ScaleNorm (line 184) | class ScaleNorm(torch.nn.Module):
method __init__ (line 191) | def __init__(self, hidden_size: int, eps: float = 1e-5, elementwise_af...
method forward (line 199) | def forward(self, inputs):
class RMSNorm (line 204) | class RMSNorm(torch.nn.Module):
method __init__ (line 207) | def __init__(self, hidden_size: int, eps: float = 1e-6, elementwise_af...
method _legacy_forward (line 215) | def _legacy_forward(self, inputs):
method _norm (line 219) | def _norm(self, x):
method forward (line 223) | def forward(self, x):
function get_causal_attention_mask (line 228) | def get_causal_attention_mask(input_ids) -> torch.Tensor:
function get_extended_attention_mask (line 239) | def get_extended_attention_mask(attention_mask: torch.Tensor, input_shap...
function _init_module (line 292) | def _init_module(module, init_method="normal", init_std=0.02, hidden_siz...
FILE: cramming/architectures/construction.py
function construct_model (line 14) | def construct_model(cfg_arch, tokenizer):
FILE: cramming/architectures/crammed_depthrecurrent.py
class crammedDepthRecurrentConfig (line 21) | class crammedDepthRecurrentConfig(PretrainedConfig):
method __init__ (line 24) | def __init__(self, cfg_arch_container: dict = {}, **kwargs):
function construct_crammed_recurrent (line 29) | def construct_crammed_recurrent(cfg_arch, vocab_size, equals_token):
class FFNComponent (line 44) | class FFNComponent(torch.nn.Module):
method __init__ (line 51) | def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=N...
method forward (line 66) | def forward(self, hidden_states):
class TransformerLayer (line 70) | class TransformerLayer(torch.nn.Module):
method __init__ (line 73) | def __init__(self, idx, cfg_arch):
method forward (line 85) | def forward(self, states, attention_mask: Optional[torch.Tensor] = None):
class TransformerBlock (line 91) | class TransformerBlock(torch.nn.Module):
method __init__ (line 94) | def __init__(self, layers, cfg_arch):
method forward (line 104) | def forward(self, states, injected_state, attention_mask: Optional[tor...
class TransposedAdapter (line 120) | class TransposedAdapter(torch.nn.Linear): # steal init
method __init__ (line 121) | def __init__(self, embedding_dim, hidden_size, original_adapter, tie_w...
method forward (line 132) | def forward(self, inputs):
class ScriptableRecurrentLM (line 136) | class ScriptableRecurrentLM(PreTrainedModel):
method __init__ (line 141) | def __init__(self, config):
method forward (line 172) | def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = No...
method initialize_state (line 194) | def initialize_state(self, hidden_states):
class ScriptableRecurrentLMReplicaConcat (line 210) | class ScriptableRecurrentLMReplicaConcat(PreTrainedModel):
method __init__ (line 216) | def __init__(self, config):
method apply_recurrent_block (line 258) | def apply_recurrent_block(self, hidden_states, injected_state, attenti...
method forward (line 264) | def forward(self, input_ids: torch.Tensor, num_steps_no_grad: int = No...
method initialize_state (line 286) | def initialize_state(self, hidden_states):
function _generate (line 304) | def _generate(self, input_ids, token_limit=100, temperature=1.0, steps_a...
class ScriptableRecurrentLMForPreTraining (line 372) | class ScriptableRecurrentLMForPreTraining(PreTrainedModel):
method __init__ (line 377) | def __init__(self, config):
method _init_weights (line 396) | def _init_weights(self, module=None):
method forward (line 407) | def forward(self, input_ids: torch.Tensor, *args, **kwargs):
method _generate (line 424) | def _generate(self, input_ids, token_limit=100, temperature=0.7, steps...
class ScriptableRecurrentLMBPTT (line 428) | class ScriptableRecurrentLMBPTT(PreTrainedModel):
method __init__ (line 433) | def __init__(self, config, equals_token):
method _init_weights (line 463) | def _init_weights(self, module=None):
method set_max_recurrences_for_training (line 474) | def set_max_recurrences_for_training(self, new_max):
method forward (line 479) | def forward(self, input_ids: torch.Tensor, *args, **kwargs):
method forward_progressive (line 493) | def forward_progressive(self, input_ids):
method prog_model_call_with_masking (line 517) | def prog_model_call_with_masking(self, input_ids, n, k):
method _generate (line 548) | def _generate(self, input_ids, token_limit=100, temperature=1.0, steps...
FILE: cramming/architectures/crammed_transformer.py
class crammedTransformerConfig (line 21) | class crammedTransformerConfig(PretrainedConfig):
method __init__ (line 24) | def __init__(self, cfg_arch_container: dict = {}, **kwargs):
function construct_crammed_transformer (line 29) | def construct_crammed_transformer(cfg_arch, vocab_size):
class FFNComponent (line 39) | class FFNComponent(torch.nn.Module):
method __init__ (line 46) | def __init__(self, hidden_size, intermed_size, cfg_arch, output_size=N...
method forward (line 61) | def forward(self, hidden_states):
class TransformerLayer (line 65) | class TransformerLayer(torch.nn.Module):
method __init__ (line 68) | def __init__(self, idx, cfg_arch):
method forward (line 80) | def forward(self, states, attention_mask: Optional[torch.Tensor] = None):
class ScriptableLM (line 86) | class ScriptableLM(PreTrainedModel):
method __init__ (line 91) | def __init__(self, config):
method forward (line 106) | def forward(self, input_ids: torch.Tensor):
class ScriptableLMForPreTraining (line 124) | class ScriptableLMForPreTraining(PreTrainedModel):
method __init__ (line 129) | def __init__(self, config):
method _init_weights (line 141) | def _init_weights(self, module=None):
method forward (line 152) | def forward(self, input_ids: torch.Tensor, *args, **kwargs):
FILE: cramming/architectures/embeddings.py
class PositionalEmbedding (line 11) | class PositionalEmbedding(torch.nn.Module):
method __init__ (line 13) | def __init__(self, demb):
method forward (line 21) | def forward(self, pos_seq, bsz=None):
class RandomNoise (line 35) | class RandomNoise(torch.nn.Module):
method __init__ (line 37) | def __init__(self, embedding_dim, max_seq_length=5000):
method forward (line 41) | def forward(self, input_ids):
class RPE (line 45) | class RPE(torch.nn.Module):
method __init__ (line 52) | def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):
method forward (line 68) | def forward(self, x):
method skew (line 105) | def skew(self, QEr):
class SinusoidalPositional (line 118) | class SinusoidalPositional(torch.nn.Module):
method __init__ (line 125) | def __init__(self, embedding_dim, max_seq_length=5000):
method forward (line 137) | def forward(self, input_ids):
class ScaledSinosoidal (line 150) | class ScaledSinosoidal(SinusoidalPositional):
method __init__ (line 153) | def __init__(self, embedding_dim, max_seq_length):
method forward (line 157) | def forward(self, input_ids):
class LearnablePositional (line 170) | class LearnablePositional(torch.nn.Module):
method __init__ (line 173) | def __init__(self, embedding_dim, max_seq_length=1024):
method forward (line 178) | def forward(self, input_ids):
class LearnablePositionalRand (line 184) | class LearnablePositionalRand(torch.nn.Module):
method __init__ (line 187) | def __init__(self, embedding_dim, max_seq_length=1024):
method forward (line 193) | def forward(self, input_ids):
class Rotary (line 206) | class Rotary(torch.nn.Module):
method __init__ (line 207) | def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int =...
method get_cos_sin_cache (line 230) | def get_cos_sin_cache(self, x: torch.Tensor):
method _get_cos_sin (line 239) | def _get_cos_sin(self):
method forward (line 248) | def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
method single_forward (line 253) | def single_forward(self, inputs: torch.Tensor):
method rotate_half (line 258) | def rotate_half(self, x: torch.Tensor):
class RotarySanityCheck (line 262) | class RotarySanityCheck(torch.nn.Module):
method __init__ (line 265) | def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int =...
method get_cos_sin_cache (line 276) | def get_cos_sin_cache(self, x: torch.Tensor):
method _get_cos_sin (line 285) | def _get_cos_sin(self):
method forward (line 294) | def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
method rotate_half (line 300) | def rotate_half(self, x: torch.Tensor):
method single_forward (line 305) | def single_forward(self, inputs: torch.Tensor):
class RotaryEleutherAI (line 313) | class RotaryEleutherAI(torch.nn.Module):
method __init__ (line 329) | def __init__(self, dim_model: int, *_, **__):
method _update_cos_sin_tables (line 340) | def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int =...
method forward (line 356) | def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension: int...
method rotate_half (line 365) | def rotate_half(self, x: torch.Tensor):
method apply_rotary_pos_emb (line 371) | def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin...
class RotaryLLAMA (line 383) | class RotaryLLAMA(torch.nn.Module):
method __init__ (line 386) | def __init__(self, hidden_per_head, base=10000, max_seq_length=512, se...
method forward (line 392) | def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
method apply_rotary_emb (line 395) | def apply_rotary_emb(self, xq: torch.Tensor, xk: torch.Tensor, freqs_c...
method reshape_for_broadcast (line 404) | def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tens...
method precompute_freqs_cis (line 412) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
class FIRE (line 419) | class FIRE(torch.nn.Module):
method __init__ (line 420) | def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512....
method forward (line 446) | def forward(self, seq_length, device):
class Abacus (line 482) | class Abacus(torch.nn.Module):
method __init__ (line 485) | def __init__(self, embedding_dim, max_seq_length=1024, max_k=99):
method helper (line 491) | def helper(self, mask, device):
method forward (line 517) | def forward(self, input_ids):
FILE: cramming/architectures/huggingface_interface.py
function construct_huggingface_model (line 6) | def construct_huggingface_model(cfg_arch, vocab_size):
FILE: cramming/architectures/losses.py
class CosineLoss (line 5) | class CosineLoss(torch.nn.Module):
method __init__ (line 9) | def __init__(self, reduction: str = "mean", dim=-1, eps=1e-8) -> None:
method forward (line 16) | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
class CrossEntropyWithZLoss (line 20) | class CrossEntropyWithZLoss(torch.nn.Module):
method __init__ (line 27) | def __init__(self, ignore_index=-100, z_loss_factor=1e-4):
method forward (line 33) | def forward(self, inputs, labels):
class MSELoss (line 44) | class MSELoss(torch.nn.Module):
method __init__ (line 49) | def __init__(self, ignore_index=-100):
method forward (line 54) | def forward(self, inputs, labels):
method _label_to_onehot (line 64) | def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
class MSELossFast (line 70) | class MSELossFast(torch.nn.Module):
method __init__ (line 75) | def __init__(self, ignore_index=-100):
method forward (line 80) | def forward(self, inputs, labels):
class L1Loss (line 94) | class L1Loss(torch.nn.Module):
method __init__ (line 99) | def __init__(self, ignore_index=-100):
method forward (line 104) | def forward(self, inputs, labels):
method _label_to_onehot (line 114) | def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
class SzegedyLoss (line 120) | class SzegedyLoss(torch.nn.Module):
method __init__ (line 126) | def __init__(self, embedding_layer, ignore_index=-100, overrelaxation=...
method forward (line 133) | def forward(self, inputs, labels):
class FocalLoss (line 173) | class FocalLoss(torch.nn.Module):
method __init__ (line 174) | def __init__(self, gamma: float = 5.0, size_average: bool = True, igno...
method forward (line 180) | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch....
class IncorrectCrossEntropyLoss (line 191) | class IncorrectCrossEntropyLoss(torch.nn.CrossEntropyLoss):
method forward (line 194) | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch....
FILE: cramming/architectures/sanity_check.py
class SanityCheckforPreTraining (line 6) | class SanityCheckforPreTraining(torch.nn.Module):
method __init__ (line 9) | def __init__(self, width, vocab_size):
method forward (line 14) | def forward(
FILE: cramming/backend/optimizers/optimizer_modifiers.py
class MetaOptimizer (line 14) | class MetaOptimizer(torch.optim.Optimizer):
method __init__ (line 17) | def __init__(self, optimizer):
method __getstate__ (line 21) | def __getstate__(self):
method __setstate__ (line 24) | def __setstate__(self, state):
method __repr__ (line 27) | def __repr__(self):
method __getattr__ (line 30) | def __getattr__(self, name):
method step (line 35) | def step(self, closure=None):
class LARS (line 39) | class LARS(MetaOptimizer):
method __init__ (line 66) | def __init__(self, optimizer, trust_coefficient=0.02, clip=False, eps=...
method step (line 73) | def step(self, closure=None):
class SAM (line 132) | class SAM(MetaOptimizer):
method __init__ (line 133) | def __init__(self, base_optimizer_instance, rho=0.05):
method first_step (line 141) | def first_step(self, zero_grad=False):
method second_step (line 157) | def second_step(self, zero_grad=False):
method step (line 170) | def step(self, closure=None):
method _grad_norm (line 180) | def _grad_norm(self):
FILE: cramming/backend/optimizers/progressive_batching.py
class ProgressiveBatching (line 17) | class ProgressiveBatching(MetaOptimizer):
method __init__ (line 18) | def __init__(self, optimizer, progress_rule="norm-based", theta=0.9, m...
method step (line 33) | def step(self):
method inner_product_test (line 63) | def inner_product_test(self):
method norm_test (line 85) | def norm_test(self):
method cosine_test (line 100) | def cosine_test(self):
method coefficient_of_variation (line 119) | def coefficient_of_variation(self):
method update_sample_statistics (line 136) | def update_sample_statistics(self):
method reset_sample_statistics (line 147) | def reset_sample_statistics(self):
method copy_mean_grad (line 157) | def copy_mean_grad(self):
FILE: cramming/backend/optimizers/schedulers.py
function get_schedule_fn (line 10) | def get_schedule_fn(cfg_train, elapsed_time: float=0.0, true_budget: flo...
class DumbScheduler (line 213) | class DumbScheduler:
method __init__ (line 214) | def __init__(self, *args, **kwargs):
method step (line 217) | def step(self, *args, **kwargs):
method _initial_step (line 220) | def _initial_step(self):
method state_dict (line 225) | def state_dict(self):
method load_state_dict (line 228) | def load_state_dict(self, state_dict):
method get_last_lr (line 231) | def get_last_lr(self):
method get_lr (line 235) | def get_lr(self):
method print_lr (line 238) | def print_lr(self, is_verbose, group, lr, epoch=None):
function get_inverse_sqrt_scheduler (line 245) | def get_inverse_sqrt_scheduler(optimizer, num_warmup_steps, num_cooldown...
function get_one_cycle (line 276) | def get_one_cycle(optimizer, num_training_steps):
function get_ramp (line 288) | def get_ramp(optimizer, num_cooldown_steps, num_training_steps):
function _get_fake_step (line 302) | def _get_fake_step(current_step, initial_time, hour_budget, num_training...
function get_budget_inv_sqrt_scheduler (line 312) | def get_budget_inv_sqrt_scheduler(optimizer, hour_budget, num_warmup_ste...
function get_budget_constant_scheduler (line 335) | def get_budget_constant_scheduler(optimizer, hour_budget, num_warmup_ste...
function get_budget_linear_schedule_with_warmup (line 351) | def get_budget_linear_schedule_with_warmup(optimizer, hour_budget, num_w...
function get_budget_cosine_schedule_with_warmup (line 364) | def get_budget_cosine_schedule_with_warmup(optimizer, hour_budget, num_w...
function get_budget_cosine_half_cycles_with_warmup (line 378) | def get_budget_cosine_half_cycles_with_warmup(optimizer, hour_budget, nu...
function get_budget_one_cycle (line 392) | def get_budget_one_cycle(optimizer, hour_budget, num_training_steps, ela...
function get_budget_multi_cycle (line 406) | def get_budget_multi_cycle(optimizer, hour_budget, num_training_steps, n...
function get_budget_ramp (line 421) | def get_budget_ramp(optimizer, hour_budget, num_cooldown_steps, num_trai...
function get_budget_inv_cosine_schedule (line 436) | def get_budget_inv_cosine_schedule(optimizer, hour_budget, num_cooldown_...
function get_budget_triangle (line 454) | def get_budget_triangle(optimizer, hour_budget, num_training_steps, base...
function get_budget_dive (line 471) | def get_budget_dive(optimizer, hour_budget, num_training_steps, num_warm...
function get_budget_polynomial_decay_with_warmup (line 487) | def get_budget_polynomial_decay_with_warmup(optimizer, hour_budget, num_...
FILE: cramming/backend/prepare_backend.py
function load_backend (line 9) | def load_backend(model, tokenizer, cfg_train, cfg_impl, setup=_default_s...
FILE: cramming/backend/torch_default.py
function initialize_torch (line 41) | def initialize_torch(model, tokenizer, cfg_train, cfg_impl, setup=_defau...
class TorchEngine (line 55) | class TorchEngine(torch.nn.Module):
method __init__ (line 58) | def __init__(self, model, cfg_train, cfg_impl, setup=_default_setup, s...
method get_true_budget (line 125) | def get_true_budget(self):
method step (line 131) | def step(self, batch: dict[str, torch.Tensor]):
method to_device (line 137) | def to_device(self, batch: dict[str, torch.Tensor], keys: list[str] = ...
method forward (line 146) | def forward(self, *inputs, **kwargs):
method backward (line 153) | def backward(self, loss):
method forward_inference (line 160) | def forward_inference(self, *inputs, **kwargs):
method dynamic_generation (line 168) | def dynamic_generation(self, *inputs, temperature=0.7, token_limit=100...
method optimizer_step (line 193) | def optimizer_step(self):
method set_train_batch_size (line 209) | def set_train_batch_size(self, batch_size):
method schedule_batch_size (line 214) | def schedule_batch_size(self):
method record_batch_size (line 235) | def record_batch_size(self):
method record_tokens_per_step (line 241) | def record_tokens_per_step(self):
method retrieve_model_state_dict (line 246) | def retrieve_model_state_dict(self):
method _init_distributed (line 261) | def _init_distributed(self, model):
method load_checkpoint (line 273) | def load_checkpoint(self, cfg_arch, file, skip_optim_state=False) -> D...
method load_metadata (line 332) | def load_metadata(self, metadata: Dict[str, Any]):
method save_training_checkpoint (line 337) | def save_training_checkpoint(self, checkpoint_directory: str, checkpoi...
method save_final_model (line 360) | def save_final_model(self, base_directory, identifier, tokenizer, cfg_...
method save_model (line 378) | def save_model(
method push_to_hub (line 420) | def push_to_hub(self, tokenizer, cfg, dryrun=False):
function _load_optimizer (line 476) | def _load_optimizer(model, cfg_train, cfg_impl, elapsed_time=0.0, true_b...
FILE: cramming/backend/utils.py
function group_parameters (line 14) | def group_parameters(model, cfg_train):
function get_model_engine_tokenizer_dataloaders (line 32) | def get_model_engine_tokenizer_dataloaders(cfg, setup, train_eval: bool ...
function load_model_checkpoint (line 113) | def load_model_checkpoint(model, model_dir, forward_only_model_with_skip...
FILE: cramming/data/arithmetic_tokenizers.py
class CustomCharLevelTokenizerForAddingPadding (line 11) | class CustomCharLevelTokenizerForAddingPadding(PreTrainedTokenizer):
method __init__ (line 13) | def __init__(self, **kwargs):
method vocab_size (line 46) | def vocab_size(self):
method get_vocab (line 49) | def get_vocab(self):
method _tokenize (line 52) | def _tokenize(self, text):
method _convert_token_to_id (line 59) | def _convert_token_to_id(self, token):
method _convert_id_to_token (line 62) | def _convert_id_to_token(self, index):
method __call__ (line 66) | def __call__(self, text, **kwargs):
method decode (line 72) | def decode(self, token_ids, **kwargs):
class CustomCharLevelTokenizerForAddingPaddingWithIndexHints (line 78) | class CustomCharLevelTokenizerForAddingPaddingWithIndexHints(PreTrainedT...
method __init__ (line 80) | def __init__(self, **kwargs):
method vocab_size (line 115) | def vocab_size(self):
method get_vocab (line 118) | def get_vocab(self):
method _tokenize (line 121) | def _tokenize(self, text):
method _convert_token_to_id (line 128) | def _convert_token_to_id(self, token):
method _convert_id_to_token (line 131) | def _convert_id_to_token(self, index):
method __call__ (line 135) | def __call__(self, text, **kwargs):
method decode (line 141) | def decode(self, token_ids, **kwargs):
class CustomCharLevelTokenizerSort (line 147) | class CustomCharLevelTokenizerSort(PreTrainedTokenizer):
method __init__ (line 149) | def __init__(self, **kwargs):
method vocab_size (line 190) | def vocab_size(self):
method get_vocab (line 193) | def get_vocab(self):
method _tokenize (line 196) | def _tokenize(self, text):
method _convert_token_to_id (line 202) | def _convert_token_to_id(self, token):
method _convert_id_to_token (line 205) | def _convert_id_to_token(self, index):
method __call__ (line 209) | def __call__(self, text, **kwargs):
method decode (line 215) | def decode(self, token_ids, **kwargs):
FILE: cramming/data/curriculum_sorting.py
function _sort_tokenized_dataset_by_unigram (line 10) | def _sort_tokenized_dataset_by_unigram(tokenized_dataset, tokenizer, num...
function _sort_tokenized_dataset_by_token (line 53) | def _sort_tokenized_dataset_by_token(tokenized_dataset, tokenizer, targe...
function _sort_tokenized_dataset_by_word_length (line 92) | def _sort_tokenized_dataset_by_word_length(tokenized_dataset, tokenizer,...
FILE: cramming/data/deduplicate.py
function deduplicate_huggingface_dataset (line 40) | def deduplicate_huggingface_dataset(dataset, threshold=100, original_cwd...
function _write_tmp_file (line 60) | def _write_tmp_file(dataset, dirname):
function _make_suffix_array (line 69) | def _make_suffix_array(text_file, tmpdir, path_to_rust_code):
function _finish_and_return_to_hf_dataset (line 139) | def _finish_and_return_to_hf_dataset(original_text_file, remove_file_cac...
FILE: cramming/data/pretraining_preparation.py
function get_num_workers (line 41) | def get_num_workers(cfg_impl):
function load_pretraining_corpus (line 50) | def load_pretraining_corpus(cfg_data, cfg_impl, data_dir: str = None):
function load_tokenized_data (line 147) | def load_tokenized_data(tokenized_dataset_path):
function convert_to_hf_dataset (line 151) | def convert_to_hf_dataset(tokenized_data):
function preprocess_dataset (line 162) | def preprocess_dataset(cfg_data, download_path, num_threads=1, max_raw_c...
function _move_stream_to_fixed_map (line 233) | def _move_stream_to_fixed_map(raw_data_streamed, max_entries_in_raw_data...
function _huggingface_preprocessing (line 263) | def _huggingface_preprocessing(raw_dataset, tokenizer, cfg_data, num_thr...
function _load_fake_dataset (line 357) | def _load_fake_dataset(cfg_data, details, path=None):
function _concatenate_entries (line 366) | def _concatenate_entries(dataset, num_entries_in_group, num_threads):
function raw_dataset_preprocessing (line 402) | def raw_dataset_preprocessing(raw_dataset, num_threads, cfg_data):
function main_process_first (line 454) | def main_process_first():
function _load_from_hub (line 478) | def _load_from_hub(cfg_data, data_path):
function prepare_dataloaders (line 498) | def prepare_dataloaders(datasets, tokenizer, cfg_train, cfg_impl) -> Dic...
function prepare_pretraining_dataloader (line 506) | def prepare_pretraining_dataloader(dataset, tokenizer, cfg_train, cfg_im...
function prepare_validation_dataloader (line 554) | def prepare_validation_dataloader(dataset, tokenizer, cfg_impl):
class FastDataCollatorForLanguageModeling (line 586) | class FastDataCollatorForLanguageModeling(transformers.DataCollatorForLa...
method __init__ (line 587) | def __init__(self, *args, create_labels_entry=False, **kwargs):
method torch_call (line 592) | def torch_call(self, examples):
class InfiniteDataLoader (line 616) | class InfiniteDataLoader(torch.utils.data.DataLoader):
method __init__ (line 619) | def __init__(self, *args, **kwargs):
method __iter__ (line 625) | def __iter__(self):
method __next__ (line 628) | def __next__(self):
method set_epoch (line 640) | def set_epoch(self, epoch: int):
class RuntimeInfiniteDataLoader (line 643) | class RuntimeInfiniteDataLoader(torch.utils.data.DataLoader):
method __init__ (line 646) | def __init__(self, tokenizer, device, *args, **kwargs):
method get_arithmetic (line 661) | def get_arithmetic(self, n, m):
method tokenize_batch (line 685) | def tokenize_batch(self, batch):
method __iter__ (line 696) | def __iter__(self):
method __next__ (line 699) | def __next__(self):
FILE: cramming/data/tokenizer_preparation.py
function get_tokenizer (line 12) | def get_tokenizer(tokenizer_type: str):
function load_tokenizer (line 27) | def load_tokenizer(tokenizer_path_or_name, seq_length=512, vocab_size=No...
function construct_tokenizer (line 38) | def construct_tokenizer(raw_datasets, cfg_data, path, known_tokens=[]):
function _download_tokenizer (line 48) | def _download_tokenizer(tokenizer_path_or_name, seq_length, cache_dir=No...
function _get_sane_token_args (line 57) | def _get_sane_token_args():
function _get_sane_normalizers (line 67) | def _get_sane_normalizers(force_english_keyboard=False, force_lowercase=...
function _construct_tokenizer (line 85) | def _construct_tokenizer(raw_datasets, cfg_data, known_tokens=[]):
FILE: cramming/data/utils.py
function checksum_config (line 17) | def checksum_config(cfg):
function stage_dataset (line 26) | def stage_dataset(data_directory_path, local_staging_dir):
function _get_size (line 65) | def _get_size(start_path="."):
function detailed_OSError (line 78) | def detailed_OSError(e):
FILE: cramming/utils.py
function main_launcher (line 37) | def main_launcher(cfg, main_fn, job_name=""):
function get_cpus (line 78) | def get_cpus() -> int:
function system_startup (line 90) | def system_startup(cfg):
function is_main_process (line 184) | def is_main_process():
function num_processes (line 188) | def num_processes():
function find_pretrained_checkpoint (line 194) | def find_pretrained_checkpoint(checkpoint: str, local_checkpoint_folder:...
function save_summary (line 256) | def save_summary(table_name, cfg, stats, local_time, setup, original_cwd...
function save_to_table (line 334) | def save_to_table(out_dir, table_name, dryrun, **kwargs):
function set_random_seed (line 367) | def set_random_seed(seed=233):
function set_deterministic (line 378) | def set_deterministic():
function avg_n_dicts (line 386) | def avg_n_dicts(dicts):
function dump_metrics (line 406) | def dump_metrics(cfg, metrics):
function _initialize_wandb (line 420) | def _initialize_wandb(setup, cfg):
function wandb_log (line 440) | def wandb_log(stats, cfg):
function flatten (line 448) | def flatten(d, parent_key="", sep="_"):
function collect_system_metrics (line 460) | def collect_system_metrics(cfg, metrics, kWh_counter, setup):
function get_kWh (line 479) | def get_kWh(kWh_counter, setup):
function pathfinder (line 486) | def pathfinder(cfg):
FILE: create_data_split.py
function generate_no_carry_addition (line 20) | def generate_no_carry_addition(n, m):
function has_carry (line 31) | def has_carry(num1, num2):
function generate_dataset (line 39) | def generate_dataset(dir_name, operation, n, m, num_examples, base_folde...
function tokenize_and_save_dataset (line 128) | def tokenize_and_save_dataset(dataset, tokenizer, directory, test_split_...
function character_histogram (line 190) | def character_histogram(dir_name, condense_white_space=False):
function token_histogram (line 238) | def token_histogram(dir_name, tokenizer_type="normal"):
function main_dataset_gen (line 296) | def main_dataset_gen(dir_name, op, n, m, num_samples, exact=False, keep_...
function tokenize_main (line 305) | def tokenize_main(dir_name, tokenizer_type, test_split_ratio=0.05):
function pick_char_set (line 334) | def pick_char_set(max_len):
function hints_helper (line 346) | def hints_helper(num_str, chars):
function bucket_method_gen (line 353) | def bucket_method_gen(n=3, m=3, operation='+', limit=1000, p=0, no_carry...
function bucket_method_main (line 412) | def bucket_method_main(n, m, operation, limit, dir_name, p=0, no_carry_a...
function uniform_distribution_sort_basic (line 437) | def uniform_distribution_sort_basic(maximum_number_of_digts, maximum_len...
function bucket_uniform_distribution (line 467) | def bucket_uniform_distribution(maximum_number_of_digts, maximum_length,...
function uniform_distribution_sort_main (line 476) | def uniform_distribution_sort_main(FLAGS, dir_name):
function main (line 507) | def main():
FILE: create_pos_or_variants.py
function one_hot_vector (line 6) | def one_hot_vector(length, index=None):
function zero_vector (line 14) | def zero_vector(length):
function main (line 19) | def main():
FILE: dataset_analysis.py
function read_dataset (line 9) | def read_dataset(dir_name, condense_white_space=False):
function remove_leading_zeros (line 26) | def remove_leading_zeros(match):
function count_digits (line 30) | def count_digits(dataset, remove_formatting=False):
function plot_pairs_heatmap (line 53) | def plot_pairs_heatmap(pairs, dir_name=".", remove_formatting=False):
function line_plotter (line 73) | def line_plotter(data, name, dir_name=".", remove_formatting=False):
function consecutive_digit_counts (line 89) | def consecutive_digit_counts(input_strings):
function create_repetition_heatmap (line 119) | def create_repetition_heatmap(data, dir_name=".", remove_formatting=False):
function main (line 134) | def main(dir_name):
FILE: load_local_model.py
function main_load_process (line 23) | def main_load_process(cfg, setup):
function launch (line 42) | def launch(cfg):
FILE: pretrain.py
function main_training_process (line 17) | def main_training_process(cfg, setup):
function get_time_elapsed (line 187) | def get_time_elapsed(start_time: float, additional_time: float = 0.0) ->...
function check_checkpointing (line 190) | def check_checkpointing(data_idx: int, cfg_impl, last_save_time) -> bool:
function check_deadline (line 196) | def check_deadline(launch_time, hour_limit, prev_budget: float = 0.0, ov...
function check_early_termination (line 205) | def check_early_termination(start_time, loss, early_termination, prev_bu...
function collect_stats (line 217) | def collect_stats(data_step, loss_vals, log_ppls, model_outputs, train_t...
function validate (line 265) | def validate(model_engine, validloader, setup, cfg):
function generate (line 313) | def generate(model_engine, tokenizer, example_prompts, token_limit=10, t...
function flag_communication (line 331) | def flag_communication(training_allowed):
function launch (line 345) | def launch(cfg):
FILE: pretty_plotter.py
function find_file (line 12) | def find_file(starting_directory, target_file):
function grid_plotter (line 18) | def grid_plotter(data, type="accs", path="", title=None, rect_size=20, u...
function main (line 44) | def main():
FILE: pretty_plotter_big.py
function grid_plotter (line 14) | def grid_plotter(data, type="accs", path="", title=None, rect_size=20):
function main (line 40) | def main():
FILE: pretty_plotter_sort.py
function grid_plotter (line 8) | def grid_plotter(data, title="", path=None):
function run (line 31) | def run(names, short_hand, base_dir, sort_plots_path):
FILE: sort_eval.py
function grid_plotter (line 21) | def grid_plotter(data, type="accs", name='_large', extra_path=None):
function grid_logic (line 44) | def grid_logic(cfg):
function main (line 189) | def main(cfg):
function launch (line 375) | def launch(cfg):
FILE: upload_processed_dataset.py
function upload (line 18) | def upload(cfg, setup):
function launch (line 81) | def launch(cfg):
Condensed preview — 132 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (516K chars).
[
{
"path": ".gitignore",
"chars": 2097,
"preview": "outputs\ntables/*/*.csv\ntables/*/*.csv#\ntables/*.csv\ntables/*.csv#\ntables/*.ods\n*.png\n*.pdf\n\n# torchdynamo debug\nisolate\n"
},
{
"path": ".pre-commit-config.yaml",
"chars": 839,
"preview": "# precommit hooks from https://github.com/ashleve/lightning-hydra-template\nrepos:\n - repo: https://github.com/pre-commi"
},
{
"path": "LICENSE",
"chars": 1084,
"preview": "MIT License\n\nCopyright (c) 2024 Sean McLeish, Jonas Geiping\n\nPermission is hereby granted, free of charge, to any person"
},
{
"path": "MANIFEST.in",
"chars": 168,
"preview": "# added by check-manifest\ninclude *.py\ninclude *.yaml\nrecursive-include cramming *.md\nrecursive-include cramming *.yaml\n"
},
{
"path": "README.md",
"chars": 7043,
"preview": "# Transformers Can Do Arithmetic with the Right Embeddings! [Link to arXiv paper](https://arxiv.org/abs/2405.17399)\n\nA j"
},
{
"path": "abacus.py",
"chars": 2777,
"preview": "\"\"\"Implementation of abacus embeddings\"\"\"\n# Example of how to extract digit tokens to pass into constructor\n# digit_toke"
},
{
"path": "arithmetic_eval_quicker.py",
"chars": 25639,
"preview": "import logging\nimport hydra\nfrom omegaconf import OmegaConf\nimport cramming\nimport torch\nfrom safetensors.torch import l"
},
{
"path": "cramming/__init__.py",
"chars": 1329,
"preview": "\"\"\"Initialize cramming\"\"\"\n\nfrom cramming import utils\nfrom cramming.architectures import construct_model\nfrom cramming.b"
},
{
"path": "cramming/architectures/__init__.py",
"chars": 137,
"preview": "\"\"\"This module handles all questions of model architecture.\"\"\"\n\nfrom .construction import construct_model\n\n__all__ = [\"c"
},
{
"path": "cramming/architectures/attention.py",
"chars": 19297,
"preview": "\"\"\"Attention modules. Most code heavily stolen from the GPT-neoX implementation\"\"\"\nimport torch\nfrom transformers.models"
},
{
"path": "cramming/architectures/components.py",
"chars": 21997,
"preview": "\"\"\"Basic transformer components.\"\"\"\n\nimport torch\n\nfrom typing import Tuple\nfrom functools import partial\n\nfrom .embeddi"
},
{
"path": "cramming/architectures/construction.py",
"chars": 1779,
"preview": "\"\"\"Interface to construct models.\"\"\"\n\nfrom .huggingface_interface import construct_huggingface_model\nfrom .sanity_check "
},
{
"path": "cramming/architectures/crammed_depthrecurrent.py",
"chars": 27018,
"preview": "\"\"\"Variant for modifications of the transformer architecture that are depth-recurrent\"\"\"\nimport torch\nfrom transformers "
},
{
"path": "cramming/architectures/crammed_transformer.py",
"chars": 6882,
"preview": "\"\"\"Base file for modifications of the transformer architecture\"\"\"\nimport torch\nfrom transformers import PretrainedConfig"
},
{
"path": "cramming/architectures/embeddings.py",
"chars": 25047,
"preview": "\"\"\"Non-standard embedding implementations.\"\"\"\n\nimport torch\nimport math\n\nfrom typing import Tuple\nfrom einops import rep"
},
{
"path": "cramming/architectures/huggingface_interface.py",
"chars": 943,
"preview": "\"\"\"HF model variations based on reconfiguring their huggingface implementations.\"\"\"\n\nimport transformers\n\n\ndef construct"
},
{
"path": "cramming/architectures/losses.py",
"chars": 8298,
"preview": "import torch\nimport math\n\n\nclass CosineLoss(torch.nn.Module):\n __constants__ = [\"reduction\"]\n reduction: str\n\n "
},
{
"path": "cramming/architectures/sanity_check.py",
"chars": 787,
"preview": "\"\"\"Sanity Check architecture.\"\"\"\nimport torch\nfrom typing import Optional\n\n\nclass SanityCheckforPreTraining(torch.nn.Mod"
},
{
"path": "cramming/backend/__init__.py",
"chars": 299,
"preview": "\"\"\"This module implements interfaces to the various backends.\"\"\"\n\nfrom .prepare_backend import load_backend\nfrom .utils "
},
{
"path": "cramming/backend/optimizers/__init__.py",
"chars": 137,
"preview": "from .progressive_batching import ProgressiveBatching\nfrom .optimizer_modifiers import SAM, LARS\nfrom .schedulers import"
},
{
"path": "cramming/backend/optimizers/optimizer_modifiers.py",
"chars": 7267,
"preview": "\"\"\"This is the apex LARS implementation, from the apex repository.\n\nIt implements LARS + optional clipping\n\nhttps://gith"
},
{
"path": "cramming/backend/optimizers/progressive_batching.py",
"chars": 6932,
"preview": "\"\"\"Implementation of a progressive batching meta optimizer.\nThe optimizer may defer an optimization step until gradient "
},
{
"path": "cramming/backend/optimizers/schedulers.py",
"chars": 21421,
"preview": "\"\"\"Misc. optimizer implementations.\"\"\"\nimport transformers\nimport math\n\nfrom torch.optim.lr_scheduler import LambdaLR\nim"
},
{
"path": "cramming/backend/prepare_backend.py",
"chars": 560,
"preview": "\"\"\"Instantiate backend objects in a congruent format.\"\"\"\nimport torch\n\nfrom .torch_default import initialize_torch\n\n_def"
},
{
"path": "cramming/backend/torch_default.py",
"chars": 25748,
"preview": "\"\"\"Basic training backend engine for pytorch training with all bells and whistles.\n\nInterface set up to be compliant wit"
},
{
"path": "cramming/backend/utils.py",
"chars": 6121,
"preview": "import logging\nimport os\nimport torch\n\nimport logging\n\nfrom safetensors.torch import load_file, save_file\nimport crammin"
},
{
"path": "cramming/config/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/arch/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/arch/albert.yaml",
"chars": 1833,
"preview": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This is set up to be as close to ALBERT-large (Lan et a"
},
{
"path": "cramming/config/arch/crammed-depthrecurrent.yaml",
"chars": 2267,
"preview": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This inherits architecture changes from the crammed-ber"
},
{
"path": "cramming/config/arch/crammed-fakeRNN.yaml",
"chars": 905,
"preview": "# Instantiates a (non-huggingface) scriptable encoder-based LM with BERT as baseline\n# Modernized version of bert-c5\n\n# "
},
{
"path": "cramming/config/arch/crammed-janus.yaml",
"chars": 2134,
"preview": "# Instantiates a (non-huggingface) scriptable janus-type RNN, right now with all tested bells-and-whistles\n\n# These are "
},
{
"path": "cramming/config/arch/crammed-rnn.yaml",
"chars": 1181,
"preview": "# Instantiates a (non-huggingface) scriptable encoder-based LM with BERT as baseline\n# Modernized version of bert-c5\n\n# "
},
{
"path": "cramming/config/arch/crammed-stack-janus.yaml",
"chars": 2520,
"preview": "# Instantiates a (non-huggingface) scriptable janus-type RNN, right now with all tested bells-and-whistles\n\n# These are "
},
{
"path": "cramming/config/arch/crammed-tiny.yaml",
"chars": 1557,
"preview": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This is the tiny setting, modified from bert-tiny with "
},
{
"path": "cramming/config/arch/crammed-transformer.yaml",
"chars": 1534,
"preview": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This inherits architecture changes from the crammed-ber"
},
{
"path": "cramming/config/arch/gpt2-base.yaml",
"chars": 1362,
"preview": "# Instantiates a (non-huggingface) scriptable decoder-based LM\n# This matches the gpt2 settings in the custom implementa"
},
{
"path": "cramming/config/arch/hf-gpt2.yaml",
"chars": 519,
"preview": "# These are the huggingface bert parameters\n\nmodel_type: \"gpt2\"\n\nn_ctx: 1024\nn_embd: 768\nn_head: 12\nn_layer: 12\nn_positi"
},
{
"path": "cramming/config/arch/sanitycheck.yaml",
"chars": 46,
"preview": "model_type: SanityCheckLM\n\nwidth: 1024 # 8352\n"
},
{
"path": "cramming/config/cfg_eval.yaml",
"chars": 1721,
"preview": "# Configuration defaults\n# Settings are separated into hyperparameters for architecture, data, implementation and train/"
},
{
"path": "cramming/config/cfg_pretrain.yaml",
"chars": 838,
"preview": "# Configuration defaults\n# Settings are separated into hyperparameters for architecture, data, implementation and train/"
},
{
"path": "cramming/config/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/data/arithmetic.yaml",
"chars": 949,
"preview": "name: arithmetic\ndefaults:\n - sources:\n - arithmetic\n\n\n\n# all the below stuff may not be required\n# Preprocessing\n"
},
{
"path": "cramming/config/data/c4-subset-processed.yaml",
"chars": 919,
"preview": "# This would be a slice of C4\nname: c4-subset\ndefaults:\n - sources:\n - c4\n\n# Preprocessing\nnormalizer:\n force_low"
},
{
"path": "cramming/config/data/openweb.yaml",
"chars": 936,
"preview": "# Selection of English sources from the ROOTS project\nname: openweb\ndefaults:\n - sources:\n - openwebtext\n\n# Prepro"
},
{
"path": "cramming/config/data/proofpile.yaml",
"chars": 921,
"preview": "name: proofpile\ndefaults:\n - sources:\n - proofpiledata\n\n# Preprocessing\nnormalizer:\n force_lowercase: False\n str"
},
{
"path": "cramming/config/data/sanity-check-1.yaml",
"chars": 908,
"preview": "# Just a bunch of fake data ...\nname: sanity-check-1\ndefaults:\n - sources:\n - fake\n\n#\n# Preprocessing\nnormalizer: "
},
{
"path": "cramming/config/data/sanity-check-2.yaml",
"chars": 1042,
"preview": "# Just a tiny test dataset ...\nname: sanity-check-2\n# https://hydra.cc/docs/patterns/select_multiple_configs_from_config"
},
{
"path": "cramming/config/data/sources/ag_news.yaml",
"chars": 171,
"preview": "# For sanity testing\nag_news:\n provider: huggingface\n partition: default\n split: train\n\n streaming: False\n\n remove_"
},
{
"path": "cramming/config/data/sources/arithmetic.yaml",
"chars": 282,
"preview": "# Just a bunch of fake data ...\narithmetic:\n provider: arithmetic\n split:\n\n randgen_seed: 0\n size: 2048\n\n tokenized"
},
{
"path": "cramming/config/data/sources/bookcorpus.yaml",
"chars": 246,
"preview": "# The bookcorpus dataset, drawn from it huggingface mirror\nbookcorpus:\n provider: huggingface\n partition: plain_text\n "
},
{
"path": "cramming/config/data/sources/c4.yaml",
"chars": 230,
"preview": "# The wikipedia en dataset, drawn from it huggingface mirror\nc4:\n provider: huggingface\n partition: en\n split: train\n"
},
{
"path": "cramming/config/data/sources/dash_books.yaml",
"chars": 222,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_book_dash_books:\n provider: huggingface\n partition:\n split: train\n\n strea"
},
{
"path": "cramming/config/data/sources/fake.yaml",
"chars": 96,
"preview": "# Just a bunch of fake data ...\nfake:\n provider: fake\n split:\n\n randgen_seed: 0\n size: 2048\n"
},
{
"path": "cramming/config/data/sources/iwslt.yaml",
"chars": 222,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_ted_talks_iwslt:\n provider: huggingface\n partition:\n split: train\n\n strea"
},
{
"path": "cramming/config/data/sources/local.yaml",
"chars": 98,
"preview": "# Just a bunch of fake data ...\nlocal:\n provider: local\n split:\n\n randgen_seed: 0\n size: 2048\n"
},
{
"path": "cramming/config/data/sources/no_code_stackexchange.yaml",
"chars": 228,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_no_code_stackexchange:\n provider: huggingface\n partition:\n split: train\n\n "
},
{
"path": "cramming/config/data/sources/openwebtext.yaml",
"chars": 236,
"preview": "# The open webtext replication, as mirrored on HF\nopenwebtext:\n provider: huggingface\n partition: plain_text\n split: "
},
{
"path": "cramming/config/data/sources/oscar.yaml",
"chars": 327,
"preview": "# The oscar dataset, drawn from it huggingface mirror\n# should be 1.2T in this deduplicated version\noscar:\n provider: h"
},
{
"path": "cramming/config/data/sources/proofpiledata.yaml",
"chars": 316,
"preview": "# The open webtext replication, as mirrored on HF\nEleutherAI/proof-pile-2:\n provider: huggingface\n partition: open-web"
},
{
"path": "cramming/config/data/sources/the_pile.yaml",
"chars": 3004,
"preview": "#\nthe_pile:\n provider: local\n file_type: json\n files:\n - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n - \"/fs/cml"
},
{
"path": "cramming/config/data/sources/the_pileCC.yaml",
"chars": 2973,
"preview": "#\nthe_pileCC:\n provider: local\n file_type: json\n files:\n - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n - \"/fs/c"
},
{
"path": "cramming/config/data/sources/the_pile_dedup.yaml",
"chars": 235,
"preview": "# The EleutherAI/the_pile_deduplicated\nEleutherAI/the_pile_deduplicated:\n provider: huggingface\n partition:\n split: t"
},
{
"path": "cramming/config/data/sources/the_pile_natural.yaml",
"chars": 3019,
"preview": "#\nthe_pile_natural:\n provider: local\n file_type: json\n files:\n - \"/fs/cml-datasets/Pile/train/00.jsonl.zst\"\n - "
},
{
"path": "cramming/config/data/sources/the_pile_stream.yaml",
"chars": 348,
"preview": "# Pile streaming from huggingface with new streaming tech :>\n# should be 1.2T in this deduplicated version\nEleutherAI/th"
},
{
"path": "cramming/config/data/sources/uncorpus.yaml",
"chars": 215,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_uncorpus:\n provider: huggingface\n partition:\n split: train\n\n streaming: T"
},
{
"path": "cramming/config/data/sources/uspto.yaml",
"chars": 221,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_the_pile_uspto:\n provider: huggingface\n partition:\n split: train\n\n stream"
},
{
"path": "cramming/config/data/sources/wikibooks.yaml",
"chars": 217,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_wikibooks:\n provider: huggingface\n partition:\n split: train\n\n streaming: "
},
{
"path": "cramming/config/data/sources/wikinews.yaml",
"chars": 216,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_wikinews:\n provider: huggingface\n partition:\n split: train\n\n streaming: F"
},
{
"path": "cramming/config/data/sources/wikipedia.yaml",
"chars": 253,
"preview": "# The wikipedia en dataset, drawn from it huggingface mirror\nwikipedia:\n provider: huggingface\n partition: 20220301.en"
},
{
"path": "cramming/config/data/sources/wikiquote.yaml",
"chars": 216,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_wikiquote:\n provider: huggingface\n partition:\n split: train\n\n streaming: "
},
{
"path": "cramming/config/data/sources/wikiversity.yaml",
"chars": 218,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_wikiversity:\n provider: huggingface\n partition:\n split: train\n\n streaming"
},
{
"path": "cramming/config/data/sources/wikivoyage.yaml",
"chars": 217,
"preview": "# A part of ROOTS\nbigscience-data/roots_en_wikivoyage:\n provider: huggingface\n partition:\n split: train\n\n streaming:"
},
{
"path": "cramming/config/eval/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/eval/pythia.yaml",
"chars": 407,
"preview": "# defaults:\n# - optim: adam\n# - tasks:\n # - winogrande\n # - lambada_openai\n # - piqa\n # - winogr"
},
{
"path": "cramming/config/eval/tasks/lambada_openai.yaml",
"chars": 44,
"preview": "# dataset-specific settings\nlambada_openai:\n"
},
{
"path": "cramming/config/eval/tasks/winogrande.yaml",
"chars": 40,
"preview": "# dataset-specific settings\nwinogrande:\n"
},
{
"path": "cramming/config/hydra/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/hydra/job_logging/custom.yaml",
"chars": 445,
"preview": "# python logging configuration for tasks\nversion: 1\nformatters:\n simple:\n format: \"[%(asctime)s] %(message)s\"\nhandle"
},
{
"path": "cramming/config/impl/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/impl/_default.yaml",
"chars": 3788,
"preview": "# Settings for implementation details\n# These settings \"should\" not influence the outcome of the computation in major wa"
},
{
"path": "cramming/config/impl/torch-default.yaml",
"chars": 2879,
"preview": "# Settings for implementation details\n# These settings \"should\" not influence the outcome of the computation in major wa"
},
{
"path": "cramming/config/train/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cramming/config/train/common.yaml",
"chars": 865,
"preview": "# Basic hyperparameter for normal BERT pretraining\n# working hard here to separate \"impl\" implementation details and \"tr"
},
{
"path": "cramming/config/train/cramming.yaml",
"chars": 1136,
"preview": "# Version 4 of changes to bert training hyperparameters\n# Optimizes MLM rate for torch.compile, includes improved weight"
},
{
"path": "cramming/config/train/janus-regime.yaml",
"chars": 834,
"preview": "# Version 4 of changes to bert training hyperparameters\n# Optimizes MLM rate for torch.compile, includes improved weight"
},
{
"path": "cramming/config/train/optim/adafactor.yaml",
"chars": 177,
"preview": "type: Adafactor\n\nlr: 0.001\neps:\n - 1e-30\n - 0.001\nclip_threshold: 1.0\ndecay_rate: -0.8\nbeta1:\nweight_decay: 0.0\nscale_"
},
{
"path": "cramming/config/train/optim/adahessian.yaml",
"chars": 100,
"preview": "type: AdaHessian\n\nlr: 0.15\nbetas:\n - 0.9\n - 0.98\neps: 1e-12\nweight_decay: 0.01\nhessian_power: 1.0\n"
},
{
"path": "cramming/config/train/optim/adam.yaml",
"chars": 100,
"preview": "type: AdamW\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.98\neps: 1e-12\nweight_decay: 0.01\namsgrad: False\nfused:\n"
},
{
"path": "cramming/config/train/optim/adam8bit.yaml",
"chars": 96,
"preview": "type: Adam8bit\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.98\neps: 1e-12\nweight_decay: 0.01\namsgrad: False\n"
},
{
"path": "cramming/config/train/optim/adam_classic.yaml",
"chars": 92,
"preview": "type: Adam\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.999\neps: 1e-8\nweight_decay: 0.01\namsgrad: False\n"
},
{
"path": "cramming/config/train/optim/adamscale.yaml",
"chars": 114,
"preview": "type: AdamWScale\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.98\neps: 1e-12\nweight_decay: 0.01\ncorrect_bias: True # adamw fix\n"
},
{
"path": "cramming/config/train/optim/agd.yaml",
"chars": 21,
"preview": "type: AGD\n\ngain: 1.0\n"
},
{
"path": "cramming/config/train/optim/lion.yaml",
"chars": 92,
"preview": "type: Lion\n\nlr: 1e-4\nbetas:\n - 0.9\n - 0.99\n# use 0.95, 0.98 if unstable\nweight_decay: 0.1\n"
},
{
"path": "cramming/config/train/optim/radam.yaml",
"chars": 78,
"preview": "type: RAdam\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.98\neps: 1e-12\nweight_decay: 0.01\n"
},
{
"path": "cramming/config/train/optim/sgd.yaml",
"chars": 85,
"preview": "type: SGD\n\nlr: 0.0005\nmomentum: 0.9\ndampening: 0.0\nweight_decay: 0.01\nnesterov: True\n"
},
{
"path": "cramming/config/train/optim/shampoo.yaml",
"chars": 1421,
"preview": "type: Shampoo\n\nlr: 0.0005\nbetas:\n - 0.9\n - 0.98\nepsilon: 1e-12\nuse_bias_correction: True\nadam_w_mode: True\nweight_deca"
},
{
"path": "cramming/config/train/optim_mod/disabled.yaml",
"chars": 11,
"preview": "name: none\n"
},
{
"path": "cramming/config/train/optim_mod/larc.yaml",
"chars": 57,
"preview": "name: LARC\n\ntrust_coefficient: 0.02\nclip: True\neps: 1e-8\n"
},
{
"path": "cramming/config/train/optim_mod/lars.yaml",
"chars": 58,
"preview": "name: LARS\n\ntrust_coefficient: 0.02\nclip: False\neps: 1e-8\n"
},
{
"path": "cramming/config/train/optim_mod/progressive.yaml",
"chars": 125,
"preview": "name: progressive-batching\n\nprogress_rule: norm-based\n\nmonotone: False\ntheta: 0.9\n\nmin_sample_guard: 2\nmax_sample_guard:"
},
{
"path": "cramming/config/train/optim_mod/sam.yaml",
"chars": 20,
"preview": "name: SAM\nrho: 0.05\n"
},
{
"path": "cramming/config/wandb/default.yaml",
"chars": 90,
"preview": "enabled: True\nentity: placeholder # change this obviously ;>\nproject: arithmetic\ntags: []\n"
},
{
"path": "cramming/config/wandb/none.yaml",
"chars": 41,
"preview": "enabled: False\nentity:\nproject:\ntags: []\n"
},
{
"path": "cramming/data/__init__.py",
"chars": 136,
"preview": "\"\"\"This module handles and hides the data away ;)\"\"\"\n\nfrom .pretraining_preparation import load_pretraining_corpus, prep"
},
{
"path": "cramming/data/arithmetic_tokenizers.py",
"chars": 9164,
"preview": "\"\"\"\nCharacter level tokenizers for arithemtic projects\nMultiple tokenizers for different tasks\n\"\"\"\n\nfrom transformers im"
},
{
"path": "cramming/data/curriculum_sorting.py",
"chars": 4738,
"preview": "\"\"\"Baseline curricula.\"\"\"\nimport torch\nimport numpy as np\n\nimport logging\n\nlog = logging.getLogger(__name__)\n\n\ndef _sort"
},
{
"path": "cramming/data/deduplicate.py",
"chars": 6458,
"preview": "\"\"\"This is glue code to connect to the rust-based deduplication of https://github.com/google-research/deduplicate-text-d"
},
{
"path": "cramming/data/pretraining_preparation.py",
"chars": 30607,
"preview": "\"\"\"Prepare and preprocess datasets.\"\"\"\n\nimport torch\nimport datasets\nimport hydra\nimport pandas as pd\nimport os\nimport c"
},
{
"path": "cramming/data/tokenizer_preparation.py",
"chars": 9935,
"preview": "\"\"\"Tokenizer functionality.\n\nNote: CANNOT name this file \"tokenizers.py ;>\n\"\"\"\n\nfrom transformers import AutoTokenizer, "
},
{
"path": "cramming/data/utils.py",
"chars": 4220,
"preview": "\"\"\"Various utilities.\"\"\"\nimport os\nfrom omegaconf import OmegaConf\nimport hashlib\nimport json\nimport shutil\nimport subpr"
},
{
"path": "cramming/utils.py",
"chars": 20673,
"preview": "\"\"\"System utilities.\"\"\"\n\nimport socket\nimport sys\n\nimport os\nimport csv\nimport yaml\nimport psutil\nimport pynvml\n\nimport "
},
{
"path": "create_data_split.py",
"chars": 26878,
"preview": "from transformers import PreTrainedTokenizer\nimport random\nimport os\nimport torch\nfrom transformers import AutoTokenizer"
},
{
"path": "create_pos_or_variants.py",
"chars": 3890,
"preview": "import numpy as np\nimport argparse\nimport random\nimport os\n\ndef one_hot_vector(length, index=None):\n \"\"\"return a one "
},
{
"path": "dataset_analysis.py",
"chars": 6498,
"preview": "import os\nimport re\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport seaborn as sns\nimport pandas as pd\nimport "
},
{
"path": "gen_eval_script.py",
"chars": 5582,
"preview": "# input your model name and base_dir\nname = \"sort_bucket_uniform_distribution_max_digits_n_10_max_length_m_10_20000000_p"
},
{
"path": "load_local_model.py",
"chars": 1567,
"preview": "\"\"\"Example for a script to load a local saved model.\n\nUse as e.g.\n\npython load_local_model.py name=A6000amp_b4096_c5_o3_"
},
{
"path": "pretrain.py",
"chars": 16167,
"preview": "\"\"\"Script for a pretraining run.\"\"\"\n\nimport torch\nimport hydra\n\nimport os\nimport time\nimport datetime\nimport logging\nfro"
},
{
"path": "pretty_plotter.py",
"chars": 4874,
"preview": "## combine multiple testing plots and make a pretty one \n\nimport os\nimport numpy as np\nimport json\nimport matplotlib.pat"
},
{
"path": "pretty_plotter_big.py",
"chars": 4877,
"preview": "## combine multiple testing plots and make a pretty one \n\nimport os\nimport numpy as np\nimport json\nimport matplotlib.pat"
},
{
"path": "pretty_plotter_sort.py",
"chars": 5446,
"preview": "import numpy as np\nimport os\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport cv2\n\ndef g"
},
{
"path": "pyproject.toml",
"chars": 113,
"preview": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.black]\nline-length = 140\n"
},
{
"path": "setup.cfg",
"chars": 1952,
"preview": "\n\n[metadata]\nname = cramming\nversion = 0.1.0\nauthor = Sean McLeish\nauthor_email = smcleish@umd.edu\nurl = https://github."
},
{
"path": "shells/addition_ff.sh",
"chars": 7226,
"preview": "## FF\n# nope\npython pretrain.py name=add_bucket_20_20_reverse_all_pad_00_depthrec_16_1_TBPTT_1024_batch_size_512_mask_be"
},
{
"path": "shells/addition_lt.sh",
"chars": 2582,
"preview": "### Looped Transformer experiments\n# vary number of layers in recurrent_block: arch.layers_in_recurrent_block\n# vary num"
},
{
"path": "shells/bitwise_or.sh",
"chars": 7757,
"preview": "# bitwise or is sometimes refered to as pos_arth in the code\n\n## LT\n# NOPE\npython pretrain.py name=pos_or_one_vec_zeros_"
},
{
"path": "shells/evaluation.sh",
"chars": 1504,
"preview": "# there is an automated helper in gen_eval_script.py for generating these evaluation scripts\n\n# Addition\npython arithmet"
},
{
"path": "shells/generate_and_tokenize_data.sh",
"chars": 2216,
"preview": "## Training Data -- these commands approximately correspond to the zipped data we provide\n\n# bitwise or\npython create_po"
},
{
"path": "shells/multiplication.sh",
"chars": 2872,
"preview": "## only Looped Transformer experiments for multiplication\ntorchrun --nproc_per_node=8 --standalone pretrain.py name=mul_"
},
{
"path": "shells/sorting.sh",
"chars": 9279,
"preview": "# REMINDER SET BASE DIR\n\n\n## fire reverse\n## fire reverse recall\n## fire reverse recurrence\n\ntorchrun --nproc_per_node=1"
},
{
"path": "sort_eval.py",
"chars": 17373,
"preview": "import logging\nimport hydra\nfrom omegaconf import OmegaConf\nimport cramming\nimport torch\nfrom safetensors.torch import l"
},
{
"path": "upload_processed_dataset.py",
"chars": 3488,
"preview": "\"\"\"Script to upload a processed dataset to the huggingface hub. You probably don't need this :)\"\"\"\n\n\nimport hydra\nimport"
}
]
About this extraction
This page contains the full source code of the mcleish7/arithmetic GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 132 files (479.5 KB), approximately 123.4k tokens, and a symbol index with 450 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.