[
  {
    "path": ".github/workflows/python-publish.yml",
    "content": "\n  \n# This workflow will upload a Python Package using Twine when a release is created\n# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries\n\n# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n# separate terms of service, privacy policy, and support\n# documentation.\n\nname: Upload Python Package\n\non:\n  release:\n    types: [published]\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python\n      uses: actions/setup-python@v2\n      with:\n        python-version: '3.x'\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install build\n    - name: Build package\n      run: python -m build\n    - name: Publish package\n      uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29\n      with:\n        user: __token__\n        password: ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Phil Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<img src=\"./block-recurrent-transformer.png\" width=\"450px\"></img>\n\n## Block Recurrent Transformer - Pytorch\n\nImplementation of <a href=\"https://arxiv.org/abs/2203.07852\">Block Recurrent Transformer</a> - Pytorch. The highlight of the paper is its reported ability to remember something up to 60k tokens ago.\n\nThis design is SOTA for recurrent transformers line of research, afaict.\n\nIt will also include <a href=\"https://arxiv.org/abs/2205.14135\">flash attention</a> as well as routed memories of up to 250k tokens using ideas from <a href=\"https://github.com/lucidrains/CoLT5-attention\">this paper</a>\n\n## Appreciation\n\n- <a href=\"https://stability.ai/\">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research\n\n## Install\n\n```bash\n$ pip install block-recurrent-transformer-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nfrom block_recurrent_transformer_pytorch import BlockRecurrentTransformer\n\nmodel = BlockRecurrentTransformer(\n    num_tokens = 20000,             # vocab size\n    dim = 512,                      # model dimensions\n    depth = 6,                      # depth\n    dim_head = 64,                  # attention head dimensions\n    heads = 8,                      # number of attention heads\n    max_seq_len = 1024,             # the total receptive field of the transformer, in the paper this was 2 * block size\n    block_width = 512,              # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)\n    num_state_vectors = 512,        # number of state vectors, i believe this was a single block size in the paper, but can be any amount\n    recurrent_layers = (4,),        # where to place the recurrent layer(s) for states with fixed simple gating\n    use_compressed_mem = False,     # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507\n    compressed_mem_factor = 4,      # compression factor of compressed memories\n    use_flash_attn = True           # use flash attention, if on pytorch 2.0\n)\n\nseq = torch.randint(0, 2000, (1, 1024))\n\nout, mems1, states1 = model(seq)\nout, mems2, states2 = model(seq, xl_memories = mems1, states = states1)\nout, mems3, states3 = model(seq, xl_memories = mems2, states = states2)\n```\n\n## Test on Enwik8\n\nFirst `pip install -r requirements.txt`, then\n\n```bash\n$ python train.py\n```\n\n## Todo\n\n- [x] use dynamic positional bias\n- [x] add enhanced recurrence\n- [x] setup local attention blocks, as in the paper\n- [x] wrapper transformer class for training\n- [x] take care of generation with recurrence in `RecurrentTrainWrapper`\n- [x] add ability to dropout to entire memories and states during each segment step during trainng\n- [x] test full system on enwik8 locally and ablate states and memories and see effects first  hand\n- [x] make sure attention allow for single head key / values too\n- [x] run a few experiments of fixed gating in regular transformers - does not work\n- [x] integrate <a href=\"https://github.com/hazyresearch/flash-attention\">flash attention</a>\n- [x] cache attention mask + rotary embeddings\n- [x] add <a href=\"https://github.com/lucidrains/compressive-transformer-pytorch\">compressed memories</a>\n\n- [ ] revisit <a href=\"https://github.com/lucidrains/memformer\">memformer</a>\n- [ ] try routing long distance memories of up to 250k using coordinate descent (Wright et al.)\n\n## Citations\n\n```bibtex\n@article{Hutchins2022BlockRecurrentT,\n    title   = {Block-Recurrent Transformers},\n    author  = {DeLesley S. Hutchins and Imanol Schlag and Yuhuai Wu and Ethan Dyer and Behnam Neyshabur},\n    journal = {ArXiv},\n    year    = {2022},\n    volume  = {abs/2203.07852}\n}\n```\n\n```bibtex\n@article{Shazeer2019FastTD,\n    title   = {Fast Transformer Decoding: One Write-Head is All You Need},\n    author  = {Noam M. Shazeer},\n    journal = {ArXiv},\n    year    = {2019},\n    volume  = {abs/1911.02150}\n}\n```\n\n```bibtex\n@inproceedings{Sun2022ALT,\n    title     = {A Length-Extrapolatable Transformer},\n    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},\n    year      = {2022}\n}\n```\n\n```bibtex\n@inproceedings{dao2022flashattention,\n    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},\n    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\\'e}, Christopher},\n    booktitle = {Advances in Neural Information Processing Systems},\n    year    = {2022}\n}\n```\n\n```bibtex\n@inproceedings{Ainslie2023CoLT5FL,\n    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},\n    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},\n    year    = {2023}\n}\n```\n\n*Memory is Attention through Time* - Alex Graves\n"
  },
  {
    "path": "block_recurrent_transformer_pytorch/__init__.py",
    "content": "import torch\nfrom packaging import version\n\nif version.parse(torch.__version__) >= version.parse('2.0.0'):\n    from einops._torch_specific import allow_ops_in_compiled_graph\n    allow_ops_in_compiled_graph()\n\nfrom block_recurrent_transformer_pytorch.block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper\n"
  },
  {
    "path": "block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py",
    "content": "import math\nfrom random import random\nfrom functools import wraps, partial\nfrom itertools import zip_longest\nfrom collections import namedtuple, defaultdict\nfrom packaging import version\n\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\n\nfrom einops import rearrange, repeat, pack, unpack\nfrom einops.layers.torch import Rearrange\n\nfrom beartype import beartype\nfrom beartype.door import is_bearable\nfrom beartype.typing import Optional, List, Tuple\n\n# helpers\n\ndef exists(val):\n    return val is not None\n\ndef default(val, d):\n    return val if exists(val) else d\n\ndef is_empty(t: torch.Tensor):\n    return t.numel() == 0\n\ndef cast_tuple(t, length = 1):\n    return t if isinstance(t, tuple) else ((t,) * length)\n\ndef all_unique(arr):\n    return len(arr) == len(set(arr))\n\ndef eval_decorator(fn):\n    def inner(self, *args, **kwargs):\n        was_training = self.training\n        self.eval()\n        out = fn(self, *args, **kwargs)\n        self.train(was_training)\n        return out\n    return inner\n\ndef once(fn):\n    called = False\n    @wraps(fn)\n    def inner(x):\n        nonlocal called\n        if called:\n            return\n        called = True\n        return fn(x)\n    return inner\n\nprint_once = once(print)\n\ndef compact(arr):\n    return [*filter(exists, arr)]\n\ndef and_reduce(arr: List[torch.Tensor]):\n    if len(arr) == 0:\n        return None\n    head, *rest = arr\n    for t in rest:\n        head = head & t\n    return head\n\ndef safe_cat(*args, dim = 1):\n    args = compact(args)\n\n    if len(args) == 0:\n        return None\n\n    return torch.cat(args, dim = dim)\n\ndef divisible_by(numer, denom):\n    return (numer % denom) == 0\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n\ndef pack_one(t, pattern):\n    return pack([t], pattern)\n\ndef unpack_one(t, ps, pattern):\n    return unpack(t, ps, pattern)[0]\n\ndef pad_at_dim(t, pad, dim = -1, value = 0.):\n    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)\n    zeros = ((0, 0) * dims_from_right)\n    return F.pad(t, (*zeros, *pad), value = value)\n\n# bias-less layernorm\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.gamma = nn.Parameter(torch.ones(dim))\n        self.register_buffer(\"beta\", torch.zeros(dim))\n\n    def forward(self, x):\n        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)\n\n# sampling helpers\n\ndef log(t, eps = 1e-20):\n    return torch.log(t.clamp(min = eps))\n\ndef gumbel_noise(t):\n    noise = torch.zeros_like(t).uniform_(0, 1)\n    return -log(-log(noise))\n\ndef gumbel_sample(t, temperature = 1., dim = -1):\n    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)\n\ndef top_k(logits, thres = 0.9):\n    k = math.ceil((1 - thres) * logits.shape[-1])\n    val, ind = torch.topk(logits, k)\n    probs = torch.full_like(logits, float('-inf'))\n    probs.scatter_(1, ind, val)\n    return probs\n\n# rotary positional embedding w/ xpos\n# https://arxiv.org/abs/2104.09864\n# https://arxiv.org/abs/2212.10554v1\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        dim,\n        width,\n        scale_base = 512,\n        theta = 10000\n    ):\n        super().__init__()\n        self.width = width\n\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent = False)\n\n        self.scale_base = scale_base\n        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)\n        self.register_buffer('scale', scale, persistent = False)\n\n        self.register_buffer('cached_freqs', None, persistent = False)\n        self.register_buffer('cached_scales', None, persistent = False)\n\n    @property\n    def device(self):\n        return next(self.buffers()).device\n\n    def forward(self):\n        device, seq_len = self.device, self.width\n\n        if exists(self.cached_freqs):\n            cached_seq_len = self.cached_freqs.shape[-2]\n            if cached_seq_len >= seq_len:\n                return self.cached_freqs[:seq_len], self.cached_scales[:seq_len]\n\n        t = torch.arange(seq_len, device = device).type_as(self.inv_freq)\n        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)\n        freqs = torch.cat((freqs, freqs), dim = -1)\n\n        power = (t - (seq_len // 2)) / self.scale_base\n        scale = self.scale ** rearrange(power, 'n -> n 1')\n        scale = torch.cat((scale, scale), dim = -1)\n\n        self.register_buffer('cached_freqs', freqs, persistent = False)\n        self.register_buffer('cached_scales', scale, persistent = False)\n        return freqs, scale\n\ndef rotate_half(x):\n    x1, x2 = x.chunk(2, dim=-1)\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(t, pos, scale = 1.):\n    scale = default(scale, 1.)\n\n    seq_len = t.shape[-2]\n\n    assert pos.shape[-2] >= seq_len\n\n    pos = pos[-seq_len:]\n\n    if isinstance(scale, torch.Tensor):\n        assert scale.shape[-2] >= seq_len\n        scale = scale[-seq_len:]\n\n    return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)\n\n# memory management\n\nclass MemoryManager(nn.Module):\n    def __init__(\n        self,\n        dim,\n        *,\n        layers = 1,\n        mem_lengths = 512,\n        compress_factors = 1\n    ):\n        super().__init__()\n        mem_lengths = cast_tuple(mem_lengths)\n        compress_factors = cast_tuple(compress_factors)\n\n        assert all([mem_length > 0 for mem_length in mem_lengths])\n        assert len(mem_lengths) == len(compress_factors)\n        assert layers >= 1\n\n        self.mem_lengths = mem_lengths\n        self.compress_factors = compress_factors\n\n        self.layers = nn.ModuleList([])\n\n        for _ in range(layers):\n            compress_fns = nn.ModuleList([])\n\n            for compress_factor in compress_factors:\n                compress_fn = nn.Identity()\n                if compress_factor > 1:\n                    compress_fn = nn.Sequential(\n                        Rearrange('b n d -> b d n'),\n                        nn.Conv1d(\n                            dim * 2,\n                            dim * 2,\n                            compress_factor,\n                            stride = compress_factor,\n                            groups = 2\n                        ),\n                        Rearrange('b d n -> b n d'),\n                    )\n\n                compress_fns.append(compress_fn)\n\n            self.layers.append(compress_fns)\n\n    def forward(\n        self,\n        past_memories: List[torch.Tensor],\n        new_memories: List[torch.Tensor]\n    ):\n        next_memories = []\n\n        for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers):\n\n            # edge case if neither memories exist\n\n            if not (exists(past_memory) or exists(new_memory)):\n                next_memories.append(None)\n                continue\n\n            next_memory = None\n\n            for mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns):\n\n                # first get the memories for the given compression factor \"current_memory\"\n\n                current_memory = None\n                if exists(past_memory):\n                    past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :]\n\n                # compress the new memories coming in, based on the compression factors set at init\n\n                if (not is_empty(new_memory)) and compress_factor > 1:\n                    # make sure memory length is divisible by compression factor\n\n                    new_mem_length = new_memory.shape[-2]\n\n                    curtailed_length = (new_mem_length // compress_factor) * compress_factor\n\n                    curtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0)\n                    new_memory = new_memory[..., curtailed_slice, :]\n\n                    # compress the memory pushed to the next stage\n\n                    if new_memory.shape[-2] > 0:\n                        new_memory = rearrange(new_memory, 'm b n d -> b n (m d)')\n                        new_memory = compress_fn(new_memory)\n                        new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2)\n\n                # fifo memory queue\n                # add the new memory on the right\n\n                current_memory = safe_cat(current_memory, new_memory, dim = -2)\n                # \"new\" memory is new with respect to the next compressed segment\n\n                new_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :]\n                # concat the new memory to the left into the past\n\n                next_memory = safe_cat(current_memory, next_memory, dim = -2)\n\n            next_memories.append(next_memory)\n\n        return next_memories\n\n# maybe flash attention, if using pytorch 2.0\n\n# constants\n\nConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])\n\n# state container\n\nclass StateContainer(nn.Module):\n    def __init__(\n        self,\n        dim,\n        *,\n        num_state_vectors,\n        dim_head = 64,\n        heads = 8,\n        qk_rmsnorm = False,\n        qk_rmsnorm_scale = 8,\n        use_flash_attn = False\n    ):\n        super().__init__()\n        assert num_state_vectors > 0\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.state_norm = LayerNorm(dim)\n\n        self.q_to_state = nn.Linear(dim, inner_dim, bias = False)\n        self.q_from_state = nn.Linear(dim, inner_dim, bias = False)\n\n        self.state_to_q = nn.Linear(dim, inner_dim, bias = False)\n        self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)\n\n        self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))\n        torch.nn.init.normal_(self.init_state, 0, .1)\n        self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))\n        # NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm\n        torch.nn.init.normal_(self.state_pos_ids, 0, 1)\n\n        self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)\n\n        self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)\n\n        self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)\n        self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)\n\n        # gating related parameters - using the fixed simple config\n\n        self.state_out_to_gate = nn.Linear(dim, dim)\n        self.learned_ema_beta = nn.Parameter(torch.randn(dim))\n        torch.nn.init.normal_(self.learned_ema_beta, 0, .1)\n\n        # since each read should be followed by a write, just store cache in the container\n\n        self.cache = None\n        self.next_read_state = None\n\n    def set_next_read_state(\n        self,\n        states\n    ):\n        if not exists(states):\n            states = self.init_state\n\n        self.next_read_state = (states,)\n\n    def read(self, x):\n        assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'\n\n        states, = self.next_read_state\n        self.next_read_state = None\n\n        # pre norm state for attention\n\n        normed_states = self.state_norm(states)\n\n        # add the positional ids, as stated in the paper critical for it to work\n\n        normed_states = normed_states + self.state_pos_ids\n\n        # get queries for cross attention, which they do not share, although they share key / values. another intriguing detail\n\n        q_to_state = self.q_to_state(x)\n        q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads)\n\n        # self attention qkv for states\n\n        state_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1)\n\n        # cross attend to the past states key values\n\n        to_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)\n\n        to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)')\n\n        # cache for next write\n\n        self.cache = (states, normed_states, state_k, state_v)\n\n        return to_state_out\n\n    def write(\n        self,\n        *,\n        memories\n    ):\n        assert exists(self.cache)\n\n        k, v = memories\n        batch = k.shape[0]\n\n        # get cached values from the previous read\n\n        states, normed_states, state_k, state_v = self.cache\n\n        self.cache = None\n\n        # derive queries\n\n        q_from_state = self.q_from_state(normed_states)\n        q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads)\n\n        state_q = self.state_to_q(normed_states)\n        state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)'\n        state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch)\n\n        # states must also undergo self attention\n\n        if q_from_state.ndim == 3:\n            q_from_state = repeat(q_from_state, '... -> b ...', b = batch)\n\n        state_out = self.state_self_attn(state_q, state_k, state_v)\n\n        from_state_out = self.from_state_cross_attn(q_from_state, k, v)\n\n        state_out = torch.cat((state_out, from_state_out), dim = -1)\n        state_out = rearrange(state_out, 'b h n d -> b n (h d)')\n\n        state_out = self.to_state_out(state_out)\n\n        # use the best performing configuration\n        # fixed simple gate - nothing more than a learned EMA with some resemblance to highway networks\n\n        z = self.state_out_to_gate(state_out)\n        learned_ema_decay = self.learned_ema_beta.sigmoid()\n\n        # set new state with the learned EMA gating\n\n        return learned_ema_decay * z + (1 - learned_ema_decay) * states\n\n    def forward(self, x):\n        raise NotImplementedError\n\n# main class\n\nclass Attend(nn.Module):\n    def __init__(\n        self,\n        causal = False,\n        use_flash_attn = False\n    ):\n        super().__init__()\n        self.causal = causal\n        self.register_buffer(\"mask\", None, persistent=False)\n\n        self.use_flash_attn = use_flash_attn\n        assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'\n\n        # determine efficient attention configs for cuda and cpu\n\n        self.cpu_config = Config(True, True, True)\n        self.cuda_config = None\n\n        if not torch.cuda.is_available() or not use_flash_attn:\n            return\n\n        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))\n\n        if device_properties.major == 8 and device_properties.minor == 0:\n            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')\n            self.cuda_config = Config(True, False, False)\n        else:\n            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')\n            self.cuda_config = Config(False, True, True)\n\n    def get_mask(self, n, device):\n        if exists(self.mask) and self.mask.shape[-1] >= n:\n            return self.mask[:n, :n]\n\n        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)\n        self.register_buffer(\"mask\", mask, persistent=False)\n        return mask\n\n    def flash_attn(self, q, k, v, mask = None):\n        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda\n\n        # Recommended for multi-query single-key-value attention by Tri Dao\n        # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])\n\n        if k.ndim == 3:\n            k = repeat(k, 'b ... -> b h ...', h = q.shape[1])\n\n        if v.ndim == 3:\n            v = repeat(v, 'b ... -> b h ...', h = q.shape[1])\n\n        # Check if mask exists and expand to compatible shape\n        # The mask is B L, so it would have to be expanded to B H N L\n\n        masks = []\n\n        if self.causal:\n            i, j = q_len, k_len\n            causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)\n            masks.append(~causal_mask)\n\n        if exists(mask):\n            if mask.ndim != 2:\n                mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])\n\n            masks.append(mask)\n\n        attn_mask = and_reduce(masks)\n\n        # Check if there is a compatible device for flash attention\n\n        config = self.cuda_config if is_cuda else self.cpu_config\n\n        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale\n\n        with torch.backends.cuda.sdp_kernel(**config._asdict()):\n            out = F.scaled_dot_product_attention(\n                q, k, v,\n                attn_mask = attn_mask\n            )\n\n        return out\n\n    def forward(self, q, k, v, mask = None, use_flash_attn = None):\n        use_flash_attn = default(use_flash_attn, self.use_flash_attn)\n\n        b, n, device = q.shape[0], q.shape[-2], q.device\n\n        q, ps = pack_one(q, '* h n d')\n        k, _ = pack_one(k, '* n d')\n        v, _ = pack_one(v, '* n d')\n\n        if use_flash_attn:\n            out = self.flash_attn(q, k, v, mask = mask)\n            return unpack_one(out, ps, '* h n d')\n\n        scale = q.shape[-1] ** -0.5\n\n        k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'\n        v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'\n\n        # similarity\n\n        sim = einsum(f\"b h i d, {k_einsum} -> b h i j\", q, k) * scale\n\n        # key padding mask\n\n        if exists(mask):\n            if mask.ndim != 2:\n                mask = repeat(mask, 'w ... -> (b w) ...', b = b)\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # causal mask\n\n        if self.causal:\n            i, j = sim.shape[-2:]\n            causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)\n            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)\n\n        # attention\n\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum(f\"b h i j, {v_einsum} -> b h i d\", attn, v)\n\n        return unpack_one(out, ps, '* h n d')\n\n# geglu feedforward\n\nclass GEGLU(nn.Module):\n    def forward(self, x):\n        x, gate = x.chunk(2, dim = -1)\n        return F.gelu(gate) * x\n\ndef FeedForward(dim, mult = 4):\n    inner_dim = int(dim * mult * 2 / 3)\n    return nn.Sequential(\n        LayerNorm(dim),\n        nn.Linear(dim, inner_dim * 2, bias = False),\n        GEGLU(),\n        nn.Linear(inner_dim, dim, bias = False)\n    )\n\n# attention\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim_head,\n        causal = False,\n        qk_rmsnorm = False,\n        qk_rmsnorm_scale = 8,\n        use_flash_attn = False\n    ):\n        super().__init__()\n        self.causal = causal\n\n        self.qk_rmsnorm = qk_rmsnorm\n        self.qk_rmsnorm_scale = qk_rmsnorm_scale\n\n        self.attend = Attend(causal = causal, use_flash_attn = use_flash_attn)\n\n        if qk_rmsnorm:\n            self.q_scale = nn.Parameter(torch.ones(dim_head))\n            self.k_scale = nn.Parameter(torch.ones(dim_head))\n\n    def forward(\n        self,\n        q, k, v,\n        mask = None,\n        rotary_pos_emb = None,\n        xpos_scale = None\n    ):\n\n        scale = q.shape[-1] ** -0.5\n\n        if self.qk_rmsnorm:\n            q, k = map(l2norm, (q, k))\n            scale = self.qk_rmsnorm_scale\n\n        if self.qk_rmsnorm:\n            q = q * self.q_scale\n            k = k * self.k_scale\n\n        # rotary positional embedding with xpos for length extrapolation\n\n        if exists(rotary_pos_emb):\n            q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)\n            k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1)\n\n        # attention\n\n        out = self.attend(q, k, v, mask = mask)\n\n        return out\n\nclass AttentionBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        block_width,\n        dim_head = 64,\n        heads = 8,\n        qk_rmsnorm = False,\n        qk_rmsnorm_scale = 8,\n        use_flash_attn = False,\n        num_state_vectors = 0,\n        num_external_state_reads = 0,\n        state_read_before_write = True  # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer\n    ):\n        super().__init__()\n        inner_dim = dim_head * heads\n        self.heads = heads\n\n        self.norm = LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias = False)\n\n        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)\n\n        self.attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)\n\n        self.block_width = block_width\n        self.is_recurrent_layer = num_state_vectors > 0\n\n        # decide how many states this attention layer is going to read from\n\n        num_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_reads\n\n        self.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False)\n\n        if not self.is_recurrent_layer:\n            return\n\n        self.state_read_before_write = state_read_before_write\n\n        self.state_container = StateContainer(\n            dim,\n            dim_head = dim_head,\n            heads = heads,\n            num_state_vectors = num_state_vectors,\n            qk_rmsnorm = qk_rmsnorm,\n            qk_rmsnorm_scale = qk_rmsnorm_scale,\n            use_flash_attn = use_flash_attn\n        )\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    def forward(\n        self,\n        x,\n        rotary_pos_emb = None,\n        xpos_scale = None,\n        attn_mask = None,\n        xl_memories: Optional[torch.Tensor] = None,\n        read_from_state_containers: List[StateContainer] = []\n    ):\n        batch, seq_len, _, width, device = *x.shape, self.block_width, self.device\n\n        # pre normalization\n\n        x = self.norm(x)\n\n        # queries, keys, values and split out heads\n\n        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))\n\n        split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads)\n        q = split_head(q)\n\n        # save the last key / values as memories for recurrence\n\n        memories = torch.stack((k, v))\n\n        mem_len = 0\n\n        if exists(xl_memories):\n            # if past memories are passed in, concat as the first bucket\n            mem_len = xl_memories.shape[-2]\n            past_k, past_v = xl_memories\n            k = torch.cat((past_k, k), dim = 1)\n            v = torch.cat((past_v, v), dim = 1)\n\n        # handle cropping of attention mask and positional embeddings\n\n        if exists(attn_mask):\n            attn_mask = attn_mask[:seq_len, :seq_len]\n            attn_mask = F.pad(attn_mask, (mem_len, 0), value = True)\n\n        # attention, but of course\n\n        out = self.attn(\n            q, k, v,\n            rotary_pos_emb = rotary_pos_emb,\n            xpos_scale = xpos_scale,\n            mask = attn_mask\n        )\n\n        # merge heads\n\n        out = rearrange(out, 'b h n d -> b n (h d)')\n\n        # early return if not a recurrent layer\n\n        if not self.is_recurrent_layer and len(read_from_state_containers) == 0:\n            return self.to_out(out), memories, None\n\n        # whether to read from own state container, default to on, but may pass in more\n\n        if self.is_recurrent_layer and self.state_read_before_write:\n            read_from_state_containers = [self.state_container, *read_from_state_containers]\n\n        for read_state_container in read_from_state_containers:\n            # read from the states ...\n\n            to_state_out = read_state_container.read(x)\n\n            # and concat it to the output of self-attention\n\n            out = torch.cat((out, to_state_out), dim = -1)\n\n        new_states = None\n\n        if self.is_recurrent_layer:\n            # then write to the states as well if need be\n\n            new_states = self.state_container.write(memories = memories)\n\n        return self.to_out(out), memories, new_states\n\n# classes\n\n@beartype\nclass BlockRecurrentTransformer(nn.Module):\n    def __init__(\n        self,\n        *,\n        num_tokens,\n        dim,\n        depth,\n        dim_head = 64,\n        heads = 8,\n        all_layers_qk_rmsnorm = False,\n        ff_mult = 4,\n        max_seq_len = 1024,\n        block_width = 512,\n        recurrent_layers: Optional[Tuple[int, ...]] = None,\n        read_recurrent_layers: Optional[Tuple[int, ...]] = None,\n        num_state_vectors = None,\n        ignore_index = -100,\n        use_flash_attn = False,\n        use_compressed_mem = False,\n        compressed_mem_factor = 4\n    ):\n        super().__init__()\n        num_state_vectors = default(num_state_vectors, block_width)\n\n        # set recurrent layers\n\n        recurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the network\n\n        assert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'\n        assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'\n\n        self.recurrent_layers = recurrent_layers\n\n        # set read recurrent layers\n\n        read_recurrent_layers = default(read_recurrent_layers, recurrent_layers)\n\n        assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'\n        assert all([0 < layer <= depth for layer in read_recurrent_layers])\n        assert len(read_recurrent_layers) == len(recurrent_layers)\n\n        self.read_recurrent_layers = read_recurrent_layers\n\n        # token embedding\n\n        self.token_emb = nn.Embedding(num_tokens, dim)\n\n        self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width)\n\n        self.layers = nn.ModuleList([])\n\n        self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}\n\n        self.read_state_router = defaultdict(list)\n\n        for layer in range(1, depth + 1):\n            is_recurrent_layer = layer in self.recurrent_layers\n\n            layer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0\n\n            num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])\n\n            # only layers with xl memories\n            # or has recurrence in horizontal direction\n            # use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper)\n            # one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = True\n\n            qk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layer\n\n            attn_block = AttentionBlock(\n                dim,\n                block_width = block_width,\n                dim_head = dim_head,\n                heads = heads,\n                qk_rmsnorm = qk_rmsnorm,\n                num_state_vectors = layer_num_state_vectors,\n                use_flash_attn = use_flash_attn,\n                num_external_state_reads = num_external_state_reads,\n                state_read_before_write = False,\n            )\n\n            ff_block = FeedForward(dim, mult = ff_mult)\n\n            if is_recurrent_layer:\n                read_layer = self.write_to_read_map[layer]\n                self.read_state_router[read_layer].append(attn_block.state_container)\n\n            self.layers.append(nn.ModuleList([\n                attn_block,\n                ff_block\n            ]))\n\n        # (compressed) memory management\n\n        self.mem_manager = MemoryManager(\n            dim = dim_head,\n            layers = depth,\n            mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2),\n            compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor)\n        )\n\n        # to logits\n\n        self.to_logits = nn.Sequential(\n            LayerNorm(dim),\n            nn.Linear(dim, num_tokens, bias = False)\n        )\n\n        self.max_seq_len = max_seq_len\n        self.block_width = block_width\n\n        assert divisible_by(max_seq_len, block_width)\n\n        self.ignore_index = ignore_index\n\n        self.register_buffer('cached_causal_attn_mask', None, persistent = False)\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    def get_causal_attn_mask(self, width):\n        if exists(self.cached_causal_attn_mask):\n            cached_mask = self.cached_causal_attn_mask\n            cached_width = cached_mask.shape[-2]\n            padding = (width - cached_width) // 2\n            j_slice = Ellipsis if padding == 0 else slice(padding, -padding)\n            return cached_mask[:cached_width, j_slice]\n\n        device = self.device\n        causal_mask = torch.ones((width, width), device = device, dtype = torch.bool).triu(1)\n        return ~causal_mask\n\n    @torch.no_grad()\n    @eval_decorator\n    def generate(\n        self,\n        prime,\n        length = None,\n        xl_memories: List[torch.Tensor] = [],\n        states: List[torch.Tensor] = [],\n        temperature = 1.,\n        filter_thres = 0.9,\n        return_memories_and_states = False\n    ):\n        length = default(length, self.max_seq_len + 1)\n        start_len = prime.shape[-1]\n\n        assert start_len < self.max_seq_len\n        assert length <= (self.max_seq_len + 1)\n        assert start_len < length\n\n        output = prime\n\n        memories = []\n\n        for ind in range(length - start_len):\n\n            logits, next_memories, next_states = self.forward(\n                output,\n                xl_memories = xl_memories,\n                states = states\n            )\n\n            logits = logits[:, -1]\n\n            filtered_logits = top_k(logits, thres = filter_thres)\n            sampled = gumbel_sample(filtered_logits, temperature = temperature)\n            sampled = rearrange(sampled, 'b -> b 1')\n\n            output = torch.cat((output, sampled), dim = -1)\n\n            if divisible_by(output.shape[-1] - 1, self.max_seq_len): # on the sampling of the last token in the current window, set new memories and states\n                memories = next_memories\n                states = next_states\n\n        output = output[:, start_len:]\n\n        if return_memories_and_states:\n            return output, memories, states\n\n        return output\n\n    def forward(\n        self,\n        x,\n        return_loss = False,\n        xl_memories: List[torch.Tensor] = [],\n        states: List[torch.Tensor] = [],\n        return_memories_and_states = None  # can force to either return memory + state or not. by default will only return when number of tokens == max_seq_len\n    ):\n        device = x.device\n\n        if return_loss:\n            x, labels = x[:, :-1], x[:, 1:]\n\n        # get sequence length i and j for dynamic pos bias\n\n        assert x.shape[-1] <= self.max_seq_len\n\n        w = self.block_width\n\n        # token embedding\n\n        x = self.token_emb(x)\n\n        # dynamic pos bias\n\n        attn_mask = self.get_causal_attn_mask(w)\n        rotary_pos_emb, xpos_scale = self.rotary_pos_emb()\n\n        # only return memories and state if at the full block width, but can be overridden\n\n        return_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2])\n\n        # ready output tensor, to be concatted to block by block\n\n        batch, _, dim = x.shape\n\n        out = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device)\n\n        # split input into blocks of width w\n\n        input_blocks = x.split(w, dim = -2)\n\n        # process each block at a time\n\n        for input_block in input_blocks:\n            input_block_length = input_block.shape[-2]\n\n            # ready xl memories and states\n\n            iter_xl_memories = iter(xl_memories)\n            iter_states = iter(states)\n\n            next_xl_memories = []\n            next_states = []\n\n            # set the states on the appropriate state containers\n\n            for attn, _ in self.layers:\n                if not attn.is_recurrent_layer:\n                    continue\n\n                attn.state_container.set_next_read_state(next(iter_states, None))\n\n            # go through layers\n\n            for ind, (attn, ff) in enumerate(self.layers):\n\n                # determine if the layer requires transformer xl memories\n\n                layer = ind + 1\n\n                # whether to pass in xl memories\n\n                attn_kwargs = dict(\n                    rotary_pos_emb = rotary_pos_emb,\n                    xpos_scale = xpos_scale,\n                    attn_mask = attn_mask,\n                    xl_memories = next(iter_xl_memories, None),\n                    read_from_state_containers = self.read_state_router[layer]\n                )\n\n                # attention layer\n\n                residual = input_block\n                attn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs)\n\n                if exists(layer_xl_memories):\n                    next_xl_memories.append(layer_xl_memories)\n\n                if exists(layer_next_states):\n                    next_states.append(layer_next_states)\n\n                input_block = attn_branch_out + residual\n\n                # feedforward layer\n\n                input_block = ff(input_block) + input_block\n\n            # concat to output\n\n            out = torch.cat((out, input_block), dim = -2)\n\n            # set new xl memories and states\n\n            states = next_states\n\n            if input_block_length == w:\n                xl_memories = self.mem_manager(xl_memories, next_xl_memories)\n\n\n        # project to logits\n\n        logits = self.to_logits(out)\n\n        # detach the states and memories\n\n        returned_next_states = list(map(torch.detach, states)) if return_memories_and_states else None\n        returned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None\n\n        # whether to return logits\n\n        if not return_loss:\n            return logits, returned_next_xl_memories, returned_next_states\n\n        # cross entropy loss\n\n        logits = rearrange(logits, 'b n c -> b c n')\n        loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)\n\n        return loss, returned_next_xl_memories, returned_next_states\n\n# recurrent trainer wrapper\n\n@beartype\nclass RecurrentTrainerWrapper(nn.Module):\n    def __init__(\n        self,\n        transformer: BlockRecurrentTransformer,\n        xl_memories_dropout = 0.,\n        state_dropout = 0.\n    ):\n        super().__init__()\n        self.transformer = transformer\n        self.seq_len = transformer.max_seq_len\n\n        self.xl_memories_dropout = xl_memories_dropout\n        self.state_dropout = state_dropout\n\n    @eval_decorator\n    @torch.no_grad()\n    def generate(\n        self,\n        prime,\n        length,\n        **kwargs\n    ):\n        seq_len = self.seq_len\n        start_len = prime.shape[-1]\n        assert start_len < length\n\n        output = prime\n        current_len = start_len\n\n        memories = []\n        states = []\n\n        # determine lengths\n\n        has_remainder = not divisible_by(length, seq_len)\n        remainder_amount = length % seq_len\n        total_segments = math.ceil(length / seq_len)\n\n        if not has_remainder:\n            lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len)\n        elif remainder_amount == 1:\n            lengths = (seq_len + 1,) * (total_segments - 1)\n        else:\n            lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount)\n\n        # loop through lengths\n\n        for next_length in lengths:\n\n            segment_output, memories, states = self.transformer.generate(\n                output[:, -current_len:],\n                length = next_length,\n                xl_memories = memories,\n                states = states,\n                return_memories_and_states = True,\n                **kwargs\n            )\n\n            output = torch.cat((output, segment_output), dim = -1)\n            current_len = 1\n\n        return output[:, start_len:]\n\n    def forward(\n        self,\n        x,\n        return_memories_and_states = False\n    ):\n        total_seq_len, seq_len = x.shape[1], self.seq_len\n\n        assert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training'\n        segments = total_seq_len // seq_len\n\n        total_loss = 0.\n\n        memories = []\n        states = []\n\n        for ind in range(segments):\n            start = ind * seq_len\n            end = start + seq_len + 1\n\n            if self.training and random() < self.xl_memories_dropout:\n                memories.clear()\n\n            if self.training and random() < self.state_dropout:\n                states.clear()\n\n            loss, memories, states = self.transformer(\n                x[:, start:end],\n                xl_memories = memories,\n                states = states,\n                return_loss = True\n            )\n\n            total_loss = total_loss + (loss / segments)\n\n        if return_memories_and_states:\n            return total_loss, memories, states\n\n        return total_loss\n"
  },
  {
    "path": "data/README.md",
    "content": "# Data source\n\nThe enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate\ntqdm\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n  name = 'block-recurrent-transformer-pytorch',\n  packages = find_packages(exclude=[]),\n  version = '0.4.4',\n  license='MIT',\n  description = 'Block Recurrent Transformer - Pytorch',\n  author = 'Phil Wang',\n  author_email = 'lucidrains@gmail.com',\n  long_description_content_type = 'text/markdown',\n  url = 'https://github.com/lucidrains/block-recurrent-transformer-pytorch',\n  keywords = [\n    'artificial intelligence',\n    'deep learning',\n    'transformers',\n    'attention mechanism',\n    'recurrence'\n  ],\n  install_requires=[\n    'beartype',\n    'einops>=0.6.1',\n    'memorizing-transformers-pytorch>=0.4.0',\n    'torch>=1.6',\n  ],\n  classifiers=[\n    'Development Status :: 4 - Beta',\n    'Intended Audience :: Developers',\n    'Topic :: Scientific/Engineering :: Artificial Intelligence',\n    'License :: OSI Approved :: MIT License',\n    'Programming Language :: Python :: 3.6',\n  ],\n)\n"
  },
  {
    "path": "train.py",
    "content": "import gzip\nimport random\nimport tqdm\nimport numpy as np\n\nimport torch\nfrom torch.optim import Adam\nfrom torch.nn import functional as F\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom accelerate import Accelerator\nfrom block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper\n\n# constants\n\nNUM_BATCHES = int(1e5)\nBATCH_SIZE = 4\nGRADIENT_ACCUMULATE_EVERY = 4\nLEARNING_RATE = 1e-4\nVALIDATE_EVERY = 100\nPRIME_LENGTH = 128\nGENERATE_EVERY = 250\nGENERATE_LENGTH = 2048\nSEQ_LEN = 2048\n\n# helpers\n\ndef cycle(loader):\n    while True:\n        for data in loader:\n            yield data\n\ndef decode_token(token):\n    return str(chr(max(32, token)))\n\ndef decode_tokens(tokens):\n    return \"\".join(list(map(decode_token, tokens)))\n\n\n# accelerator\n\naccelerator = Accelerator()\n\ndevice = accelerator.device\nacc_print = accelerator.print\n\n# instantiate palm\n\nmodel = BlockRecurrentTransformer(\n    num_tokens = 256,\n    dim = 512,\n    depth = 6,\n    dim_head = 64,\n    heads = 8,\n    max_seq_len = 1024,\n    block_width = 512,\n    num_state_vectors = 512,\n    recurrent_layers = (4,),\n    use_flash_attn = True\n)\n\ntrain_wrapper = RecurrentTrainerWrapper(\n    model,\n    xl_memories_dropout = 0.1,\n    state_dropout = 0.1,\n)\n\nmodel.to(device)\n\n# prepare enwik8 data\n\nwith gzip.open(\"./data/enwik8.gz\") as file:\n    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()\n    np_train, np_valid = np.split(data, [int(90e6)])\n    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)\n\nclass TextSamplerDataset(Dataset):\n    def __init__(self, data, seq_len):\n        super().__init__()\n        self.data = data\n        self.seq_len = seq_len\n\n    def __getitem__(self, index):\n        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))\n        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()\n        return full_seq.to(device)\n\n    def __len__(self):\n        return self.data.size(0) // self.seq_len\n\ntrain_dataset = TextSamplerDataset(data_train, SEQ_LEN)\nval_dataset = TextSamplerDataset(data_val, SEQ_LEN)\ntrain_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))\nval_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))\n\n# optimizer\n\noptim = Adam(model.parameters(), lr = LEARNING_RATE)\n\nmodel, optim, train_loader, val_loader = accelerator.prepare(\n    model, optim, train_loader, val_loader\n)\n\n# training\n\nfor i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc=\"training\"):\n    model.train()\n\n    for _ in range(GRADIENT_ACCUMULATE_EVERY):\n        loss = train_wrapper(next(train_loader))\n        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)\n\n    acc_print(f\"training loss: {loss.item()}\")\n    accelerator.clip_grad_norm_(model.parameters(), 0.5)\n\n    optim.step()\n    optim.zero_grad()\n\n    if i % VALIDATE_EVERY == 0:\n        model.eval()\n        with torch.no_grad():\n            loss = train_wrapper(next(val_loader))\n            acc_print(f\"validation loss: {loss.item()}\")\n\n    if i % GENERATE_EVERY == 0:\n        model.eval()\n        inp = random.choice(val_dataset)[:PRIME_LENGTH]\n        prime = decode_tokens(inp)\n        acc_print(f\"%s \\n\\n %s\", (prime, \"*\" * 100))\n\n        sample = train_wrapper.generate(inp[None, ...], length = GENERATE_LENGTH)\n        output_str = decode_tokens(sample[0])\n        acc_print(output_str, \"\\n\")\n"
  }
]