main 9f28fcdb387d cached
10 files
50.2 KB
12.9k tokens
69 symbols
1 requests
Download .txt
Repository: lucidrains/block-recurrent-transformer-pytorch
Branch: main
Commit: 9f28fcdb387d
Files: 10
Total size: 50.2 KB

Directory structure:
gitextract_0788ko4s/

├── .github/
│   └── workflows/
│       └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── block_recurrent_transformer_pytorch/
│   ├── __init__.py
│   └── block_recurrent_transformer_pytorch.py
├── data/
│   └── README.md
├── requirements.txt
├── setup.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/python-publish.yml
================================================

  
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
  release:
    types: [published]

jobs:
  deploy:

    runs-on: ubuntu-latest

    steps:
    - uses: actions/checkout@v2
    - name: Set up Python
      uses: actions/setup-python@v2
      with:
        python-version: '3.x'
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install build
    - name: Build package
      run: python -m build
    - name: Publish package
      uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
      with:
        user: __token__
        password: ${{ secrets.PYPI_API_TOKEN }}


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 Phil Wang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
<img src="./block-recurrent-transformer.png" width="450px"></img>

## Block Recurrent Transformer - Pytorch

Implementation of <a href="https://arxiv.org/abs/2203.07852">Block Recurrent Transformer</a> - Pytorch. The highlight of the paper is its reported ability to remember something up to 60k tokens ago.

This design is SOTA for recurrent transformers line of research, afaict.

It will also include <a href="https://arxiv.org/abs/2205.14135">flash attention</a> as well as routed memories of up to 250k tokens using ideas from <a href="https://github.com/lucidrains/CoLT5-attention">this paper</a>

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

## Install

```bash
$ pip install block-recurrent-transformer-pytorch
```

## Usage

```python
import torch
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer

model = BlockRecurrentTransformer(
    num_tokens = 20000,             # vocab size
    dim = 512,                      # model dimensions
    depth = 6,                      # depth
    dim_head = 64,                  # attention head dimensions
    heads = 8,                      # number of attention heads
    max_seq_len = 1024,             # the total receptive field of the transformer, in the paper this was 2 * block size
    block_width = 512,              # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)
    num_state_vectors = 512,        # number of state vectors, i believe this was a single block size in the paper, but can be any amount
    recurrent_layers = (4,),        # where to place the recurrent layer(s) for states with fixed simple gating
    use_compressed_mem = False,     # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507
    compressed_mem_factor = 4,      # compression factor of compressed memories
    use_flash_attn = True           # use flash attention, if on pytorch 2.0
)

seq = torch.randint(0, 2000, (1, 1024))

out, mems1, states1 = model(seq)
out, mems2, states2 = model(seq, xl_memories = mems1, states = states1)
out, mems3, states3 = model(seq, xl_memories = mems2, states = states2)
```

## Test on Enwik8

First `pip install -r requirements.txt`, then

```bash
$ python train.py
```

## Todo

- [x] use dynamic positional bias
- [x] add enhanced recurrence
- [x] setup local attention blocks, as in the paper
- [x] wrapper transformer class for training
- [x] take care of generation with recurrence in `RecurrentTrainWrapper`
- [x] add ability to dropout to entire memories and states during each segment step during trainng
- [x] test full system on enwik8 locally and ablate states and memories and see effects first  hand
- [x] make sure attention allow for single head key / values too
- [x] run a few experiments of fixed gating in regular transformers - does not work
- [x] integrate <a href="https://github.com/hazyresearch/flash-attention">flash attention</a>
- [x] cache attention mask + rotary embeddings
- [x] add <a href="https://github.com/lucidrains/compressive-transformer-pytorch">compressed memories</a>

- [ ] revisit <a href="https://github.com/lucidrains/memformer">memformer</a>
- [ ] try routing long distance memories of up to 250k using coordinate descent (Wright et al.)

## Citations

```bibtex
@article{Hutchins2022BlockRecurrentT,
    title   = {Block-Recurrent Transformers},
    author  = {DeLesley S. Hutchins and Imanol Schlag and Yuhuai Wu and Ethan Dyer and Behnam Neyshabur},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.07852}
}
```

```bibtex
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
```

```bibtex
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
```

```bibtex
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
```

```bibtex
@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
```

*Memory is Attention through Time* - Alex Graves


================================================
FILE: block_recurrent_transformer_pytorch/__init__.py
================================================
import torch
from packaging import version

if version.parse(torch.__version__) >= version.parse('2.0.0'):
    from einops._torch_specific import allow_ops_in_compiled_graph
    allow_ops_in_compiled_graph()

