[
  {
    "path": ".github/workflows/python-publish.yml",
    "content": "# This workflow will upload a Python Package using Twine when a release is created\n# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries\n\nname: Upload Python Package\n\non:\n  release:\n    types: [created]\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python\n      uses: actions/setup-python@v2\n      with:\n        python-version: '3.x'\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install setuptools wheel twine\n    - name: Build and publish\n      env:\n        TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n        TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n      run: |\n        python setup.py sdist bdist_wheel\n        twine upload dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Phil Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<img src=\"./glom2.png\" width=\"400px\"></img>\n\n<img src=\"./glom1.png\" width=\"600px\"></img>\n\n## GLOM - Pytorch\n\nAn implementation of <a href=\"https://arxiv.org/abs/2102.12627\">Glom</a>, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.\n\n<a href=\"https://www.youtube.com/watch?v=cllFzkvrYmE\">Yannic Kilcher's video</a> was instrumental in helping me to understand this paper\n\n## Install\n\n```bash\n$ pip install glom-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nfrom glom_pytorch import Glom\n\nmodel = Glom(\n    dim = 512,         # dimension\n    levels = 6,        # number of levels\n    image_size = 224,  # image size\n    patch_size = 14    # patch size\n)\n\nimg = torch.randn(1, 3, 224, 224)\nlevels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)\n```\n\nPass the `return_all = True` keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.\n\nIt also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.\n\n```python\nimport torch\nfrom glom_pytorch import Glom\n\nmodel = Glom(\n    dim = 512,         # dimension\n    levels = 6,        # number of levels\n    image_size = 224,  # image size\n    patch_size = 14    # patch size\n)\n\nimg = torch.randn(1, 3, 224, 224)\nall_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)\n\n# get the top level outputs after iteration 6\ntop_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)\n```\n\nDenoising self-supervised learning for encouraging emergence, as described by Hinton\n\n```python\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom einops.layers.torch import Rearrange\n\nfrom glom_pytorch import Glom\n\nmodel = Glom(\n    dim = 512,         # dimension\n    levels = 6,        # number of levels\n    image_size = 224,  # image size\n    patch_size = 14    # patch size\n)\n\nimg = torch.randn(1, 3, 224, 224)\nnoised_img = img + torch.randn_like(img)\n\nall_levels = model(noised_img, return_all = True)\n\npatches_to_images = nn.Sequential(\n    nn.Linear(512, 14 * 14 * 3),\n    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))\n)\n\ntop_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6\nrecon_img = patches_to_images(top_level)\n\n# do self-supervised learning by denoising\n\nloss = F.mse_loss(img, recon_img)\nloss.backward()\n```\n\nYou can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)\n\n```python\nimport torch\nfrom glom_pytorch import Glom\n\nmodel = Glom(\n    dim = 512,\n    levels = 6,\n    image_size = 224,\n    patch_size = 14\n)\n\nimg1 = torch.randn(1, 3, 224, 224)\nimg2 = torch.randn(1, 3, 224, 224)\nimg3 = torch.randn(1, 3, 224, 224)\n\nlevels1 = model(img1, iters = 12)                   # image 1 for 12 iterations\nlevels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins\nlevels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations\n```\n\n### Appreciation\n\nThanks goes out to <a href=\"https://github.com/cfoster0\">Cfoster0</a> for reviewing the code\n\n### Todo\n\n- [ ] contrastive / consistency regularization of top-ish levels\n\n## Citations\n\n```bibtex\n@misc{hinton2021represent,\n    title   = {How to represent part-whole hierarchies in a neural network}, \n    author  = {Geoffrey Hinton},\n    year    = {2021},\n    eprint  = {2102.12627},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CV}\n}\n```\n"
  },
  {
    "path": "glom_pytorch/__init__.py",
    "content": "from glom_pytorch.glom_pytorch import Glom\n"
  },
  {
    "path": "glom_pytorch/glom_pytorch.py",
    "content": "from math import sqrt\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\n\nfrom einops import rearrange, repeat\nfrom einops.layers.torch import Rearrange\n\n# constants\n\nTOKEN_ATTEND_SELF_VALUE = -5e-4\n\n# helpers\n\ndef exists(val):\n    return val is not None\n\ndef default(val, d):\n    return val if exists(val) else d\n\n# class\n\nclass GroupedFeedForward(nn.Module):\n    def __init__(self, *, dim, groups, mult = 4):\n        super().__init__()\n        total_dim = dim * groups # levels * dim\n        self.net = nn.Sequential(\n            Rearrange('b n l d -> b (l d) n'),\n            nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),\n            nn.GELU(),\n            nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),\n            Rearrange('b (l d) n -> b n l d', l = groups)\n        )\n\n    def forward(self, levels):\n        return self.net(levels)\n\nclass ConsensusAttention(nn.Module):\n    def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):\n        super().__init__()\n        self.attend_self = attend_self\n        self.local_consensus_radius = local_consensus_radius\n\n        if self.local_consensus_radius > 0:\n            coors = torch.stack(torch.meshgrid(\n                torch.arange(num_patches_side),\n                torch.arange(num_patches_side)\n            )).float()\n\n            coors = rearrange(coors, 'c h w -> (h w) c')\n            dist = torch.cdist(coors, coors)\n            mask_non_local = dist > self.local_consensus_radius\n            mask_non_local = rearrange(mask_non_local, 'i j -> () i j')\n            self.register_buffer('non_local_mask', mask_non_local)\n\n    def forward(self, levels):\n        _, n, _, d, device = *levels.shape, levels.device\n        q, k, v = levels, F.normalize(levels, dim = -1), levels\n\n        sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)\n\n        if not self.attend_self:\n            self_mask = torch.eye(n, device = device, dtype = torch.bool)\n            self_mask = rearrange(self_mask, 'i j -> () () i j')\n            sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)\n\n        if self.local_consensus_radius > 0:\n            max_neg_value = -torch.finfo(sim.dtype).max\n            sim.masked_fill_(self.non_local_mask, max_neg_value)\n\n        attn = sim.softmax(dim = -1)\n        out = einsum('b l i j, b j l d -> b i l d', attn, levels)\n        return out\n\n# main class\n\nclass Glom(nn.Module):\n    def __init__(\n        self,\n        *,\n        dim = 512,\n        levels = 6,\n        image_size = 224,\n        patch_size = 14,\n        consensus_self = False,\n        local_consensus_radius = 0\n    ):\n        super().__init__()\n        # bottom level - incoming image, tokenize and add position\n        num_patches_side = (image_size // patch_size)\n        num_patches =  num_patches_side ** 2\n        self.levels = levels\n\n        self.image_to_tokens = nn.Sequential(\n            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),\n            nn.Linear(patch_size ** 2 * 3, dim)\n        )\n        self.pos_emb = nn.Embedding(num_patches, dim)\n\n        # initial embeddings for all levels of a column\n        self.init_levels = nn.Parameter(torch.randn(levels, dim))\n\n        # bottom-up and top-down\n        self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)\n        self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)\n\n        # consensus attention\n        self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)\n\n    def forward(self, img, iters = None, levels = None, return_all = False):\n        b, device = img.shape[0], img.device\n        iters = default(iters, self.levels * 2)   # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden\n\n        tokens = self.image_to_tokens(img)\n        n = tokens.shape[1]\n\n        pos_embs = self.pos_emb(torch.arange(n, device = device))\n        pos_embs = rearrange(pos_embs, 'n d -> () n () d')\n\n        bottom_level = tokens\n        bottom_level = rearrange(bottom_level, 'b n d -> b n () d')\n\n        if not exists(levels):\n            levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)\n\n        hiddens = [levels]\n\n        num_contributions = torch.empty(self.levels, device = device).fill_(4)\n        num_contributions[-1] = 3  # top level does not get a top-down contribution, so have to account for this when doing the weighted mean\n\n        for _ in range(iters):\n            levels_with_input = torch.cat((bottom_level, levels), dim = -2)  # each iteration, attach original input at the most bottom level, to be bottomed-up\n\n            bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])\n\n            top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks\n            top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)\n\n            consensus = self.attention(levels)\n\n            levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value\n            levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')\n\n            levels = levels_mean  # set for next iteration\n            hiddens.append(levels)\n\n        if return_all:\n            return torch.stack(hiddens)  # return (time step, batch, num columns, levels, dimension)\n\n        return levels\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n  name = 'glom-pytorch',\n  packages = find_packages(),\n  version = '0.0.14',\n  license='MIT',\n  description = 'Glom - Pytorch',\n  author = 'Phil Wang',\n  author_email = 'lucidrains@gmail.com',\n  url = 'https://github.com/lucidrains/glom-pytorch',\n  keywords = [\n    'artificial intelligence',\n    'deep learning'\n  ],\n  install_requires=[\n    'einops>=0.3',\n    'torch>=1.6'\n  ],\n  classifiers=[\n    'Development Status :: 4 - Beta',\n    'Intended Audience :: Developers',\n    'Topic :: Scientific/Engineering :: Artificial Intelligence',\n    'License :: OSI Approved :: MIT License',\n    'Programming Language :: Python :: 3.6',\n  ],\n)\n"
  }
]