Repository: lucidrains/block-recurrent-transformer-pytorch Branch: main Commit: 9f28fcdb387d Files: 10 Total size: 50.2 KB Directory structure: gitextract_0788ko4s/ ├── .github/ │ └── workflows/ │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── block_recurrent_transformer_pytorch/ │ ├── __init__.py │ └── block_recurrent_transformer_pytorch.py ├── data/ │ └── README.md ├── requirements.txt ├── setup.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/python-publish.yml ================================================ # This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries # This workflow uses actions that are not certified by GitHub. # They are provided by a third-party and are governed by # separate terms of service, privacy policy, and support # documentation. name: Upload Python Package on: release: types: [published] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: '3.x' - name: Install dependencies run: | python -m pip install --upgrade pip pip install build - name: Build package run: python -m build - name: Publish package uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Phil Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ## Block Recurrent Transformer - Pytorch Implementation of Block Recurrent Transformer - Pytorch. The highlight of the paper is its reported ability to remember something up to 60k tokens ago. This design is SOTA for recurrent transformers line of research, afaict. It will also include flash attention as well as routed memories of up to 250k tokens using ideas from this paper ## Appreciation - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research ## Install ```bash $ pip install block-recurrent-transformer-pytorch ``` ## Usage ```python import torch from block_recurrent_transformer_pytorch import BlockRecurrentTransformer model = BlockRecurrentTransformer( num_tokens = 20000, # vocab size dim = 512, # model dimensions depth = 6, # depth dim_head = 64, # attention head dimensions heads = 8, # number of attention heads max_seq_len = 1024, # the total receptive field of the transformer, in the paper this was 2 * block size block_width = 512, # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong) num_state_vectors = 512, # number of state vectors, i believe this was a single block size in the paper, but can be any amount recurrent_layers = (4,), # where to place the recurrent layer(s) for states with fixed simple gating use_compressed_mem = False, # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507 compressed_mem_factor = 4, # compression factor of compressed memories use_flash_attn = True # use flash attention, if on pytorch 2.0 ) seq = torch.randint(0, 2000, (1, 1024)) out, mems1, states1 = model(seq) out, mems2, states2 = model(seq, xl_memories = mems1, states = states1) out, mems3, states3 = model(seq, xl_memories = mems2, states = states2) ``` ## Test on Enwik8 First `pip install -r requirements.txt`, then ```bash $ python train.py ``` ## Todo - [x] use dynamic positional bias - [x] add enhanced recurrence - [x] setup local attention blocks, as in the paper - [x] wrapper transformer class for training - [x] take care of generation with recurrence in `RecurrentTrainWrapper` - [x] add ability to dropout to entire memories and states during each segment step during trainng - [x] test full system on enwik8 locally and ablate states and memories and see effects first hand - [x] make sure attention allow for single head key / values too - [x] run a few experiments of fixed gating in regular transformers - does not work - [x] integrate flash attention - [x] cache attention mask + rotary embeddings - [x] add compressed memories - [ ] revisit memformer - [ ] try routing long distance memories of up to 250k using coordinate descent (Wright et al.) ## Citations ```bibtex @article{Hutchins2022BlockRecurrentT, title = {Block-Recurrent Transformers}, author = {DeLesley S. Hutchins and Imanol Schlag and Yuhuai Wu and Ethan Dyer and Behnam Neyshabur}, journal = {ArXiv}, year = {2022}, volume = {abs/2203.07852} } ``` ```bibtex @article{Shazeer2019FastTD, title = {Fast Transformer Decoding: One Write-Head is All You Need}, author = {Noam M. Shazeer}, journal = {ArXiv}, year = {2019}, volume = {abs/1911.02150} } ``` ```bibtex @inproceedings{Sun2022ALT, title = {A Length-Extrapolatable Transformer}, author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei}, year = {2022} } ``` ```bibtex @inproceedings{dao2022flashattention, title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, booktitle = {Advances in Neural Information Processing Systems}, year = {2022} } ``` ```bibtex @inproceedings{Ainslie2023CoLT5FL, title = {CoLT5: Faster Long-Range Transformers with Conditional Computation}, author = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai}, year = {2023} } ``` *Memory is Attention through Time* - Alex Graves ================================================ FILE: block_recurrent_transformer_pytorch/__init__.py ================================================ import torch from packaging import version if version.parse(torch.__version__) >= version.parse('2.0.0'): from einops._torch_specific import allow_ops_in_compiled_graph allow_ops_in_compiled_graph() from block_recurrent_transformer_pytorch.block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper ================================================ FILE: block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py ================================================ import math from random import random from functools import wraps, partial from itertools import zip_longest from collections import namedtuple, defaultdict from packaging import version import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat, pack, unpack from einops.layers.torch import Rearrange from beartype import beartype from beartype.door import is_bearable from beartype.typing import Optional, List, Tuple # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def is_empty(t: torch.Tensor): return t.numel() == 0 def cast_tuple(t, length = 1): return t if isinstance(t, tuple) else ((t,) * length) def all_unique(arr): return len(arr) == len(set(arr)) def eval_decorator(fn): def inner(self, *args, **kwargs): was_training = self.training self.eval() out = fn(self, *args, **kwargs) self.train(was_training) return out return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) def compact(arr): return [*filter(exists, arr)] def and_reduce(arr: List[torch.Tensor]): if len(arr) == 0: return None head, *rest = arr for t in rest: head = head & t return head def safe_cat(*args, dim = 1): args = compact(args) if len(args) == 0: return None return torch.cat(args, dim = dim) def divisible_by(numer, denom): return (numer % denom) == 0 def l2norm(t): return F.normalize(t, dim = -1) def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] def pad_at_dim(t, pad, dim = -1, value = 0.): dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value = value) # bias-less layernorm class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer("beta", torch.zeros(dim)) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) # sampling helpers def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature = 1., dim = -1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) def top_k(logits, thres = 0.9): k = math.ceil((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs # rotary positional embedding w/ xpos # https://arxiv.org/abs/2104.09864 # https://arxiv.org/abs/2212.10554v1 class RotaryEmbedding(nn.Module): def __init__( self, dim, width, scale_base = 512, theta = 10000 ): super().__init__() self.width = width inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent = False) self.scale_base = scale_base scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.register_buffer('scale', scale, persistent = False) self.register_buffer('cached_freqs', None, persistent = False) self.register_buffer('cached_scales', None, persistent = False) @property def device(self): return next(self.buffers()).device def forward(self): device, seq_len = self.device, self.width if exists(self.cached_freqs): cached_seq_len = self.cached_freqs.shape[-2] if cached_seq_len >= seq_len: return self.cached_freqs[:seq_len], self.cached_scales[:seq_len] t = torch.arange(seq_len, device = device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.cat((freqs, freqs), dim = -1) power = (t - (seq_len // 2)) / self.scale_base scale = self.scale ** rearrange(power, 'n -> n 1') scale = torch.cat((scale, scale), dim = -1) self.register_buffer('cached_freqs', freqs, persistent = False) self.register_buffer('cached_scales', scale, persistent = False) return freqs, scale def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(t, pos, scale = 1.): scale = default(scale, 1.) seq_len = t.shape[-2] assert pos.shape[-2] >= seq_len pos = pos[-seq_len:] if isinstance(scale, torch.Tensor): assert scale.shape[-2] >= seq_len scale = scale[-seq_len:] return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale) # memory management class MemoryManager(nn.Module): def __init__( self, dim, *, layers = 1, mem_lengths = 512, compress_factors = 1 ): super().__init__() mem_lengths = cast_tuple(mem_lengths) compress_factors = cast_tuple(compress_factors) assert all([mem_length > 0 for mem_length in mem_lengths]) assert len(mem_lengths) == len(compress_factors) assert layers >= 1 self.mem_lengths = mem_lengths self.compress_factors = compress_factors self.layers = nn.ModuleList([]) for _ in range(layers): compress_fns = nn.ModuleList([]) for compress_factor in compress_factors: compress_fn = nn.Identity() if compress_factor > 1: compress_fn = nn.Sequential( Rearrange('b n d -> b d n'), nn.Conv1d( dim * 2, dim * 2, compress_factor, stride = compress_factor, groups = 2 ), Rearrange('b d n -> b n d'), ) compress_fns.append(compress_fn) self.layers.append(compress_fns) def forward( self, past_memories: List[torch.Tensor], new_memories: List[torch.Tensor] ): next_memories = [] for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers): # edge case if neither memories exist if not (exists(past_memory) or exists(new_memory)): next_memories.append(None) continue next_memory = None for mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns): # first get the memories for the given compression factor "current_memory" current_memory = None if exists(past_memory): past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :] # compress the new memories coming in, based on the compression factors set at init if (not is_empty(new_memory)) and compress_factor > 1: # make sure memory length is divisible by compression factor new_mem_length = new_memory.shape[-2] curtailed_length = (new_mem_length // compress_factor) * compress_factor curtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0) new_memory = new_memory[..., curtailed_slice, :] # compress the memory pushed to the next stage if new_memory.shape[-2] > 0: new_memory = rearrange(new_memory, 'm b n d -> b n (m d)') new_memory = compress_fn(new_memory) new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2) # fifo memory queue # add the new memory on the right current_memory = safe_cat(current_memory, new_memory, dim = -2) # "new" memory is new with respect to the next compressed segment new_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :] # concat the new memory to the left into the past next_memory = safe_cat(current_memory, next_memory, dim = -2) next_memories.append(next_memory) return next_memories # maybe flash attention, if using pytorch 2.0 # constants Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) # state container class StateContainer(nn.Module): def __init__( self, dim, *, num_state_vectors, dim_head = 64, heads = 8, qk_rmsnorm = False, qk_rmsnorm_scale = 8, use_flash_attn = False ): super().__init__() assert num_state_vectors > 0 self.heads = heads inner_dim = dim_head * heads self.state_norm = LayerNorm(dim) self.q_to_state = nn.Linear(dim, inner_dim, bias = False) self.q_from_state = nn.Linear(dim, inner_dim, bias = False) self.state_to_q = nn.Linear(dim, inner_dim, bias = False) self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim)) torch.nn.init.normal_(self.init_state, 0, .1) self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim)) # NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm torch.nn.init.normal_(self.state_pos_ids, 0, 1) self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False) self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn) self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn) self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn) # gating related parameters - using the fixed simple config self.state_out_to_gate = nn.Linear(dim, dim) self.learned_ema_beta = nn.Parameter(torch.randn(dim)) torch.nn.init.normal_(self.learned_ema_beta, 0, .1) # since each read should be followed by a write, just store cache in the container self.cache = None self.next_read_state = None def set_next_read_state( self, states ): if not exists(states): states = self.init_state self.next_read_state = (states,) def read(self, x): assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state' states, = self.next_read_state self.next_read_state = None # pre norm state for attention normed_states = self.state_norm(states) # add the positional ids, as stated in the paper critical for it to work normed_states = normed_states + self.state_pos_ids # get queries for cross attention, which they do not share, although they share key / values. another intriguing detail q_to_state = self.q_to_state(x) q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads) # self attention qkv for states state_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1) # cross attend to the past states key values to_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v) to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)') # cache for next write self.cache = (states, normed_states, state_k, state_v) return to_state_out def write( self, *, memories ): assert exists(self.cache) k, v = memories batch = k.shape[0] # get cached values from the previous read states, normed_states, state_k, state_v = self.cache self.cache = None # derive queries q_from_state = self.q_from_state(normed_states) q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads) state_q = self.state_to_q(normed_states) state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)' state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch) # states must also undergo self attention if q_from_state.ndim == 3: q_from_state = repeat(q_from_state, '... -> b ...', b = batch) state_out = self.state_self_attn(state_q, state_k, state_v) from_state_out = self.from_state_cross_attn(q_from_state, k, v) state_out = torch.cat((state_out, from_state_out), dim = -1) state_out = rearrange(state_out, 'b h n d -> b n (h d)') state_out = self.to_state_out(state_out) # use the best performing configuration # fixed simple gate - nothing more than a learned EMA with some resemblance to highway networks z = self.state_out_to_gate(state_out) learned_ema_decay = self.learned_ema_beta.sigmoid() # set new state with the learned EMA gating return learned_ema_decay * z + (1 - learned_ema_decay) * states def forward(self, x): raise NotImplementedError # main class class Attend(nn.Module): def __init__( self, causal = False, use_flash_attn = False ): super().__init__() self.causal = causal self.register_buffer("mask", None, persistent=False) self.use_flash_attn = use_flash_attn assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' # determine efficient attention configs for cuda and cpu self.cpu_config = Config(True, True, True) self.cuda_config = None if not torch.cuda.is_available() or not use_flash_attn: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: print_once('A100 GPU detected, using flash attention if input tensor is on cuda') self.cuda_config = Config(True, False, False) else: print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') self.cuda_config = Config(False, True, True) def get_mask(self, n, device): if exists(self.mask) and self.mask.shape[-1] >= n: return self.mask[:n, :n] mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) self.register_buffer("mask", mask, persistent=False) return mask def flash_attn(self, q, k, v, mask = None): _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda # Recommended for multi-query single-key-value attention by Tri Dao # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) if k.ndim == 3: k = repeat(k, 'b ... -> b h ...', h = q.shape[1]) if v.ndim == 3: v = repeat(v, 'b ... -> b h ...', h = q.shape[1]) # Check if mask exists and expand to compatible shape # The mask is B L, so it would have to be expanded to B H N L masks = [] if self.causal: i, j = q_len, k_len causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1) masks.append(~causal_mask) if exists(mask): if mask.ndim != 2: mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0]) masks.append(mask) attn_mask = and_reduce(masks) # Check if there is a compatible device for flash attention config = self.cuda_config if is_cuda else self.cpu_config # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( q, k, v, attn_mask = attn_mask ) return out def forward(self, q, k, v, mask = None, use_flash_attn = None): use_flash_attn = default(use_flash_attn, self.use_flash_attn) b, n, device = q.shape[0], q.shape[-2], q.device q, ps = pack_one(q, '* h n d') k, _ = pack_one(k, '* n d') v, _ = pack_one(v, '* n d') if use_flash_attn: out = self.flash_attn(q, k, v, mask = mask) return unpack_one(out, ps, '* h n d') scale = q.shape[-1] ** -0.5 k_einsum = 'b j d' if k.ndim == 3 else 'b h j d' v_einsum = 'b j d' if v.ndim == 3 else 'b h j d' # similarity sim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale # key padding mask if exists(mask): if mask.ndim != 2: mask = repeat(mask, 'w ... -> (b w) ...', b = b) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # causal mask if self.causal: i, j = sim.shape[-2:] causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # attention attn = sim.softmax(dim=-1) # aggregate values out = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v) return unpack_one(out, ps, '* h n d') # geglu feedforward class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4): inner_dim = int(dim * mult * 2 / 3) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Linear(inner_dim, dim, bias = False) ) # attention class Attention(nn.Module): def __init__( self, dim_head, causal = False, qk_rmsnorm = False, qk_rmsnorm_scale = 8, use_flash_attn = False ): super().__init__() self.causal = causal self.qk_rmsnorm = qk_rmsnorm self.qk_rmsnorm_scale = qk_rmsnorm_scale self.attend = Attend(causal = causal, use_flash_attn = use_flash_attn) if qk_rmsnorm: self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) def forward( self, q, k, v, mask = None, rotary_pos_emb = None, xpos_scale = None ): scale = q.shape[-1] ** -0.5 if self.qk_rmsnorm: q, k = map(l2norm, (q, k)) scale = self.qk_rmsnorm_scale if self.qk_rmsnorm: q = q * self.q_scale k = k * self.k_scale # rotary positional embedding with xpos for length extrapolation if exists(rotary_pos_emb): q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale) k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1) # attention out = self.attend(q, k, v, mask = mask) return out class AttentionBlock(nn.Module): def __init__( self, dim, block_width, dim_head = 64, heads = 8, qk_rmsnorm = False, qk_rmsnorm_scale = 8, use_flash_attn = False, num_state_vectors = 0, num_external_state_reads = 0, state_read_before_write = True # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer ): super().__init__() inner_dim = dim_head * heads self.heads = heads self.norm = LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn) self.block_width = block_width self.is_recurrent_layer = num_state_vectors > 0 # decide how many states this attention layer is going to read from num_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_reads self.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False) if not self.is_recurrent_layer: return self.state_read_before_write = state_read_before_write self.state_container = StateContainer( dim, dim_head = dim_head, heads = heads, num_state_vectors = num_state_vectors, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn ) @property def device(self): return next(self.parameters()).device def forward( self, x, rotary_pos_emb = None, xpos_scale = None, attn_mask = None, xl_memories: Optional[torch.Tensor] = None, read_from_state_containers: List[StateContainer] = [] ): batch, seq_len, _, width, device = *x.shape, self.block_width, self.device # pre normalization x = self.norm(x) # queries, keys, values and split out heads q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads) q = split_head(q) # save the last key / values as memories for recurrence memories = torch.stack((k, v)) mem_len = 0 if exists(xl_memories): # if past memories are passed in, concat as the first bucket mem_len = xl_memories.shape[-2] past_k, past_v = xl_memories k = torch.cat((past_k, k), dim = 1) v = torch.cat((past_v, v), dim = 1) # handle cropping of attention mask and positional embeddings if exists(attn_mask): attn_mask = attn_mask[:seq_len, :seq_len] attn_mask = F.pad(attn_mask, (mem_len, 0), value = True) # attention, but of course out = self.attn( q, k, v, rotary_pos_emb = rotary_pos_emb, xpos_scale = xpos_scale, mask = attn_mask ) # merge heads out = rearrange(out, 'b h n d -> b n (h d)') # early return if not a recurrent layer if not self.is_recurrent_layer and len(read_from_state_containers) == 0: return self.to_out(out), memories, None # whether to read from own state container, default to on, but may pass in more if self.is_recurrent_layer and self.state_read_before_write: read_from_state_containers = [self.state_container, *read_from_state_containers] for read_state_container in read_from_state_containers: # read from the states ... to_state_out = read_state_container.read(x) # and concat it to the output of self-attention out = torch.cat((out, to_state_out), dim = -1) new_states = None if self.is_recurrent_layer: # then write to the states as well if need be new_states = self.state_container.write(memories = memories) return self.to_out(out), memories, new_states # classes @beartype class BlockRecurrentTransformer(nn.Module): def __init__( self, *, num_tokens, dim, depth, dim_head = 64, heads = 8, all_layers_qk_rmsnorm = False, ff_mult = 4, max_seq_len = 1024, block_width = 512, recurrent_layers: Optional[Tuple[int, ...]] = None, read_recurrent_layers: Optional[Tuple[int, ...]] = None, num_state_vectors = None, ignore_index = -100, use_flash_attn = False, use_compressed_mem = False, compressed_mem_factor = 4 ): super().__init__() num_state_vectors = default(num_state_vectors, block_width) # set recurrent layers recurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the network assert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}' assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers' self.recurrent_layers = recurrent_layers # set read recurrent layers read_recurrent_layers = default(read_recurrent_layers, recurrent_layers) assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer' assert all([0 < layer <= depth for layer in read_recurrent_layers]) assert len(read_recurrent_layers) == len(recurrent_layers) self.read_recurrent_layers = read_recurrent_layers # token embedding self.token_emb = nn.Embedding(num_tokens, dim) self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width) self.layers = nn.ModuleList([]) self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)} self.read_state_router = defaultdict(list) for layer in range(1, depth + 1): is_recurrent_layer = layer in self.recurrent_layers layer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0 num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers]) # only layers with xl memories # or has recurrence in horizontal direction # use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper) # one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = True qk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layer attn_block = AttentionBlock( dim, block_width = block_width, dim_head = dim_head, heads = heads, qk_rmsnorm = qk_rmsnorm, num_state_vectors = layer_num_state_vectors, use_flash_attn = use_flash_attn, num_external_state_reads = num_external_state_reads, state_read_before_write = False, ) ff_block = FeedForward(dim, mult = ff_mult) if is_recurrent_layer: read_layer = self.write_to_read_map[layer] self.read_state_router[read_layer].append(attn_block.state_container) self.layers.append(nn.ModuleList([ attn_block, ff_block ])) # (compressed) memory management self.mem_manager = MemoryManager( dim = dim_head, layers = depth, mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2), compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor) ) # to logits self.to_logits = nn.Sequential( LayerNorm(dim), nn.Linear(dim, num_tokens, bias = False) ) self.max_seq_len = max_seq_len self.block_width = block_width assert divisible_by(max_seq_len, block_width) self.ignore_index = ignore_index self.register_buffer('cached_causal_attn_mask', None, persistent = False) @property def device(self): return next(self.parameters()).device def get_causal_attn_mask(self, width): if exists(self.cached_causal_attn_mask): cached_mask = self.cached_causal_attn_mask cached_width = cached_mask.shape[-2] padding = (width - cached_width) // 2 j_slice = Ellipsis if padding == 0 else slice(padding, -padding) return cached_mask[:cached_width, j_slice] device = self.device causal_mask = torch.ones((width, width), device = device, dtype = torch.bool).triu(1) return ~causal_mask @torch.no_grad() @eval_decorator def generate( self, prime, length = None, xl_memories: List[torch.Tensor] = [], states: List[torch.Tensor] = [], temperature = 1., filter_thres = 0.9, return_memories_and_states = False ): length = default(length, self.max_seq_len + 1) start_len = prime.shape[-1] assert start_len < self.max_seq_len assert length <= (self.max_seq_len + 1) assert start_len < length output = prime memories = [] for ind in range(length - start_len): logits, next_memories, next_states = self.forward( output, xl_memories = xl_memories, states = states ) logits = logits[:, -1] filtered_logits = top_k(logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) if divisible_by(output.shape[-1] - 1, self.max_seq_len): # on the sampling of the last token in the current window, set new memories and states memories = next_memories states = next_states output = output[:, start_len:] if return_memories_and_states: return output, memories, states return output def forward( self, x, return_loss = False, xl_memories: List[torch.Tensor] = [], states: List[torch.Tensor] = [], return_memories_and_states = None # can force to either return memory + state or not. by default will only return when number of tokens == max_seq_len ): device = x.device if return_loss: x, labels = x[:, :-1], x[:, 1:] # get sequence length i and j for dynamic pos bias assert x.shape[-1] <= self.max_seq_len w = self.block_width # token embedding x = self.token_emb(x) # dynamic pos bias attn_mask = self.get_causal_attn_mask(w) rotary_pos_emb, xpos_scale = self.rotary_pos_emb() # only return memories and state if at the full block width, but can be overridden return_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2]) # ready output tensor, to be concatted to block by block batch, _, dim = x.shape out = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device) # split input into blocks of width w input_blocks = x.split(w, dim = -2) # process each block at a time for input_block in input_blocks: input_block_length = input_block.shape[-2] # ready xl memories and states iter_xl_memories = iter(xl_memories) iter_states = iter(states) next_xl_memories = [] next_states = [] # set the states on the appropriate state containers for attn, _ in self.layers: if not attn.is_recurrent_layer: continue attn.state_container.set_next_read_state(next(iter_states, None)) # go through layers for ind, (attn, ff) in enumerate(self.layers): # determine if the layer requires transformer xl memories layer = ind + 1 # whether to pass in xl memories attn_kwargs = dict( rotary_pos_emb = rotary_pos_emb, xpos_scale = xpos_scale, attn_mask = attn_mask, xl_memories = next(iter_xl_memories, None), read_from_state_containers = self.read_state_router[layer] ) # attention layer residual = input_block attn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs) if exists(layer_xl_memories): next_xl_memories.append(layer_xl_memories) if exists(layer_next_states): next_states.append(layer_next_states) input_block = attn_branch_out + residual # feedforward layer input_block = ff(input_block) + input_block # concat to output out = torch.cat((out, input_block), dim = -2) # set new xl memories and states states = next_states if input_block_length == w: xl_memories = self.mem_manager(xl_memories, next_xl_memories) # project to logits logits = self.to_logits(out) # detach the states and memories returned_next_states = list(map(torch.detach, states)) if return_memories_and_states else None returned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None # whether to return logits if not return_loss: return logits, returned_next_xl_memories, returned_next_states # cross entropy loss logits = rearrange(logits, 'b n c -> b c n') loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index) return loss, returned_next_xl_memories, returned_next_states # recurrent trainer wrapper @beartype class RecurrentTrainerWrapper(nn.Module): def __init__( self, transformer: BlockRecurrentTransformer, xl_memories_dropout = 0., state_dropout = 0. ): super().__init__() self.transformer = transformer self.seq_len = transformer.max_seq_len self.xl_memories_dropout = xl_memories_dropout self.state_dropout = state_dropout @eval_decorator @torch.no_grad() def generate( self, prime, length, **kwargs ): seq_len = self.seq_len start_len = prime.shape[-1] assert start_len < length output = prime current_len = start_len memories = [] states = [] # determine lengths has_remainder = not divisible_by(length, seq_len) remainder_amount = length % seq_len total_segments = math.ceil(length / seq_len) if not has_remainder: lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len) elif remainder_amount == 1: lengths = (seq_len + 1,) * (total_segments - 1) else: lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount) # loop through lengths for next_length in lengths: segment_output, memories, states = self.transformer.generate( output[:, -current_len:], length = next_length, xl_memories = memories, states = states, return_memories_and_states = True, **kwargs ) output = torch.cat((output, segment_output), dim = -1) current_len = 1 return output[:, start_len:] def forward( self, x, return_memories_and_states = False ): total_seq_len, seq_len = x.shape[1], self.seq_len assert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training' segments = total_seq_len // seq_len total_loss = 0. memories = [] states = [] for ind in range(segments): start = ind * seq_len end = start + seq_len + 1 if self.training and random() < self.xl_memories_dropout: memories.clear() if self.training and random() < self.state_dropout: states.clear() loss, memories, states = self.transformer( x[:, start:end], xl_memories = memories, states = states, return_loss = True ) total_loss = total_loss + (loss / segments) if return_memories_and_states: return total_loss, memories, states return total_loss ================================================ FILE: data/README.md ================================================ # Data source The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ ================================================ FILE: requirements.txt ================================================ accelerate tqdm ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name = 'block-recurrent-transformer-pytorch', packages = find_packages(exclude=[]), version = '0.4.4', license='MIT', description = 'Block Recurrent Transformer - Pytorch', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', long_description_content_type = 'text/markdown', url = 'https://github.com/lucidrains/block-recurrent-transformer-pytorch', keywords = [ 'artificial intelligence', 'deep learning', 'transformers', 'attention mechanism', 'recurrence' ], install_requires=[ 'beartype', 'einops>=0.6.1', 'memorizing-transformers-pytorch>=0.4.0', 'torch>=1.6', ], classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 3.6', ], ) ================================================ FILE: train.py ================================================ import gzip import random import tqdm import numpy as np import torch from torch.optim import Adam from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset from accelerate import Accelerator from block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper # constants NUM_BATCHES = int(1e5) BATCH_SIZE = 4 GRADIENT_ACCUMULATE_EVERY = 4 LEARNING_RATE = 1e-4 VALIDATE_EVERY = 100 PRIME_LENGTH = 128 GENERATE_EVERY = 250 GENERATE_LENGTH = 2048 SEQ_LEN = 2048 # helpers def cycle(loader): while True: for data in loader: yield data def decode_token(token): return str(chr(max(32, token))) def decode_tokens(tokens): return "".join(list(map(decode_token, tokens))) # accelerator accelerator = Accelerator() device = accelerator.device acc_print = accelerator.print # instantiate palm model = BlockRecurrentTransformer( num_tokens = 256, dim = 512, depth = 6, dim_head = 64, heads = 8, max_seq_len = 1024, block_width = 512, num_state_vectors = 512, recurrent_layers = (4,), use_flash_attn = True ) train_wrapper = RecurrentTrainerWrapper( model, xl_memories_dropout = 0.1, state_dropout = 0.1, ) model.to(device) # prepare enwik8 data with gzip.open("./data/enwik8.gz") as file: data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() np_train, np_valid = np.split(data, [int(90e6)]) data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid) class TextSamplerDataset(Dataset): def __init__(self, data, seq_len): super().__init__() self.data = data self.seq_len = seq_len def __getitem__(self, index): rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() return full_seq.to(device) def __len__(self): return self.data.size(0) // self.seq_len train_dataset = TextSamplerDataset(data_train, SEQ_LEN) val_dataset = TextSamplerDataset(data_val, SEQ_LEN) train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) # optimizer optim = Adam(model.parameters(), lr = LEARNING_RATE) model, optim, train_loader, val_loader = accelerator.prepare( model, optim, train_loader, val_loader ) # training for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): model.train() for _ in range(GRADIENT_ACCUMULATE_EVERY): loss = train_wrapper(next(train_loader)) accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY) acc_print(f"training loss: {loss.item()}") accelerator.clip_grad_norm_(model.parameters(), 0.5) optim.step() optim.zero_grad() if i % VALIDATE_EVERY == 0: model.eval() with torch.no_grad(): loss = train_wrapper(next(val_loader)) acc_print(f"validation loss: {loss.item()}") if i % GENERATE_EVERY == 0: model.eval() inp = random.choice(val_dataset)[:PRIME_LENGTH] prime = decode_tokens(inp) acc_print(f"%s \n\n %s", (prime, "*" * 100)) sample = train_wrapper.generate(inp[None, ...], length = GENERATE_LENGTH) output_str = decode_tokens(sample[0]) acc_print(output_str, "\n")