from block_recurrent_transformer_pytorch.block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper


================================================
FILE: block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py
================================================
import math
from random import random
from functools import wraps, partial
from itertools import zip_longest
from collections import namedtuple, defaultdict
from packaging import version


import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, List, Tuple

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def is_empty(t: torch.Tensor):
    return t.numel() == 0

def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

def all_unique(arr):
    return len(arr) == len(set(arr))

def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)

def compact(arr):
    return [*filter(exists, arr)]

def and_reduce(arr: List[torch.Tensor]):
    if len(arr) == 0:
        return None
    head, *rest = arr
    for t in rest:
        head = head & t
    return head

def safe_cat(*args, dim = 1):
    args = compact(args)

    if len(args) == 0:
        return None

    return torch.cat(args, dim = dim)

def divisible_by(numer, denom):
    return (numer % denom) == 0

def l2norm(t):
    return F.normalize(t, dim = -1)

def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

def pad_at_dim(t, pad, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# bias-less layernorm

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        width,
        scale_base = 512,
        theta = 10000
    ):
        super().__init__()
        self.width = width

        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent = False)

        self.scale_base = scale_base
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.register_buffer('scale', scale, persistent = False)

        self.register_buffer('cached_freqs', None, persistent = False)
        self.register_buffer('cached_scales', None, persistent = False)

    @property
    def device(self):
        return next(self.buffers()).device

    def forward(self):
        device, seq_len = self.device, self.width

        if exists(self.cached_freqs):
            cached_seq_len = self.cached_freqs.shape[-2]
            if cached_seq_len >= seq_len:
                return self.cached_freqs[:seq_len], self.cached_scales[:seq_len]

        t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)

        power = (t - (seq_len // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        self.register_buffer('cached_freqs', freqs, persistent = False)
        self.register_buffer('cached_scales', scale, persistent = False)
        return freqs, scale

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t, pos, scale = 1.):
    scale = default(scale, 1.)

    seq_len = t.shape[-2]

    assert pos.shape[-2] >= seq_len

    pos = pos[-seq_len:]

    if isinstance(scale, torch.Tensor):
        assert scale.shape[-2] >= seq_len
        scale = scale[-seq_len:]

    return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)

# memory management

class MemoryManager(nn.Module):
    def __init__(
        self,
        dim,
        *,
        layers = 1,
        mem_lengths = 512,
        compress_factors = 1
    ):
        super().__init__()
        mem_lengths = cast_tuple(mem_lengths)
        compress_factors = cast_tuple(compress_factors)

        assert all([mem_length > 0 for mem_length in mem_lengths])
        assert len(mem_lengths) == len(compress_factors)
        assert layers >= 1

        self.mem_lengths = mem_lengths
        self.compress_factors = compress_factors

        self.layers = nn.ModuleList([])

        for _ in range(layers):
            compress_fns = nn.ModuleList([])

            for compress_factor in compress_factors:
                compress_fn = nn.Identity()
                if compress_factor > 1:
                    compress_fn = nn.Sequential(
                        Rearrange('b n d -> b d n'),
                        nn.Conv1d(
                            dim * 2,
                            dim * 2,
                            compress_factor,
                            stride = compress_factor,
                            groups = 2
                        ),
                        Rearrange('b d n -> b n d'),
                    )

                compress_fns.append(compress_fn)

            self.layers.append(compress_fns)

    def forward(
        self,
        past_memories: List[torch.Tensor],
        new_memories: List[torch.Tensor]
    ):
        next_memories = []

        for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers):

            # edge case if neither memories exist

            if not (exists(past_memory) or exists(new_memory)):
                next_memories.append(None)
                continue

            next_memory = None

            for mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns):

                # first get the memories for the given compression factor "current_memory"

                current_memory = None
                if exists(past_memory):
                    past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :]

                # compress the new memories coming in, based on the compression factors set at init

                if (not is_empty(new_memory)) and compress_factor > 1:
                    # make sure memory length is divisible by compression factor

                    new_mem_length = new_memory.shape[-2]

                    curtailed_length = (new_mem_length // compress_factor) * compress_factor

                    curtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0)
                    new_memory = new_memory[..., curtailed_slice, :]

                    # compress the memory pushed to the next stage

                    if new_memory.shape[-2] > 0:
                        new_memory = rearrange(new_memory, 'm b n d -> b n (m d)')
                        new_memory = compress_fn(new_memory)
                        new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2)

                # fifo memory queue
                # add the new memory on the right

                current_memory = safe_cat(current_memory, new_memory, dim = -2)
                # "new" memory is new with respect to the next compressed segment

                new_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :]
                # concat the new memory to the left into the past

                next_memory = safe_cat(current_memory, next_memory, dim = -2)

            next_memories.append(next_memory)

        return next_memories

