Repository: lucidrains/glom-pytorch
Branch: main
Commit: f30f62165d0c
Files: 7
Total size: 13.7 KB
Directory structure:
gitextract_fkzu5mpf/
├── .github/
│ └── workflows/
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── glom_pytorch/
│ ├── __init__.py
│ └── glom_pytorch.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/python-publish.yml
================================================
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Phil Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<img src="./glom2.png" width="400px"></img>
<img src="./glom1.png" width="600px"></img>
## GLOM - Pytorch
An implementation of <a href="https://arxiv.org/abs/2102.12627">Glom</a>, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.
<a href="https://www.youtube.com/watch?v=cllFzkvrYmE">Yannic Kilcher's video</a> was instrumental in helping me to understand this paper
## Install
```bash
$ pip install glom-pytorch
```
## Usage
```python
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
```
Pass the `return_all = True` keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.
It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.
```python
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)
# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)
```
Denoising self-supervised learning for encouraging emergence, as described by Hinton
```python
import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)
all_levels = model(noised_img, return_all = True)
patches_to_images = nn.Sequential(
nn.Linear(512, 14 * 14 * 3),
Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)
top_level = all_levels[7, :, :, -1] # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)
# do self-supervised learning by denoising
loss = F.mse_loss(img, recon_img)
loss.backward()
```
You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)
```python
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512,
levels = 6,
image_size = 224,
patch_size = 14
)
img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)
levels1 = model(img1, iters = 12) # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
```
### Appreciation
Thanks goes out to <a href="https://github.com/cfoster0">Cfoster0</a> for reviewing the code
### Todo
- [ ] contrastive / consistency regularization of top-ish levels
## Citations
```bibtex
@misc{hinton2021represent,
title = {How to represent part-whole hierarchies in a neural network},
author = {Geoffrey Hinton},
year = {2021},
eprint = {2102.12627},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
================================================
FILE: glom_pytorch/__init__.py
================================================
from glom_pytorch.glom_pytorch import Glom
================================================
FILE: glom_pytorch/glom_pytorch.py
================================================
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# constants
TOKEN_ATTEND_SELF_VALUE = -5e-4
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# class
class GroupedFeedForward(nn.Module):
def __init__(self, *, dim, groups, mult = 4):
super().__init__()
total_dim = dim * groups # levels * dim
self.net = nn.Sequential(
Rearrange('b n l d -> b (l d) n'),
nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
nn.GELU(),
nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
Rearrange('b (l d) n -> b n l d', l = groups)
)
def forward(self, levels):
return self.net(levels)
class ConsensusAttention(nn.Module):
def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):
super().__init__()
self.attend_self = attend_self
self.local_consensus_radius = local_consensus_radius
if self.local_consensus_radius > 0:
coors = torch.stack(torch.meshgrid(
torch.arange(num_patches_side),
torch.arange(num_patches_side)
)).float()
coors = rearrange(coors, 'c h w -> (h w) c')
dist = torch.cdist(coors, coors)
mask_non_local = dist > self.local_consensus_radius
mask_non_local = rearrange(mask_non_local, 'i j -> () i j')
self.register_buffer('non_local_mask', mask_non_local)
def forward(self, levels):
_, n, _, d, device = *levels.shape, levels.device
q, k, v = levels, F.normalize(levels, dim = -1), levels
sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)
if not self.attend_self:
self_mask = torch.eye(n, device = device, dtype = torch.bool)
self_mask = rearrange(self_mask, 'i j -> () () i j')
sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)
if self.local_consensus_radius > 0:
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(self.non_local_mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('b l i j, b j l d -> b i l d', attn, levels)
return out
# main class
class Glom(nn.Module):
def __init__(
self,
*,
dim = 512,
levels = 6,
image_size = 224,
patch_size = 14,
consensus_self = False,
local_consensus_radius = 0
):
super().__init__()
# bottom level - incoming image, tokenize and add position
num_patches_side = (image_size // patch_size)
num_patches = num_patches_side ** 2
self.levels = levels
self.image_to_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_size ** 2 * 3, dim)
)
self.pos_emb = nn.Embedding(num_patches, dim)
# initial embeddings for all levels of a column
self.init_levels = nn.Parameter(torch.randn(levels, dim))
# bottom-up and top-down
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)
# consensus attention
self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)
def forward(self, img, iters = None, levels = None, return_all = False):
b, device = img.shape[0], img.device
iters = default(iters, self.levels * 2) # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden
tokens = self.image_to_tokens(img)
n = tokens.shape[1]
pos_embs = self.pos_emb(torch.arange(n, device = device))
pos_embs = rearrange(pos_embs, 'n d -> () n () d')
bottom_level = tokens
bottom_level = rearrange(bottom_level, 'b n d -> b n () d')
if not exists(levels):
levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)
hiddens = [levels]
num_contributions = torch.empty(self.levels, device = device).fill_(4)
num_contributions[-1] = 3 # top level does not get a top-down contribution, so have to account for this when doing the weighted mean
for _ in range(iters):
levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input at the most bottom level, to be bottomed-up
bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])
top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks
top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)
consensus = self.attention(levels)
levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value
levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')
levels = levels_mean # set for next iteration
hiddens.append(levels)
if return_all:
return torch.stack(hiddens) # return (time step, batch, num columns, levels, dimension)
return levels
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(
name = 'glom-pytorch',
packages = find_packages(),
version = '0.0.14',
license='MIT',
description = 'Glom - Pytorch',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/glom-pytorch',
keywords = [
'artificial intelligence',
'deep learning'
],
install_requires=[
'einops>=0.3',
'torch>=1.6'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
gitextract_fkzu5mpf/ ├── .github/ │ └── workflows/ │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── glom_pytorch/ │ ├── __init__.py │ └── glom_pytorch.py └── setup.py
SYMBOL INDEX (11 symbols across 1 files)
FILE: glom_pytorch/glom_pytorch.py
function exists (line 15) | def exists(val):
function default (line 18) | def default(val, d):
class GroupedFeedForward (line 23) | class GroupedFeedForward(nn.Module):
method __init__ (line 24) | def __init__(self, *, dim, groups, mult = 4):
method forward (line 35) | def forward(self, levels):
class ConsensusAttention (line 38) | class ConsensusAttention(nn.Module):
method __init__ (line 39) | def __init__(self, num_patches_side, attend_self = True, local_consens...
method forward (line 56) | def forward(self, levels):
class Glom (line 77) | class Glom(nn.Module):
method __init__ (line 78) | def __init__(
method forward (line 110) | def forward(self, img, iters = None, levels = None, return_all = False):
Condensed preview — 7 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (15K chars).
[
{
"path": ".github/workflows/python-publish.yml",
"chars": 864,
"preview": "# This workflow will upload a Python Package using Twine when a release is created\n# For more information see: https://h"
},
{
"path": ".gitignore",
"chars": 1799,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 1066,
"preview": "MIT License\n\nCopyright (c) 2021 Phil Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
},
{
"path": "README.md",
"chars": 3943,
"preview": "<img src=\"./glom2.png\" width=\"400px\"></img>\n\n<img src=\"./glom1.png\" width=\"600px\"></img>\n\n## GLOM - Pytorch\n\nAn implemen"
},
{
"path": "glom_pytorch/__init__.py",
"chars": 43,
"preview": "from glom_pytorch.glom_pytorch import Glom\n"
},
{
"path": "glom_pytorch/glom_pytorch.py",
"chars": 5654,
"preview": "from math import sqrt\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\n\nfrom einops import rear"
},
{
"path": "setup.py",
"chars": 689,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name = 'glom-pytorch',\n packages = find_packages(),\n version = '"
}
]
About this extraction
This page contains the full source code of the lucidrains/glom-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 7 files (13.7 KB), approximately 4.0k tokens, and a symbol index with 11 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.