# maybe flash attention, if using pytorch 2.0

# constants

Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# state container

class StateContainer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_state_vectors,
        dim_head = 64,
        heads = 8,
        qk_rmsnorm = False,
        qk_rmsnorm_scale = 8,
        use_flash_attn = False
    ):
        super().__init__()
        assert num_state_vectors > 0
        self.heads = heads
        inner_dim = dim_head * heads

        self.state_norm = LayerNorm(dim)

        self.q_to_state = nn.Linear(dim, inner_dim, bias = False)
        self.q_from_state = nn.Linear(dim, inner_dim, bias = False)

        self.state_to_q = nn.Linear(dim, inner_dim, bias = False)
        self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))
        torch.nn.init.normal_(self.init_state, 0, .1)
        self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))
        # NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm
        torch.nn.init.normal_(self.state_pos_ids, 0, 1)

        self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)

        self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

        self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)
        self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

        # gating related parameters - using the fixed simple config

        self.state_out_to_gate = nn.Linear(dim, dim)
        self.learned_ema_beta = nn.Parameter(torch.randn(dim))
        torch.nn.init.normal_(self.learned_ema_beta, 0, .1)

        # since each read should be followed by a write, just store cache in the container

        self.cache = None
        self.next_read_state = None

    def set_next_read_state(
        self,
        states
    ):
        if not exists(states):
            states = self.init_state

        self.next_read_state = (states,)

    def read(self, x):
        assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'

        states, = self.next_read_state
        self.next_read_state = None

        # pre norm state for attention

        normed_states = self.state_norm(states)

        # add the positional ids, as stated in the paper critical for it to work

        normed_states = normed_states + self.state_pos_ids

        # get queries for cross attention, which they do not share, although they share key / values. another intriguing detail

        q_to_state = self.q_to_state(x)
        q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads)

        # self attention qkv for states

        state_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1)

        # cross attend to the past states key values

        to_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)

        to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)')

        # cache for next write

        self.cache = (states, normed_states, state_k, state_v)

        return to_state_out

    def write(
        self,
        *,
        memories
    ):
        assert exists(self.cache)

        k, v = memories
        batch = k.shape[0]

        # get cached values from the previous read

        states, normed_states, state_k, state_v = self.cache

        self.cache = None

        # derive queries

        q_from_state = self.q_from_state(normed_states)
        q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads)

        state_q = self.state_to_q(normed_states)
        state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)'
        state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch)

        # states must also undergo self attention

        if q_from_state.ndim == 3:
            q_from_state = repeat(q_from_state, '... -> b ...', b = batch)

        state_out = self.state_self_attn(state_q, state_k, state_v)

        from_state_out = self.from_state_cross_attn(q_from_state, k, v)

        state_out = torch.cat((state_out, from_state_out), dim = -1)
        state_out = rearrange(state_out, 'b h n d -> b n (h d)')

        state_out = self.to_state_out(state_out)

        # use the best performing configuration
        # fixed simple gate - nothing more than a learned EMA with some resemblance to highway networks

        z = self.state_out_to_gate(state_out)
        learned_ema_decay = self.learned_ema_beta.sigmoid()

        # set new state with the learned EMA gating

        return learned_ema_decay * z + (1 - learned_ema_decay) * states

    def forward(self, x):
        raise NotImplementedError

# main class

class Attend(nn.Module):
    def __init__(
        self,
        causal = False,
        use_flash_attn = False
    ):
        super().__init__()
        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.use_flash_attn = use_flash_attn
        assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash_attn:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # Recommended for multi-query single-key-value attention by Tri Dao
        # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

        if k.ndim == 3:
            k = repeat(k, 'b ... -> b h ...', h = q.shape[1])

        if v.ndim == 3:
            v = repeat(v, 'b ... -> b h ...', h = q.shape[1])

        # Check if mask exists and expand to compatible shape
        # The mask is B L, so it would have to be expanded to B H N L

        masks = []

        if self.causal:
            i, j = q_len, k_len
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
            masks.append(~causal_mask)

        if exists(mask):
            if mask.ndim != 2:
                mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])

            masks.append(mask)

        attn_mask = and_reduce(masks)

        # Check if there is a compatible device for flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = attn_mask
            )

        return out

    def forward(self, q, k, v, mask = None, use_flash_attn = None):
        use_flash_attn = default(use_flash_attn, self.use_flash_attn)

        b, n, device = q.shape[0], q.shape[-2], q.device

        q, ps = pack_one(q, '* h n d')
        k, _ = pack_one(k, '* n d')
        v, _ = pack_one(v, '* n d')

        if use_flash_attn:
            out = self.flash_attn(q, k, v, mask = mask)
            return unpack_one(out, ps, '* h n d')

        scale = q.shape[-1] ** -0.5

        k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'
        v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'

        # similarity

        sim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale

        # key padding mask

        if exists(mask):
            if mask.ndim != 2:
                mask = repeat(mask, 'w ... -> (b w) ...', b = b)

            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # causal mask

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # attention

        attn = sim.softmax(dim=-1)

        # aggregate values

        out = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v)

        return unpack_one(out, ps, '* h n d')

# geglu feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult * 2 / 3)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

# attention

class Attention(nn.Module):
    def __init__(
        self,
        dim_head,
        causal = False,
        qk_rmsnorm = False,
        qk_rmsnorm_scale = 8,
        use_flash_attn = False
    ):
        super().__init__()
        self.causal = causal

        self.qk_rmsnorm = qk_rmsnorm
        self.qk_rmsnorm_scale = qk_rmsnorm_scale

        self.attend = Attend(causal = causal, use_flash_attn = use_flash_attn)

        if qk_rmsnorm:
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            self.k_scale = nn.Parameter(torch.ones(dim_head))

    def forward(
        self,
        q, k, v,
        mask = None,
        rotary_pos_emb = None,
        xpos_scale = None
    ):

        scale = q.shape[-1] ** -0.5

        if self.qk_rmsnorm:
            q, k = map(l2norm, (q, k))
            scale = self.qk_rmsnorm_scale

        if self.qk_rmsnorm:
            q = q * self.q_scale
            k = k * self.k_scale

        # rotary positional embedding with xpos for length extrapolation

        if exists(rotary_pos_emb):
            q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)
            k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1)

        # attention

        out = self.attend(q, k, v, mask = mask)

        return out

class AttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        block_width,
        dim_head = 64,
        heads = 8,
        qk_rmsnorm = False,
        qk_rmsnorm_scale = 8,
        use_flash_attn = False,
        num_state_vectors = 0,
        num_external_state_reads = 0,
        state_read_before_write = True  # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads

        self.norm = LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)

        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

        self.block_width = block_width
        self.is_recurrent_layer = num_state_vectors > 0

        # decide how many states this attention layer is going to read from

        num_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_reads

        self.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False)

        if not self.is_recurrent_layer:
            return

        self.state_read_before_write = state_read_before_write

        self.state_container = StateContainer(
            dim,
            dim_head = dim_head,
            heads = heads,
            num_state_vectors = num_state_vectors,
            qk_rmsnorm = qk_rmsnorm,
            qk_rmsnorm_scale = qk_rmsnorm_scale,
            use_flash_attn = use_flash_attn
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        x,
        rotary_pos_emb = None,
        xpos_scale = None,
        attn_mask = None,
        xl_memories: Optional[torch.Tensor] = None,
        read_from_state_containers: List[StateContainer] = []
    ):
        batch, seq_len, _, width, device = *x.shape, self.block_width, self.device

        # pre normalization

        x = self.norm(x)

        # queries, keys, values and split out heads

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads)
        q = split_head(q)

        # save the last key / values as memories for recurrence

        memories = torch.stack((k, v))

        mem_len = 0

        if exists(xl_memories):
            # if past memories are passed in, concat as the first bucket
            mem_len = xl_memories.shape[-2]
            past_k, past_v = xl_memories
            k = torch.cat((past_k, k), dim = 1)
            v = torch.cat((past_v, v), dim = 1)

        # handle cropping of attention mask and positional embeddings

        if exists(attn_mask):
            attn_mask = attn_mask[:seq_len, :seq_len]
            attn_mask = F.pad(attn_mask, (mem_len, 0), value = True)

        # attention, but of course

        out = self.attn(
            q, k, v,
            rotary_pos_emb = rotary_pos_emb,
            xpos_scale = xpos_scale,
            mask = attn_mask
        )

        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')

        # early return if not a recurrent layer

        if not self.is_recurrent_layer and len(read_from_state_containers) == 0:
            return self.to_out(out), memories, None

        # whether to read from own state container, default to on, but may pass in more

        if self.is_recurrent_layer and self.state_read_before_write:
            read_from_state_containers = [self.state_container, *read_from_state_containers]

        for read_state_container in read_from_state_containers:
            # read from the states ...

            to_state_out = read_state_container.read(x)

            # and concat it to the output of self-attention

            out = torch.cat((out, to_state_out), dim = -1)

        new_states = None

        if self.is_recurrent_layer:
            # then write to the states as well if need be

            new_states = self.state_container.write(memories = memories)

        return self.to_out(out), memories, new_states

# classes

@beartype
class BlockRecurrentTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        all_layers_qk_rmsnorm = False,
        ff_mult = 4,
        max_seq_len = 1024,
        block_width = 512,
        recurrent_layers: Optional[Tuple[int, ...]] = None,
        read_recurrent_layers: Optional[Tuple[int, ...]] = None,
        num_state_vectors = None,
        ignore_index = -100,
        use_flash_attn = False,
        use_compressed_mem = False,
        compressed_mem_factor = 4
    ):
        super().__init__()
        num_state_vectors = default(num_state_vectors, block_width)

        # set recurrent layers

        recurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the network

        assert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'
        assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'

        self.recurrent_layers = recurrent_layers

        # set read recurrent layers

        read_recurrent_layers = default(read_recurrent_layers, recurrent_layers)

        assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'
        assert all([0 < layer <= depth for layer in read_recurrent_layers])
        assert len(read_recurrent_layers) == len(recurrent_layers)

        self.read_recurrent_layers = read_recurrent_layers

        # token embedding

        self.token_emb = nn.Embedding(num_tokens, dim)

        self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width)

        self.layers = nn.ModuleList([])

        self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}

        self.read_state_router = defaultdict(list)

        for layer in range(1, depth + 1):
            is_recurrent_layer = layer in self.recurrent_layers

            layer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0

            num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])

            # only layers with xl memories
            # or has recurrence in horizontal direction
            # use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper)
            # one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = True

            qk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layer

            attn_block = AttentionBlock(
                dim,
                block_width = block_width,
                dim_head = dim_head,
                heads = heads,
                qk_rmsnorm = qk_rmsnorm,
                num_state_vectors = layer_num_state_vectors,
                use_flash_attn = use_flash_attn,
                num_external_state_reads = num_external_state_reads,
                state_read_before_write = False,
            )

            ff_block = FeedForward(dim, mult = ff_mult)

            if is_recurrent_layer:
                read_layer = self.write_to_read_map[layer]
                self.read_state_router[read_layer].append(attn_block.state_container)

            self.layers.append(nn.ModuleList([
                attn_block,
                ff_block
            ]))

        # (compressed) memory management

        self.mem_manager = MemoryManager(
            dim = dim_head,
            layers = depth,
            mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2),
            compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor)
        )

        # to logits

        self.to_logits = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

        self.max_seq_len = max_seq_len
        self.block_width = block_width

        assert divisible_by(max_seq_len, block_width)

        self.ignore_index = ignore_index

        self.register_buffer('cached_causal_attn_mask', None, persistent = False)

    @property
    def device(self):
        return next(self.parameters()).device

    def get_causal_attn_mask(self, width):
        if exists(self.cached_causal_attn_mask):
            cached_mask = self.cached_causal_attn_mask
            cached_width = cached_mask.shape[-2]
            padding = (width - cached_width) // 2
            j_slice = Ellipsis if padding == 0 else slice(padding, -padding)
            return cached_mask[:cached_width, j_slice]

        device = self.device
        causal_mask = torch.ones((width, width), device = device, dtype = torch.bool).triu(1)
        return ~causal_mask

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prime,
        length = None,
        xl_memories: List[torch.Tensor] = [],
        states: List[torch.Tensor] = [],
        temperature = 1.,
        filter_thres = 0.9,
        return_memories_and_states = False
    ):
        length = default(length, self.max_seq_len + 1)
        start_len = prime.shape[-1]

        assert start_len < self.max_seq_len
        assert length <= (self.max_seq_len + 1)
        assert start_len < length

        output = prime

        memories = []

        for ind in range(length - start_len):

            logits, next_memories, next_states = self.forward(
                output,
                xl_memories = xl_memories,
                states = states
            )

            logits = logits[:, -1]

            filtered_logits = top_k(logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature)
            sampled = rearrange(sampled, 'b -> b 1')

            output = torch.cat((output, sampled), dim = -1)

            if divisible_by(output.shape[-1] - 1, self.max_seq_len): # on the sampling of the last token in the current window, set new memories and states
                memories = next_memories
                states = next_states

        output = output[:, start_len:]

        if return_memories_and_states:
            return output, memories, states

        return output

    def forward(
        self,
        x,
        return_loss = False,
        xl_memories: List[torch.Tensor] = [],
        states: List[torch.Tensor] = [],
        return_memories_and_states = None  # can force to either return memory + state or not. by default will only return when number of tokens == max_seq_len
    ):
        device = x.device

        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]

        # get sequence length i and j for dynamic pos bias

        assert x.shape[-1] <= self.max_seq_len

        w = self.block_width

        # token embedding

        x = self.token_emb(x)

        # dynamic pos bias

        attn_mask = self.get_causal_attn_mask(w)
        rotary_pos_emb, xpos_scale = self.rotary_pos_emb()

        # only return memories and state if at the full block width, but can be overridden

        return_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2])

        # ready output tensor, to be concatted to block by block

        batch, _, dim = x.shape

        out = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device)

        # split input into blocks of width w

        input_blocks = x.split(w, dim = -2)

        # process each block at a time

        for input_block in input_blocks:
            input_block_length = input_block.shape[-2]

            # ready xl memories and states

            iter_xl_memories = iter(xl_memories)
            iter_states = iter(states)

            next_xl_memories = []
            next_states = []

            # set the states on the appropriate state containers

            for attn, _ in self.layers:
                if not attn.is_recurrent_layer:
                    continue

                attn.state_container.set_next_read_state(next(iter_states, None))

            # go through layers

            for ind, (attn, ff) in enumerate(self.layers):

                # determine if the layer requires transformer xl memories

                layer = ind + 1

                # whether to pass in xl memories

                attn_kwargs = dict(
                    rotary_pos_emb = rotary_pos_emb,
                    xpos_scale = xpos_scale,
                    attn_mask = attn_mask,
                    xl_memories = next(iter_xl_memories, None),
                    read_from_state_containers = self.read_state_router[layer]
                )

                # attention layer

                residual = input_block
                attn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs)

                if exists(layer_xl_memories):
                    next_xl_memories.append(layer_xl_memories)

                if exists(layer_next_states):
                    next_states.append(layer_next_states)

                input_block = attn_branch_out + residual

                # feedforward layer

                input_block = ff(input_block) + input_block

            # concat to output

            out = torch.cat((out, input_block), dim = -2)

            # set new xl memories and states

            states = next_states

            if input_block_length == w:
                xl_memories = self.mem_manager(xl_memories, next_xl_memories)


        # project to logits

        logits = self.to_logits(out)

        # detach the states and memories

        returned_next_states = list(map(torch.detach, states)) if return_memories_and_states else None
        returned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None

        # whether to return logits

        if not return_loss:
            return logits, returned_next_xl_memories, returned_next_states

        # cross entropy loss

        logits = rearrange(logits, 'b n c -> b c n')
        loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)

        return loss, returned_next_xl_memories, returned_next_states

# recurrent trainer wrapper

@beartype
class RecurrentTrainerWrapper(nn.Module):
    def __init__(
        self,
        transformer: BlockRecurrentTransformer,
        xl_memories_dropout = 0.,
        state_dropout = 0.
    ):
        super().__init__()
        self.transformer = transformer
        self.seq_len = transformer.max_seq_len

        self.xl_memories_dropout = xl_memories_dropout
        self.state_dropout = state_dropout

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        prime,
        length,
        **kwargs
    ):
        seq_len = self.seq_len
        start_len = prime.shape[-1]
        assert start_len < length

        output = prime
        current_len = start_len

        memories = []
        states = []

        # determine lengths

        has_remainder = not divisible_by(length, seq_len)
        remainder_amount = length % seq_len
        total_segments = math.ceil(length / seq_len)

        if not has_remainder:
            lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len)
        elif remainder_amount == 1:
            lengths = (seq_len + 1,) * (total_segments - 1)
        else:
            lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount)

        # loop through lengths

        for next_length in lengths:

            segment_output, memories, states = self.transformer.generate(
                output[:, -current_len:],
                length = next_length,
                xl_memories = memories,
                states = states,
                return_memories_and_states = True,
                **kwargs
            )

            output = torch.cat((output, segment_output), dim = -1)
            current_len = 1

        return output[:, start_len:]

    def forward(
        self,
        x,
        return_memories_and_states = False
    ):
        total_seq_len, seq_len = x.shape[1], self.seq_len

        assert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training'
        segments = total_seq_len // seq_len

        total_loss = 0.

        memories = []
        states = []

        for ind in range(segments):
            start = ind * seq_len
            end = start + seq_len + 1

            if self.training and random() < self.xl_memories_dropout:
                memories.clear()

            if self.training and random() < self.state_dropout:
                states.clear()

            loss, memories, states = self.transformer(
                x[:, start:end],
                xl_memories = memories,
                states = states,
                return_loss = True
            )

            total_loss = total_loss + (loss / segments)

        if return_memories_and_states:
            return total_loss, memories, states

        return total_loss


================================================
FILE: data/README.md
================================================
# Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

================================================
FILE: requirements.txt
================================================
accelerate
tqdm


================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages

setup(
  name = 'block-recurrent-transformer-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.4',
  license='MIT',
  description = 'Block Recurrent Transformer - Pytorch',
  author = 'Phil Wang',
  author_email = 'lucidrains@gmail.com',
  long_description_content_type = 'text/markdown',
  url = 'https://github.com/lucidrains/block-recurrent-transformer-pytorch',
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'recurrence'
  ],
  install_requires=[
    'beartype',
    'einops>=0.6.1',
    'memorizing-transformers-pytorch>=0.4.0',
    'torch>=1.6',
  ],
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)


================================================
FILE: train.py
================================================
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 250
GENERATE_LENGTH = 2048
SEQ_LEN = 2048

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))


# accelerator

accelerator = Accelerator()

device = accelerator.device
acc_print = accelerator.print

# instantiate palm

model = BlockRecurrentTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8,
    max_seq_len = 1024,
    block_width = 512,
    num_state_vectors = 512,
    recurrent_layers = (4,),
    use_flash_attn = True
)

train_wrapper = RecurrentTrainerWrapper(
    model,
    xl_memories_dropout = 0.1,
    state_dropout = 0.1,
)

model.to(device)

# prepare enwik8 data

with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# optimizer

optim = Adam(model.parameters(), lr = LEARNING_RATE)

model, optim, train_loader, val_loader = accelerator.prepare(
    model, optim, train_loader, val_loader
)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = train_wrapper(next(train_loader))
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    acc_print(f"training loss: {loss.item()}")
    accelerator.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = train_wrapper(next(val_loader))
            acc_print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        acc_print(f"%s \n\n %s", (prime, "*" * 100))

        sample = train_wrapper.generate(inp[None, ...], length = GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        acc_print(output_str, "\n")
Download .txt
gitextract_0788ko4s/

├── .github/
│   └── workflows/
│       └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── block_recurrent_transformer_pytorch/
│   ├── __init__.py
│   └── block_recurrent_transformer_pytorch.py
├── data/
│   └── README.md
├── requirements.txt
├── setup.py
└── train.py
Download .txt
SYMBOL INDEX (69 symbols across 2 files)

FILE: block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py
  function exists (line 22) | def exists(val):
  function default (line 25) | def default(val, d):
  function is_empty (line 28) | def is_empty(t: torch.Tensor):
  function cast_tuple (line 31) | def cast_tuple(t, length = 1):
  function all_unique (line 34) | def all_unique(arr):
  function eval_decorator (line 37) | def eval_decorator(fn):
  function once (line 46) | def once(fn):
  function compact (line 59) | def compact(arr):
  function and_reduce (line 62) | def and_reduce(arr: List[torch.Tensor]):
  function safe_cat (line 70) | def safe_cat(*args, dim = 1):
  function divisible_by (line 78) | def divisible_by(numer, denom):
  function l2norm (line 81) | def l2norm(t):
  function pack_one (line 84) | def pack_one(t, pattern):
  function unpack_one (line 87) | def unpack_one(t, ps, pattern):
  function pad_at_dim (line 90) | def pad_at_dim(t, pad, dim = -1, value = 0.):
  class LayerNorm (line 97) | class LayerNorm(nn.Module):
    method __init__ (line 98) | def __init__(self, dim):
    method forward (line 103) | def forward(self, x):
  function log (line 108) | def log(t, eps = 1e-20):
  function gumbel_noise (line 111) | def gumbel_noise(t):
  function gumbel_sample (line 115) | def gumbel_sample(t, temperature = 1., dim = -1):
  function top_k (line 118) | def top_k(logits, thres = 0.9):
  class RotaryEmbedding (line 129) | class RotaryEmbedding(nn.Module):
    method __init__ (line 130) | def __init__(
    method device (line 151) | def device(self):
    method forward (line 154) | def forward(self):
  function rotate_half (line 174) | def rotate_half(x):
  function apply_rotary_pos_emb (line 179) | def apply_rotary_pos_emb(t, pos, scale = 1.):
  class MemoryManager (line 196) | class MemoryManager(nn.Module):
    method __init__ (line 197) | def __init__(
    method forward (line 240) | def forward(
  class StateContainer (line 307) | class StateContainer(nn.Module):
    method __init__ (line 308) | def __init__(
    method set_next_read_state (line 356) | def set_next_read_state(
    method read (line 365) | def read(self, x):
    method write (line 400) | def write(
    method forward (line 449) | def forward(self, x):
  class Attend (line 454) | class Attend(nn.Module):
    method __init__ (line 455) | def __init__(
    method get_mask (line 484) | def get_mask(self, n, device):
    method flash_attn (line 492) | def flash_attn(self, q, k, v, mask = None):
    method forward (line 536) | def forward(self, q, k, v, mask = None, use_flash_attn = None):
  class GEGLU (line 585) | class GEGLU(nn.Module):
    method forward (line 586) | def forward(self, x):
  function FeedForward (line 590) | def FeedForward(dim, mult = 4):
  class Attention (line 601) | class Attention(nn.Module):
    method __init__ (line 602) | def __init__(
    method forward (line 622) | def forward(
  class AttentionBlock (line 652) | class AttentionBlock(nn.Module):
    method __init__ (line 653) | def __init__(
    method device (line 703) | def device(self):
    method forward (line 706) | def forward(
  class BlockRecurrentTransformer (line 791) | class BlockRecurrentTransformer(nn.Module):
    method __init__ (line 792) | def __init__(
    method device (line 909) | def device(self):
    method get_causal_attn_mask (line 912) | def get_causal_attn_mask(self, width):
    method generate (line 926) | def generate(
    method forward (line 974) | def forward(
  class RecurrentTrainerWrapper (line 1108) | class RecurrentTrainerWrapper(nn.Module):
    method __init__ (line 1109) | def __init__(
    method generate (line 1124) | def generate(
    method forward (line 1171) | def forward(

FILE: train.py
  function cycle (line 28) | def cycle(loader):
  function decode_token (line 33) | def decode_token(token):
  function decode_tokens (line 36) | def decode_tokens(tokens):
  class TextSamplerDataset (line 77) | class TextSamplerDataset(Dataset):
    method __init__ (line 78) | def __init__(self, data, seq_len):
    method __getitem__ (line 83) | def __getitem__(self, index):
    method __len__ (line 88) | def __len__(self):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (54K chars).
[
  {
    "path": ".github/workflows/python-publish.yml",
    "chars": 1064,
    "preview": "\n  \n# This workflow will upload a Python Package using Twine when a release is created\n# For more information see: https"
  },
  {
    "path": ".gitignore",
    "chars": 1799,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1066,
    "preview": "MIT License\n\nCopyright (c) 2023 Phil Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "README.md",
    "chars": 5090,
    "preview": "<img src=\"./block-recurrent-transformer.png\" width=\"450px\"></img>\n\n## Block Recurrent Transformer - Pytorch\n\nImplementat"
  },
  {
    "path": "block_recurrent_transformer_pytorch/__init__.py",
    "chars": 344,
    "preview": "import torch\nfrom packaging import version\n\nif version.parse(torch.__version__) >= version.parse('2.0.0'):\n    from eino"
  },
  {
    "path": "block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py",
    "chars": 37555,
    "preview": "import math\nfrom random import random\nfrom functools import wraps, partial\nfrom itertools import zip_longest\nfrom collec"
  },
  {
    "path": "data/README.md",
    "chars": 99,
    "preview": "# Data source\n\nThe enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/"
  },
  {
    "path": "requirements.txt",
    "chars": 16,
    "preview": "accelerate\ntqdm\n"
  },
  {
    "path": "setup.py",
    "chars": 948,
    "preview": "from setuptools import setup, find_packages\n\nsetup(\n  name = 'block-recurrent-transformer-pytorch',\n  packages = find_pa"
  },
  {
    "path": "train.py",
    "chars": 3391,
    "preview": "import gzip\nimport random\nimport tqdm\nimport numpy as np\n\nimport torch\nfrom torch.optim import Adam\nfrom torch.nn import"
  }
]

About this extraction

This page contains the full source code of the lucidrains/block-recurrent-transformer-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (50.2 KB), approximately 12.9k tokens, and a symbol index with 69 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!