Repository: kyegomez/RT-X
Branch: main
Commit: 1393e85e18fb
Files: 36
Total size: 82.5 KB
Directory structure:
gitextract_yvc81ecv/
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ ├── PULL_REQUEST_TEMPLATE.yml
│ ├── dependabot.yml
│ ├── labeler.yml
│ └── workflows/
│ ├── docs.yml
│ ├── label.yml
│ ├── publish.yml
│ ├── pull-request-links.yml
│ ├── pylint.yml
│ ├── python-publish.yml
│ ├── stale.yml
│ ├── test.yml
│ ├── unit_test.yml
│ └── welcome.yml
├── .gitignore
├── LICENSE
├── README.md
├── examples/
│ ├── __init__.py
│ ├── efficient_net_example.py
│ ├── rtx1_example.py
│ └── train_example.py
├── pyproject.toml
├── requirements.txt
├── rtx/
│ ├── __init__.py
│ ├── data_util.py
│ ├── efficient_net.py
│ ├── rtx1.py
│ └── rtx2.py
├── rtx2_example.py
├── run.py
└── tests/
├── __init__.py
├── test_data_utils.py
├── test_rtx1.py
└── tests.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: [kyegomez]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: #Nothing
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a detailed report on the bug and it's root cause. Conduct root cause error analysis
title: "[BUG] "
labels: bug
assignees: kyegomez
---
**Describe the bug**
A clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: 'kyegomez'
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.yml
================================================
<!-- Thank you for contributing to Zeta!
Replace this comment with:
- Description: a description of the change,
- Issue: the issue # it fixes (if applicable),
- Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer (see below),
- Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out!
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on network access,
2. an example notebook showing its use.
Maintainer responsibilities:
- nn / Misc / if you don't know who to tag: kye@apac.ai
- tokenizers: kye@apac.ai
- training / Prompts: kye@apac.ai
- models: kye@apac.ai
If no one reviews your PR within a few days, feel free to kye@apac.ai
See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/kyegomez/zeta
================================================
FILE: .github/dependabot.yml
================================================
# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "weekly"
================================================
FILE: .github/labeler.yml
================================================
Documentation:
- changed-files:
- any-glob-to-any-file: '**/*.md'
- any-glob-to-any-file: 'docs/**'
# Add 'feature' label to any PR where the head branch name starts with `feature` or has a `feature` section in the name
feature:
- head-branch: ['^feature', 'feature']
# Add 'bug' label to any PR where the head branch name starts with `bug` or has a `bug` section in the name
bug:
- head-branch: ['^bug', 'bug']
================================================
FILE: .github/workflows/docs.yml
================================================
name: Docs WorkFlow
on:
push:
branches:
- master
- main
- develop
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: pip install mkdocs-material
- run: pip install "mkdocstrings[python]"
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/label.yml
================================================
# This workflow will triage pull requests and apply a label based on the
# paths that are modified in the pull request.
#
# To use this workflow, you will need to set up a .github/labeler.yml
# file with configuration. For more information, see:
# https://github.com/actions/labeler
name: Labeler
on: [pull_request_target]
jobs:
label:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/labeler@v5
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
================================================
FILE: .github/workflows/publish.yml
================================================
name: Supervision Releases to PyPi
on:
push:
tags:
- '[0-9]+.[0-9]+[0-9]+.[0-9]'
- '[0-9]+.[0-9]+[0-9]+.[0-9]'
- '[0-9]+.[0-9]+[0-9]+.[0-9]'
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}
- name: 🐍 Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: 🏗️ Build source and wheel distributions
run: |
python -m pip install --upgrade build twine
python -m build
twine check --strict dist/*
- name: 🚀 Publish to PyPi
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: ${{ secrets.PYPI_USERNAME }}
password: ${{ secrets.PYPI_PASSWORD }}
- name: 🚀 Publish to Test-PyPi
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
user: ${{ secrets.PYPI_TEST_USERNAME }}
password: ${{ secrets.PYPI_TEST_PASSWORD }}
================================================
FILE: .github/workflows/pull-request-links.yml
================================================
name: readthedocs/actions
on:
pull_request_target:
types:
- opened
paths:
- "docs/**"
permissions:
pull-requests: write
jobs:
pull-request-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: zeta
================================================
FILE: .github/workflows/pylint.yml
================================================
name: Pylint
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
================================================
FILE: .github/workflows/python-publish.yml
================================================
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@81e9d935c883d0b210363ab89cf05f3894778450
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .github/workflows/stale.yml
================================================
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
#
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: Mark stale issues and pull requests
on:
schedule:
- cron: '26 12 * * *'
jobs:
stale:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'
stale-pr-message: 'Stale pull request message'
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
================================================
FILE: .github/workflows/test.yml
================================================
# name: test
# on:
# push:
# branches: [master]
# pull_request:
# workflow_dispatch:
# env:
# POETRY_VERSION: "1.4.2"
# jobs:
# build:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# python-version:
# - "3.8"
# - "3.9"
# - "3.10"
# - "3.11"
# test_type:
# - "core"
# - "extended"
# name: Python ${{ matrix.python-version }} ${{ matrix.test_type }}
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
# uses: "./.github/actions/poetry_setup"
# with:
# python-version: ${{ matrix.python-version }}
# poetry-version: "1.4.2"
# cache-key: ${{ matrix.test_type }}
# install-command: |
# if [ "${{ matrix.test_type }}" == "core" ]; then
# echo "Running core tests, installing dependencies with poetry..."
# poetry install
# else
# echo "Running extended tests, installing dependencies with poetry..."
# poetry install -E extended_testing
# fi
# - name: Run ${{matrix.test_type}} tests
# run: |
# if [ "${{ matrix.test_type }}" == "core" ]; then
# make test
# else
# make extended_tests
# fi
# shell: bash
================================================
FILE: .github/workflows/unit_test.yml
================================================
name: "python 3.11 | 3.10"
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build_and_test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- '3.11'
- '3.10'
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install -r requirements.txt
- name: Verify integration test results
run: python3 -m unittest
================================================
FILE: .github/workflows/welcome.yml
================================================
name: Welcome WorkFlow
on:
issues:
types: [opened]
pull_request_target:
types: [opened]
jobs:
build:
name: 👋 Welcome
runs-on: ubuntu-latest
steps:
- uses: actions/first-interaction@v1.3.0
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
issue-message: "Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap."
pr-message: "Hello there, thank you for opening an PR ! 🙏🏻 The team was notified and they will get back to you asap."
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
.ruff_cache/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
datasets_cache/
**/*checkpoints**/*
.DS_Store
runs/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 Eternal Reclaimer
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
[](https://discord.gg/qUtxnK2NMf)
# RT-X
Pytorch implementation of the models RT-1-X and RT-2-X from the paper: "Open X-Embodiment: Robotic Learning Datasets and RT-X Models".
Here we implement both model architectures, RTX-1 and RTX-2
[Paper Link](https://robotics-transformer-x.github.io/)
- The RTX-2 Implementation does not natively output for simplicity a 7 dimensional vector but rather text tokens, if you wanted to output 7 dimensional vector you could implement the same token learner as in RTX1
# Appreciation
* Lucidrains
* Agorians
# Install
`pip install rtx-torch `
# Usage
To see detailed usage, run `python run.py --help`.
## RTX1
- RTX1 Usage takes in text and videos
- Does not use Efficient Net yet, we're integrating it now then the implementation will be complete
- Uses SOTA transformer architecture
```python
import torch
from rtx.rtx1 import RTX1, FilmViTConfig
# Use a pre-trained MaxVit model from pytorch
model = RTX1(film_vit_config=FilmViTConfig(pretrained=pretrained))
video = torch.randn(2, 3, 6, 224, 224)
instructions = ["bring me that apple sitting on the table", "please pass the butter"]
# compute the train logits
train_logits = model.train(video, instructions)
# set the model to evaluation mode
model.model.eval()
# compute the eval logits with a conditional scale of 3
eval_logits = model.run(video, instructions, cond_scale=3.0)
print(eval_logits.shape)
```
## RTX-2
- RTX-2 takes in images and text and interleaves them to form multi-modal sentences and outputs text tokens not a 7 dimensional vector of x,y,z,roll,pitch,yaw,and gripper
```python
import torch
from rtx import RTX2
# usage
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
model = RTX2()
output = model(img, text)
print(output)
```
## EfficientNetFilm
- Extracts the feature from the given image
```python
from rtx import EfficientNetFilm
model = EfficientNetFilm("efficientnet-b0", 10)
out = model("img.jpeg")
```
# Model Differences from the Paper Implementation
## RT-1
The main difference here is the substitution of a Film-EfficientNet backbone (pre-trained EfficientNet-B3 with Film layers inserted) with a MaxViT model.
# Tests
I created a single tests file that uses pytest to run tests on all the modules, RTX1, RTX2, EfficientNetFil, first git clone and get into the repository, install the requirements.txt with pip then run this:
`python -m pytest tests/tests.py`
# License
MIT
# Citations
```bibtex
@misc{open_x_embodiment_rt_x_2023,
title={Open {X-E}mbodiment: Robotic Learning Datasets and {RT-X} Models},
author = {Open X-Embodiment Collaboration and Abhishek Padalkar and Acorn Pooley and Ajinkya Jain and Alex Bewley and Alex Herzog and Alex Irpan and Alexander Khazatsky and Anant Rai and Anikait Singh and Anthony Brohan and Antonin Raffin and Ayzaan Wahid and Ben Burgess-Limerick and Beomjoon Kim and Bernhard Schölkopf and Brian Ichter and Cewu Lu and Charles Xu and Chelsea Finn and Chenfeng Xu and Cheng Chi and Chenguang Huang and Christine Chan and Chuer Pan and Chuyuan Fu and Coline Devin and Danny Driess and Deepak Pathak and Dhruv Shah and Dieter Büchler and Dmitry Kalashnikov and Dorsa Sadigh and Edward Johns and Federico Ceola and Fei Xia and Freek Stulp and Gaoyue Zhou and Gaurav S. Sukhatme and Gautam Salhotra and Ge Yan and Giulio Schiavi and Hao Su and Hao-Shu Fang and Haochen Shi and Heni Ben Amor and Henrik I Christensen and Hiroki Furuta and Homer Walke and Hongjie Fang and Igor Mordatch and Ilija Radosavovic and Isabel Leal and Jacky Liang and Jaehyung Kim and Jan Schneider and Jasmine Hsu and Jeannette Bohg and Jeffrey Bingham and Jiajun Wu and Jialin Wu and Jianlan Luo and Jiayuan Gu and Jie Tan and Jihoon Oh and Jitendra Malik and Jonathan Tompson and Jonathan Yang and Joseph J. Lim and João Silvério and Junhyek Han and Kanishka Rao and Karl Pertsch and Karol Hausman and Keegan Go and Keerthana Gopalakrishnan and Ken Goldberg and Kendra Byrne and Kenneth Oslund and Kento Kawaharazuka and Kevin Zhang and Keyvan Majd and Krishan Rana and Krishnan Srinivasan and Lawrence Yunliang Chen and Lerrel Pinto and Liam Tan and Lionel Ott and Lisa Lee and Masayoshi Tomizuka and Maximilian Du and Michael Ahn and Mingtong Zhang and Mingyu Ding and Mohan Kumar Srirama and Mohit Sharma and Moo Jin Kim and Naoaki Kanazawa and Nicklas Hansen and Nicolas Heess and Nikhil J Joshi and Niko Suenderhauf and Norman Di Palo and Nur Muhammad Mahi Shafiullah and Oier Mees and Oliver Kroemer and Pannag R Sanketi and Paul Wohlhart and Peng Xu and Pierre Sermanet and Priya Sundaresan and Quan Vuong and Rafael Rafailov and Ran Tian and Ria Doshi and Roberto Martín-Martín and Russell Mendonca and Rutav Shah and Ryan Hoque and Ryan Julian and Samuel Bustamante and Sean Kirmani and Sergey Levine and Sherry Moore and Shikhar Bahl and Shivin Dass and Shuran Song and Sichun Xu and Siddhant Haldar and Simeon Adebola and Simon Guist and Soroush Nasiriany and Stefan Schaal and Stefan Welker and Stephen Tian and Sudeep Dasari and Suneel Belkhale and Takayuki Osa and Tatsuya Harada and Tatsuya Matsushima and Ted Xiao and Tianhe Yu and Tianli Ding and Todor Davchev and Tony Z. Zhao and Travis Armstrong and Trevor Darrell and Vidhi Jain and Vincent Vanhoucke and Wei Zhan and Wenxuan Zhou and Wolfram Burgard and Xi Chen and Xiaolong Wang and Xinghao Zhu and Xuanlin Li and Yao Lu and Yevgen Chebotar and Yifan Zhou and Yifeng Zhu and Ying Xu and Yixuan Wang and Yonatan Bisk and Yoonyoung Cho and Youngwoon Lee and Yuchen Cui and Yueh-hua Wu and Yujin Tang and Yuke Zhu and Yunzhu Li and Yusuke Iwasawa and Yutaka Matsuo and Zhuo Xu and Zichen Jeff Cui},
howpublished = {\url{https://arxiv.org/abs/2310.08864}},
year = {2023},
}
```
# Todo
- Integrate EfficientNetFilm with RTX-1
- Create training script for RTX-1 by unrolling observations and do basic cross entropy in first rt-1
- Use RTX-2 dataset on huggingface
- [Check out the project board for more tasks](https://github.com/users/kyegomez/projects/10/views/1)
================================================
FILE: examples/__init__.py
================================================
================================================
FILE: examples/efficient_net_example.py
================================================
from rtx.efficient_net import EfficientNetFilm
model = EfficientNetFilm("efficientnet-b0", 10)
out = model("img.jpeg")
================================================
FILE: examples/rtx1_example.py
================================================
import torch
from rtx.rtx1 import RTX1, FilmViTConfig
def run(pretrained=False):
"""Run RT-X1 example.
Args:
pretrained (bool, optional): Whether or not to use a pretrained MaxVit with film (downloads from pytorch).
Defaults to False.
"""
model = RTX1(vit_config=FilmViTConfig(pretrained=pretrained))
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
# compute the train logits
model.train(video, instructions)
# set the model to evaluation mode
model.model.eval()
# compute the eval logits with a conditional scale of 3
eval_logits = model.run(video, instructions, cond_scale=3.0)
print(eval_logits.shape)
if __name__ == "__main__":
run()
================================================
FILE: examples/train_example.py
================================================
import torch
from absl import logging
def run(
model: torch.nn.Module,
):
logging.fatal("Not yet implemented.")
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "rtx-torch"
version = "0.1.3"
description = "rtx - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
homepage = "https://github.com/kyegomez/rt-x"
documentation = "https://github.com/kyegomez/rt-x" # Replace if you have documentation.
readme = "README.md" # Assuming you have a README.md
repository = "https://github.com/kyegomez/rtx"
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6"
]
packages = [
{ include = "rtx" },
{ include = "rtx/**/*.py" },
]
[tool.poetry.dependencies]
python = "^3.9,<3.12"
torch = "*"
torchvision = "^0.16.2"
einops = "0.7.0"
efficientnet_pytorch = "0.7.1"
zetascale = "1.2.5"
classifier-free-guidance-pytorch = "0.5.3"
lz4 = "^4.3.2"
torch-tb-profiler = "^0.4.3"
tensorboardX = "^2.6.2.2"
tensorboard = "^2.15.1"
olefile = "^0.47"
================================================
FILE: requirements.txt
================================================
torch
torchvision
torch-tb-profiler==0.4.3
tensorboardx==2.6.2.2
tensorboard==2.16.2
olefile==0.47
einops==0.7.0
efficientnet_pytorch==0.7.1
zetascale==1.2.5
classifier-free-guidance-pytorch==0.5.3
tensorboardX
================================================
FILE: rtx/__init__.py
================================================
from rtx.rtx2 import RTX2
from rtx.rtx1 import RTX1
from rtx.efficient_net import EfficientNetFilm
from rtx.data_util import describe, format_imgs, preprocess
__all__ = [
"RTX2",
"RTX1",
"EfficientNetFilm",
"describe",
"format_imgs",
"preprocess",
]
================================================
FILE: rtx/data_util.py
================================================
import io
import torch
import numpy as np
from PIL import Image
from tensorboardX import SummaryWriter
ArrayLike = np.ndarray | list | torch.Tensor
def map_np(input: np.ndarray, idxs: list[int], fn: callable) -> None:
"""Maps a function through a numpy array.
Args:
input (np.ndarray): Input.
fn (callable): Function to map.
Returns: None
"""
if sum(input.shape) <= 1:
fn(input, idxs)
idxs.pop()
return
for i, x in enumerate(input):
idxs.append(i)
map_np(x, idxs, fn)
def write_dict_to(name: str, writer: SummaryWriter, input: dict, step: int):
"""Writes a dictionary to tensorboard.
Args:
name (str): Name of group to identify values in dict with.
writer (SummaryWriter): Tensorboard writer.
input (dict): Input dictionary.
step (int): Global step value.
"""
for k, v in input.items():
v = np.array(v).squeeze()
if sum(v.shape) <= 1:
writer.add_scalar(name + "_" + k, v, step)
continue
map_np(
v,
[],
lambda x, idxs: writer.add_scalar(
"{}_{}-{}".format(name, k, "-".join([str(i) for i in idxs])), x, step
),
)
def describe(dic, prefix="", str_built=[]) -> str:
"""Useful to print out the structure of TF Record. ds.info can also be used
but it does not show lengths of lists and dicts.
Args:
dic (dict): Input
prefix (str, optional): Prefix used for nested indentation. Defaults to ''.
str_built (str, optional): Desription string built so far. Defaults to ''.
"""
if not isinstance(dic, dict):
return ""
def describe_img(img: bytes):
img = Image.open(io.BytesIO(img))
return f"{img.__class__.__name__} sz: { img.size}"
for k, v in dic.items():
if isinstance(v, list):
list_type = ""
if len(v) > 0:
v_description = ""
if isinstance(v[0], torch.Tensor):
v_description = f"({tuple(v[0].size())}, {v[0].dtype})"
elif isinstance(v[0], bytes):
v_description = describe_img(v[0])
list_type = f"({v[0].__class__.__name__ }{v_description})"
print(f"{prefix} {k}, {v.__class__.__name__}{list_type} sz:" f" {len(v)}")
if len(v) > 0:
str_built.append(describe(v[0], prefix + " "))
elif isinstance(v, dict):
print(f"{prefix} {k}, {v.__class__.__name__} sz:" f" {len(v.items())}")
describe(v, prefix + " ")
elif isinstance(v, bytes):
print(f"{prefix} {k}, {describe_img( v)}")
elif isinstance(v, str):
str_built.append(f"{prefix} {k}, {v.__class__.__name__} v: {v}\n")
else:
tensor_type = ""
if isinstance(v, torch.Tensor):
tensor_type = f"({tuple(v[0].size())}, {v[0].dtype})"
print(f"{prefix} {k}, {v.__class__.__name__} {tensor_type} ")
def preprocess(dic: any, height=224, width=224):
"""Remove nonetypes from a dict, convert images to numpy arrays and return.
Args:
dic (dict): Input.
Returns:
dict: Output.
"""
if isinstance(dic, bytes):
img = Image.open(io.BytesIO(dic))
return np.array(img.resize((width, height)))
if not isinstance(dic, dict):
return dic
to_remove = []
for k, v in dic.items():
if isinstance(v, list):
processed = []
for vv in v:
processed.append(preprocess(vv, height, width))
dic[k] = processed
elif v is None:
to_remove.append(k)
else:
dic[k] = preprocess(v, height, width)
for k in to_remove:
del dic[k]
return dic
def format_imgs(dic: any, sz: int):
"""Resizes images to sz as a numpy array.
Args:
dic (dict): Input.
Returns:
dict: Output.
"""
if isinstance(dic, bytes):
img = Image.open(io.BytesIO(dic))
return np.array(img.resize((sz, sz)))
return np.array(img.resize((sz, sz)))
if not isinstance(dic, dict):
return dic
for k, v in dic.items():
if isinstance(v, list):
for i in range(len(v)):
v[i] = format_imgs(v, sz)
else:
dic[k] = format_imgs(v, sz)
return dic
================================================
FILE: rtx/efficient_net.py
================================================
from torch import nn
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
class EfficientNetFilm(nn.Module):
"""
EfficientNet with FiLM layer
Args:
model (str): EfficientNet model name
num_classes (int): Number of classes
num_features (int): Number of features to output from the model
resize (int): Size to resize the image to
Attributes:
model (EfficientNet): EfficientNet model
num_classes (int): Number of classes
num_features (int): Number of features to output from the model
resize (int): Size to resize the image to
transform (torchvision.transforms.Compose): Image transformations
Example:
>>> model = EfficientNetFilm('efficientnet-b0', 10)
>>> img = Image.open('img.jpeg')
>>> features = model(img)
>>> features.shape
torch.Size([1, 1280])
"""
def __init__(
self,
model,
num_classes,
num_features=1280,
resize=224,
):
super().__init__()
self.model = model
self.num_classes = num_classes
self.num_features = num_features
self.resize = resize
self.model = EfficientNet.from_pretrained(model)
self.transform = transforms.Compose(
[
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def __call__(self, img: str):
"""
Extract the feature embeddings from the image
Args:
img (str): Path to image
"""
img = Image.open(img)
img = self.transform(img).unsqueeze(0)
print(img.shape)
features = self.model.extract_features(img)
print(features.shape)
================================================
FILE: rtx/rtx1.py
================================================
from functools import partial
import torch
from torch import nn, einsum, Tensor
from typing import List, Optional, Callable, Tuple
# from beartype import beartype
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce
from classifier_free_guidance_pytorch import (
TextConditioner as FilmTextConditioner,
AttentionTextConditioner as FilmAttentionTextConditioner,
classifier_free_guidance,
)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length=1):
return val if isinstance(val, tuple) else ((val,) * length)
def pack_one(x, pattern):
return pack([x], pattern)
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
# sinusoidal positions
def posemb_sincos_1d(seq, dim, temperature=10000, device=None, dtype=torch.float32):
n = torch.arange(seq, device=device)
omega = torch.arange(dim // 2, device=device) / (dim // 2 - 1)
omega = 1.0 / (temperature**omega)
n = n[:, None] * omega[None, :]
pos_emb = torch.cat((n.sin(), n.cos()), dim=1)
return pos_emb.type(dtype)
# helper classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x, cond_fn=None):
x = self.norm(x)
if exists(cond_fn):
# adaptive layernorm
x = cond_fn(x)
return self.net(x)
# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate=0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce("b c h w -> b c", "mean"),
nn.Linear(dim, hidden_dim, bias=False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias=False),
nn.Sigmoid(),
Rearrange("b c -> b c 1 1"),
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout=0.0):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob=0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0.0 or (not self.training):
return x
keep_mask = (
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
> self.prob
)
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate=4,
shrinkage_rate=0.25,
dropout=0.0,
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(
hidden_dim,
hidden_dim,
3,
stride=stride,
padding=1,
groups=hidden_dim,
),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out),
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout=dropout)
return net
# attention related classes
class Attention(nn.Module):
def __init__(self, dim, dim_head=32, dropout=0.0, window_size=7):
super().__init__()
assert (
dim % dim_head
) == 0, "dimension should be divisible by dimension per head"
self.norm = nn.LayerNorm(dim)
self.heads = dim // dim_head
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
)
# relative positional bias
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing="ij"))
grid = rearrange(grid, "c i j -> (i j) c")
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
grid, "j ... -> 1 j ..."
)
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
def forward(self, x):
(
batch,
height,
width,
window_height,
window_width,
_,
device,
h,
) = (
*x.shape,
x.device,
self.heads,
)
x = self.norm(x)
# flatten
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# split heads
q, k, v = map(
lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h),
(q, k, v),
)
# scale
q = q * self.scale
# sim
sim = einsum("b h i d, b h j d -> b h i j", q, k)
# add positional bias
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, "i j h -> h i j")
# attention
attn = self.attend(sim)
# aggregate
out = einsum("b h i j, b h j d -> b h i d", attn, v)
# merge heads
out = rearrange(
out,
"b h (w1 w2) d -> b w1 w2 (h d)",
w1=window_height,
w2=window_width,
)
# combine heads out
out = self.to_out(out)
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
class FilmViTConfig:
"""Configuration class to store the configuration of a `FilmMaxVit`."""
def __init__(
self,
num_classes=1000, # 1000 for ImageNet
input_channels=3,
stem_channels_in=64, # Number of stem channels
dim_head=32, # Attention head dimension
block_channel_ins: List = [
64,
128,
256,
512,
], # Number of channels for each ViT block
block_layers=[
2,
2,
5,
2,
], # Number of layers for each ViT block
window_size=7, # Partition size
mbconv_expansion_rate=4,
mbconv_shrinkage_rate=0.25, # MBConv squeeze ratio
dropout=0.1,
norm_layer: nn.Module = None,
activation_layer=nn.GELU,
stochastic_depth_prob=0.2,
pretrained=False,
):
"""
Constructs a MaxVit architecture with optional film layers from
`MaxVit: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
Parameters
----------
num_classes : int
Number of classes for the classification task
input_channels : int
Number of input channels
stem_channels_in : int
Number of stem channels
dim_head : int
Dimension of the head
block_channel_ins : List
Number of channels for each ViT block
block_layers : List
Number of layers for each ViT block
window_size : int
Partition size
mbconv_expansion_rate : int
MBConv expansion rate
mbconv_shrinkage_rate : float
MBConv squeeze ratio
dropout : float
Dropout probability
norm_layer : nn.Module
Normalization layer
activation_layer : nn.Module
Activation layer
stochastic_depth_prob : float
Stochastic depth probability
"""
self.num_classes = num_classes
self.input_channels = input_channels
self.stem_channels_in = stem_channels_in
self.block_channel_ins = block_channel_ins
self.block_layers = block_layers
self.dim_head = dim_head
self.stem_channels_in = stem_channels_in
self.window_size = window_size
self.mbconv_expansion_rate = mbconv_expansion_rate
self.mbconv_shrinkage_rate = mbconv_shrinkage_rate
self.dropout = dropout
self.norm_layer = norm_layer
if self.norm_layer is None:
self.norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
self.activation_layer = activation_layer
self.pretrained = pretrained
self.stochastic_depth_prob = stochastic_depth_prob
class FilmMaxVit(nn.Module):
def __init__(
self,
config: FilmViTConfig,
):
super().__init__()
assert isinstance(config.block_layers, tuple | list), (
"depth needs to be tuple if integers indicating number of"
" transformer blocks at that stage"
)
# List of number of input and output channels for each ViT block.
in_channels: List = [config.stem_channels_in] + config.block_channel_ins[:-1]
out_channels: List = config.block_channel_ins
# Condition after each layer starting with the input to the stem block.
self.cond_hidden_dims = [config.stem_channels_in] # Used by FilmTextConditioner
for block_in_channels, block_layers in zip(out_channels, config.block_layers):
for _ in range(block_layers):
self.cond_hidden_dims.append(block_in_channels)
self.cond_hidden_dims = self.cond_hidden_dims[
:-1
] # Don't condition on last embedding.
self.embed_dim = out_channels[-1]
if config.pretrained:
from torchvision.models import maxvit_t, MaxVit_T_Weights
self._vit = maxvit_t(weights=MaxVit_T_Weights.DEFAULT)
self.conv_stem = self._vit.stem
self.mlp_head = self._vit.classifier
self.layers = nn.ModuleList([])
for block in self._vit.blocks:
for layer in block.layers:
self.layers.append(layer)
return
# convolutional stem
self.conv_stem = nn.Sequential(
nn.Conv2d(
config.input_channels,
config.stem_channels_in,
3,
stride=2,
padding=1,
),
nn.Conv2d(
config.stem_channels_in,
config.stem_channels_in,
3,
padding=1,
),
)
self.layers = nn.ModuleList([])
for (
block_channels_in,
block_channels_out,
block_num_layers,
) in zip(in_channels, out_channels, config.block_layers):
for i in range(block_num_layers):
layer_channels_in = block_channels_in if i == 0 else block_channels_out
layer = nn.Sequential(
MBConv(
layer_channels_in,
block_channels_out,
downsample=(i == 0),
expansion_rate=config.mbconv_expansion_rate,
shrinkage_rate=config.mbconv_shrinkage_rate,
),
Rearrange(
"b d (x w1) (y w2) -> b x y w1 w2 d",
w1=config.window_size,
w2=config.window_size,
), # block-like attention
Residual(
Attention(
dim=block_channels_out,
dim_head=config.dim_head,
dropout=config.dropout,
window_size=config.window_size,
)
),
Residual(
FeedForward(
dim=block_channels_out,
dropout=config.dropout,
)
),
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
Rearrange(
"b d (w1 x) (w2 y) -> b x y w1 w2 d",
w1=config.window_size,
w2=config.window_size,
), # grid-like attention
Residual(
Attention(
dim=block_channels_out,
dim_head=config.dim_head,
dropout=config.dropout,
window_size=config.window_size,
)
),
Residual(
FeedForward(
dim=block_channels_out,
dropout=config.dropout,
)
),
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
)
self.layers.append(layer)
# mlp head out
self.mlp_head = nn.Sequential(
Reduce("b d h w -> b d", "mean"),
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, config.num_classes, bias=False),
)
# @beartype
def forward(
self,
x,
texts: Optional[List[str]] = None,
cond_fns: Optional[Tuple[Callable, ...]] = None,
cond_drop_prob=0.0,
return_embeddings=False,
):
x = self.conv_stem(x)
cond_fns = iter(default(cond_fns, []))
for stage in self.layers:
cond_fn = next(cond_fns, None)
if exists(cond_fn):
x = cond_fn(x)
x = stage(x)
if return_embeddings:
return x
return self.mlp_head(x)
# attention
class TransformerAttention(nn.Module):
def __init__(
self,
dim,
causal=False,
dim_head=64,
dim_context=None,
heads=8,
norm_context=False,
dropout=0.1,
):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
self.causal = causal
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
self.attn_dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout)
)
def forward(
self,
x,
context=None,
mask=None,
attn_bias=None,
attn_mask=None,
cond_fn: Optional[Callable] = None,
):
x.shape[0]
if exists(context):
context = self.context_norm(context)
kv_input = default(context, x)
x = self.norm(x)
if exists(cond_fn):
# adaptive layer-norm
x = cond_fn(x)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
q = q * self.scale
sim = einsum("b h i d, b j d -> b h i j", q, k)
if exists(attn_bias):
sim = sim + attn_bias
if exists(attn_mask):
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
if exists(mask):
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(
j - i + 1
)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
out = einsum("b h i j, b j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
depth=6,
attn_dropout=0.0,
ff_dropout=0.0,
):
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
TransformerAttention(
dim=dim, heads=heads, dropout=attn_dropout
),
FeedForward(dim=dim, dropout=ff_dropout),
]
)
)
def forward(
self,
x,
cond_fns: Optional[Tuple[Callable, ...]] = None,
attn_mask=None,
):
cond_fns = iter(default(cond_fns, []))
x = self.norm(x)
for attn, ff in self.layers:
x = (
attn(
self.norm(x),
attn_mask=attn_mask,
cond_fn=next(cond_fns, None),
)
+ x
)
x = ff(self.norm(x), cond_fn=next(cond_fns, None)) + x
return x
# token learner module
class TokenLearner(nn.Module):
"""
https://arxiv.org/abs/2106.11297
using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
"""
def __init__(self, *, dim, ff_mult=2, num_output_tokens=8, num_layers=2):
super().__init__()
inner_dim = dim * ff_mult * num_output_tokens
self.num_output_tokens = num_output_tokens
self.net = nn.Sequential(
nn.Conv2d(
dim * num_output_tokens,
inner_dim,
1,
groups=num_output_tokens,
),
nn.GELU(),
nn.Conv2d(
inner_dim,
num_output_tokens,
1,
groups=num_output_tokens,
),
)
def forward(self, x):
x, ps = pack_one(x, "* c h w")
x = repeat(x, "b c h w -> b (g c) h w", g=self.num_output_tokens)
attn = self.net(x)
attn = rearrange(attn, "b g h w -> b 1 g h w")
x = rearrange(x, "b (g c) h w -> b c g h w", g=self.num_output_tokens)
x = reduce(x * attn, "b c g h w -> b c g", "mean")
x = unpack_one(x, ps, "* c n")
return x
# Robotic Transformer
class RT1Config:
def __init__(
self,
num_actions=11,
action_bins=256,
depth=6,
heads=8,
dim_head=64,
token_learner_ff_mult=2,
token_learner_num_layers=2,
token_learner_num_output_tokens=8,
cond_drop_prob=0.2,
use_attn_conditioner=False,
):
"""Configuration class to store the configuration of a `RT1`.
Args:
num_actions (int): Number of actions for the classification task
action_bins (int): Number of bins for each action
depth (int): Number of transformer blocks
heads (int): Number of heads for the transformer
dim_head (int): Dimension of the head
token_learner_ff_mult (int): Multiplier for the token learner
token_learner_num_layers (int): Number of layers for the token learner
token_learner_num_output_tokens (int): Number of output tokens for the token learner
cond_drop_prob (float): Dropout probability
use_attn_conditioner (bool): Whether to use the attention conditioner
"""
self.num_actions = num_actions
self.action_bins = action_bins
self.depth = depth
self.heads = heads
self.dim_head = dim_head
self.token_learner_ff_mult = token_learner_ff_mult
self.token_learner_num_layers = token_learner_num_layers
self.token_learner_num_output_tokens = token_learner_num_output_tokens
self.cond_drop_prob = cond_drop_prob
self.use_attn_conditioner = use_attn_conditioner
# @beartype
class RT1(nn.Module):
def __init__(
self,
config: RT1Config,
vit: FilmMaxVit,
conditioner_kwargs: dict = dict(),
):
super().__init__()
self.vit = vit
self.num_vit_stages = len(vit.cond_hidden_dims)
film_layer = (
FilmAttentionTextConditioner
if config.use_attn_conditioner
else FilmTextConditioner
)
self.conditioner = film_layer(
hidden_dims=(
*tuple(vit.cond_hidden_dims),
*((vit.embed_dim,) * config.depth * 2),
),
hiddens_channel_first=(
*((True,) * self.num_vit_stages),
*((False,) * config.depth * 2),
),
cond_drop_prob=config.cond_drop_prob,
**conditioner_kwargs,
)
self.token_learner = TokenLearner(
dim=vit.embed_dim,
ff_mult=config.token_learner_ff_mult,
num_output_tokens=config.token_learner_num_output_tokens,
num_layers=config.token_learner_num_layers,
)
self.num_learned_tokens = config.token_learner_num_output_tokens
self.transformer_depth = config.depth
self.transformer = Transformer(
dim=vit.embed_dim,
dim_head=config.dim_head,
heads=config.heads,
depth=config.depth,
)
self.norm = nn.LayerNorm(config.embed_dim)
self.cond_drop_prob = config.cond_drop_prob
self.to_logits = nn.Sequential(
nn.LayerNorm(vit.embed_dim),
nn.Linear(vit.embed_dim, config.num_actions * config.action_bins),
Rearrange("... (a b) -> ... a b", b=config.action_bins),
)
def embed_texts(self, texts: List[str]):
return self.conditioner.embed_texts(texts)
@classifier_free_guidance
def forward(
self,
video,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None,
cond_drop_prob=0.0,
):
assert exists(texts) ^ exists(text_embeds)
cond_kwargs = dict(texts=texts, text_embeds=text_embeds)
depth = self.transformer_depth
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
frames, device = video.shape[2], video.device
cond_fns, _ = self.conditioner(
**cond_kwargs,
cond_drop_prob=cond_drop_prob,
repeat_batch=(
*((frames,) * self.num_vit_stages),
*((1,) * self.transformer_depth * 2),
),
)
vit_cond_fns, transformer_cond_fns = (
cond_fns[: -(depth * 2)],
cond_fns[-(depth * 2) :],
)
video = rearrange(video, "b c f h w -> b f c h w")
images, packed_shape = pack_one(video, "* c h w")
tokens = self.vit(
images,
texts=texts,
cond_fns=vit_cond_fns,
cond_drop_prob=cond_drop_prob,
return_embeddings=True,
)
tokens = unpack_one(tokens, packed_shape, "* c h w")
learned_tokens = self.token_learner(tokens)
learned_tokens = rearrange(learned_tokens, "b f c n -> b (f n) c")
# causal attention mask
attn_mask = torch.ones((frames, frames), dtype=torch.bool, device=device).triu(
1
)
attn_mask = repeat(
attn_mask,
"i j -> (i r1) (j r2)",
r1=self.num_learned_tokens,
r2=self.num_learned_tokens,
)
# sinusoidal positional embedding
pos_emb = posemb_sincos_1d(
frames,
learned_tokens.shape[-1],
dtype=learned_tokens.dtype,
device=learned_tokens.device,
)
learned_tokens = learned_tokens + repeat(
pos_emb, "n d -> (n r) d", r=self.num_learned_tokens
)
# attention
attended_tokens = self.transformer(
learned_tokens,
cond_fns=transformer_cond_fns,
attn_mask=~attn_mask,
)
pooled = reduce(attended_tokens, "b (f n) d -> b f d", "mean", f=frames)
logits = self.to_logits(pooled)
return logits
class RTX1(nn.Module):
"""
A class for real-time video processing using Vision Transformers (ViT) and Reinforcement Learning (RT1) models.
...
Attributes
----------
vit : FilmMaxVit
a Vision Transformer model
model : RT1
a reinforcement learning model
Methods
-------
train(video, instructions):
Computes the logits for the given video and instructions using the RT1 model in training mode.
eval(video, instructions, cond_scale=1.0):
Computes the logits for the given video and instructions using the RT1 model in evaluation mode.
"""
def __init__(
self,
rt1_config: RT1Config = None,
vit_config: FilmViTConfig = None,
):
"""
Constructs all the necessary attributes for the RTX1 object.
Parameters
----------
rt1_config : RT1Config, optional
a configuration object for the RT1 model (default is None)
vit_config : FilmViTConfig, optional
a configuration object for the ViT model (default is None)
Example:
import torch
from rtx import RTX1
model = RTX1()
video = torch.randn(2, 3, 6, 224, 224)
instructions = ["bring me that apple sitting on the table", "please pass the butter"]
# compute the train logits
train_logits = model.train(video, instructions)
# set the model to evaluation mode
model.model.eval()
# compute the eval logits with a conditional scale of 3
eval_logits = model.run(video, instructions, cond_scale=3.0)
print(eval_logits.shape)
"""
super().__init__()
if rt1_config is None:
rt1_config = RT1Config()
if vit_config is None:
vit_config = FilmViTConfig()
self.vit = FilmMaxVit(vit_config)
self.model = RT1(
config=rt1_config,
vit=self.vit,
)
def train(self, video, instructions):
"""
Computes the logits for the given video and instructions using the RT1 model in training mode.
Parameters
----------
video : torch.Tensor
a tensor containing the video data
instructions : torch.Tensor
a tensor containing the instructions
Returns
-------
torch.Tensor
a tensor containing the computed logits
"""
try:
train_logits = self.model(video, instructions)
return train_logits
except Exception as e:
raise RuntimeError("Error in training: {}".format(e))
def run(self, video, instructions, cond_scale=1.0):
"""
Computes the logits for the given video and instructions using the RT1 model in evaluation mode.
Parameters
----------
video : torch.Tensor
a tensor containing the video data
instructions : torch.Tensor
a tensor containing the instructions
cond_scale : float, optional
a scale factor for the conditional scaling (default is 1.0)
Returns
-------
torch.Tensor
a tensor containing the computed logits
"""
try:
self.model.eval()
# shape => 2, 3, 6, 224, 224
eval_logits = self.model(video, instructions, cond_scale=cond_scale)
return eval_logits
except Exception as e:
raise RuntimeError("Error in evaluation: {}".format(e))
================================================
FILE: rtx/rtx2.py
================================================
import torch
from torch import nn
from zeta.structs import (
AutoregressiveWrapper,
Decoder,
Encoder,
Transformer,
ViTransformerWrapper,
)
class RTX2(torch.nn.Module):
"""
RTX2 is a transformer architecture that uses a ViT encoder and a transformer decoder.
Args:
image_size (int): Size of the image.
patch_size (int): Size of the patch.
encoder_dim (int): Dimension of the encoder.
encoder_depth (int): Depth of the encoder.
encoder_heads (int): Number of heads in the encoder.
num_tokens (int): Number of tokens.
max_seq_len (int): Maximum sequence length.
decoder_dim (int): Dimension of the decoder.
decoder_depth (int): Depth of the decoder.
decoder_heads (int): Number of heads in the decoder.
alibi_num_heads (int): Number of heads in the alibi attention.
attn_kv_heads (int): Number of heads in the attention key-value projection.
use_abs_pos_emb (bool): Whether to use absolute positional embeddings.
cross_attend (bool): Whether to cross attend in the decoder.
alibi_pos_bias (bool): Whether to use positional bias in the alibi attention.
rotary_xpos (bool): Whether to use rotary positional embeddings.
attn_flash (bool): Whether to use attention flash.
qk_norm (bool): Whether to normalize the query and key in the attention layer.
Returns:
torch.Tensor: The output of the model.
Usage:
>>> img = torch.randn(1, 3, 256, 256)
>>> text = torch.randint(0, 20000, (1, 1024))
>>> model = RTX2()
>>> output = model(img, text)
>>> print(output)
"""
def __init__(
self,
image_size=256,
patch_size=32,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
num_tokens=20000,
max_seq_len=1024,
decoder_dim=512,
decoder_depth=6,
decoder_heads=8,
alibi_num_heads=4,
attn_kv_heads=2,
use_abs_pos_emb=False,
cross_attend=True,
alibi_pos_bias=True,
rotary_xpos=True,
attn_flash=True,
qk_norm=True,
*args,
**kwargs,
):
super(RTX2, self).__init__()
# vit architecture
self.encoder = ViTransformerWrapper(
image_size=image_size,
patch_size=patch_size,
attn_layers=Encoder(
dim=encoder_dim,
depth=encoder_depth,
heads=encoder_heads,
),
)
# palm model architecture
self.decoder = Transformer(
num_tokens=num_tokens,
max_seq_len=max_seq_len,
use_abs_pos_emb=use_abs_pos_emb,
attn_layers=Decoder(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
cross_attend=cross_attend,
alibi_pos_bias=alibi_pos_bias,
alibi_num_heads=alibi_num_heads,
rotary_xpos=rotary_xpos,
attn_kv_heads=attn_kv_heads,
attn_flash=attn_flash,
qk_norm=qk_norm,
*args,
**kwargs,
),
)
# autoregressive wrapper to enable generation of tokens
self.decoder = AutoregressiveWrapper(self.decoder)
# Norm
self.norm = nn.LayerNorm(encoder_dim)
def forward(self, img: torch.Tensor, text: torch.Tensor):
"""Forward pass of the model."""
try:
encoded = self.encoder(img, return_embeddings=True)
encoded = self.norm(encoded)
encoded = self.norm(encoded)
return self.decoder(text, context=encoded)
except Exception as error:
print(f"Failed in forward method: {error}")
raise
================================================
FILE: rtx2_example.py
================================================
import torch
from rtx import RTX2
def run():
# usage
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
model = RTX2()
output = model(img, text)
print(output)
if __name__ == "__main__":
run()
================================================
FILE: run.py
================================================
from examples import rtx1_example, train_example
from rtx import RTX1, RTX2
from rtx.rtx1 import FilmViTConfig
from absl import app, flags, logging
from . import rtx2_example
REGISTRY = {
"rtx1": RTX1,
"rtx2": RTX2,
}
MODES = ["inference", "train"]
EXAMPLE_SCRIPTS = {
"rtx1": rtx1_example,
"rtx2": rtx2_example,
}
FLAGS = flags.FLAGS
flags.DEFINE_boolean(
"pretrained_vit", False, "Whether to use a pretrained ViT as a backbone or not."
)
flags.DEFINE_enum("model", "rtx1", REGISTRY.keys(), "Model to choose from.")
flags.DEFINE_enum("mode", "inference", MODES, "Experiment mode to run.")
def main(_):
if FLAGS.mode == "inference":
EXAMPLE_SCRIPTS[FLAGS.model].run()
elif FLAGS.mode == "train":
if FLAGS.pretrained_vit and FLAGS.model == "rtx2":
logging.fatal(
"Option `pretrained_vit` is not available for model {} ".format(
FLAGS.model
)
)
model = REGISTRY[FLAGS.model](
vit_config=FilmViTConfig(pretrained=FLAGS.pretrained_vit)
)
train_example.run(model)
if __name__ == "__main__":
app.run(main)
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/test_data_utils.py
================================================
import io
import numpy as np
import torch
from PIL import Image
from rtx.data_util import describe, format_imgs, preprocess
def test_describe():
dic = {
"key1": "value1",
"key2": [1, 2, 3],
"key3": {"nested_key": "nested_value"},
}
describe(dic)
def test_describe_empty():
dic = {}
describe(dic)
def test_describe_non_dict():
non_dict = "not a dict"
describe(non_dict)
def test_preprocess():
dic = {
"key1": "value1",
"key2": [1, 2, 3],
"key3": {"nested_key": "nested_value"},
}
result = preprocess(dic)
assert result == dic
def test_preprocess_empty():
dic = {}
result = preprocess(dic)
assert result == dic
def test_preprocess_non_dict():
non_dict = "not a dict"
result = preprocess(non_dict)
assert result == non_dict
def test_preprocess_none_value():
dic = {"key1": None}
result = preprocess(dic)
assert result == {}
def test_preprocess_image():
img = Image.new("RGB", (60, 30), color="red")
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
result = preprocess(img_byte_arr)
assert isinstance(result, np.ndarray)
def test_format_imgs():
dic = {
"key1": "value1",
"key2": [1, 2, 3],
"key3": {"nested_key": "nested_value"},
}
result = format_imgs(dic, 224)
assert result == dic
def test_format_imgs_empty():
dic = {}
result = format_imgs(dic, 224)
assert result == dic
def test_format_imgs_non_dict():
non_dict = "not a dict"
result = format_imgs(non_dict, 224)
assert result == non_dict
def test_format_imgs_image():
img = Image.new("RGB", (60, 30), color="red")
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
result = format_imgs(img_byte_arr, 224)
assert isinstance(result, np.ndarray)
def test_format_imgs_tensor():
tensor = torch.tensor([1, 2, 3])
result = format_imgs(tensor, 224)
assert isinstance(result, torch.Tensor)
def test_format_imgs_list():
list_val = [1, 2, 3]
result = format_imgs(list_val, 224)
assert result == list_val
def test_format_imgs_nested_dict():
dic = {"key1": {"nested_key": "nested_value"}}
result = format_imgs(dic, 224)
assert result == dic
def test_format_imgs_nested_list():
dic = {"key1": [1, 2, 3]}
result = format_imgs(dic, 224)
assert result == dic
def test_format_imgs_nested_image():
img = Image.new("RGB", (60, 30), color="red")
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
dic = {"key1": img_byte_arr}
result = format_imgs(dic, 224)
assert isinstance(result["key1"], np.ndarray)
def test_format_imgs_nested_tensor():
tensor = torch.tensor([1, 2, 3])
dic = {"key1": tensor}
result = format_imgs(dic, 224)
assert isinstance(result["key1"], torch.Tensor)
def test_format_imgs_nested_list():
list_val = [1, 2, 3]
dic = {"key1": list_val}
result = format_imgs(dic, 224)
assert result == dic
================================================
FILE: tests/test_rtx1.py
================================================
import unittest
import torch
from rtx.rtx1 import RTX1, FilmViTConfig, RT1Config
class RTX1Test(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.num_frames = 6
self.num_actions = 11
self.num_action_bins = 256
self.video = torch.randn(self.batch_size, 3, self.num_frames, 224, 224)
self.instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
rt1_config = RT1Config(
num_actions=self.num_actions,
action_bins=self.num_action_bins,
)
self.rtx1 = RTX1(rt1_config)
self.rtx1_pretrained = RTX1(rt1_config, FilmViTConfig(pretrained=True))
self.expected_logits_shape = torch.Size(
[
self.batch_size,
self.num_frames,
self.num_actions,
self.num_action_bins,
]
)
def test_default_pretrained_has_same_shape(self):
# Tests the general shape as the pretrained version from pytorch has
# different layernorm and conv2dnorm implementations.
assert len(self.rtx1.vit.layers) == len(self.rtx1_pretrained.vit.layers)
def test_default_train_eval(self):
train_logits = self.rtx1.train(self.video, self.instructions)
assert train_logits.shape == self.expected_logits_shape
self.rtx1.model.eval()
# compute the eval logits with a conditional scale of 3
eval_logits = self.rtx1.run(self.video, self.instructions, cond_scale=3.0)
assert eval_logits.shape == self.expected_logits_shape
def test_pretrained_train_eval(self):
train_logits = self.rtx1_pretrained.train(self.video, self.instructions)
assert train_logits.shape == self.expected_logits_shape
self.rtx1.model.eval()
# compute the eval logits with a conditional scale of 3
eval_logits = self.rtx1_pretrained.run(
self.video, self.instructions, cond_scale=3.0
)
assert eval_logits.shape == self.expected_logits_shape
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/tests.py
================================================
import pytest
import torch
from PIL import Image
from zeta.structs import (
AutoregressiveWrapper,
ViTransformerWrapper,
)
from rtx.efficient_net import EfficientNetFilm
from rtx.rtx1 import RT1, RTX1, FilmMaxVit
from rtx.rtx2 import RTX2
########################### EfficientNetFilm ###########################
img = "img.jpeg"
# Fixture to create an instance of the EfficientNetFilm class
@pytest.fixture
def efficientnet_model():
model = EfficientNetFilm("efficientnet-b0", 10)
return model
# Test case to check if EfficientNetFilm initializes correctly
def test_efficientnet_init(efficientnet_model):
assert efficientnet_model is not None
# Test case to check if EfficientNetFilm processes an image correctly
def test_efficientnet_process_image(efficientnet_model):
# Load a sample image
image_path = img
Image.open(image_path)
# Process the image using the model
features = efficientnet_model(image_path)
# Check if the output features are of the correct shape
assert isinstance(features, torch.Tensor)
assert features.shape == (1, efficientnet_model.num_features)
# Test case to check if EfficientNetFilm handles image resizing correctly
def test_efficientnet_image_resize(efficientnet_model):
# Load a sample image
image_path = img
image = Image.open(image_path)
# Process the image using the model
efficientnet_model(image_path)
# Check if the input image was resized to the specified size
assert image.size == (
efficientnet_model.resize,
efficientnet_model.resize,
)
# Test case to check if EfficientNetFilm handles model loading correctly
def test_efficientnet_model_loading(efficientnet_model):
# Check if the model was loaded successfully
assert efficientnet_model.model is not None
# Test case to check if EfficientNetFilm handles image transformations correctly
def test_efficientnet_image_transformations(efficientnet_model):
# Load a sample image
image_path = img
Image.open(image_path)
# Process the image using the model
features = efficientnet_model(image_path)
# Check if image transformations were applied correctly
assert torch.max(features).item() <= 1.0
assert torch.min(features).item() >= -1.0
# Test case to check if EfficientNetFilm handles the number of classes correctly
def test_efficientnet_num_classes(efficientnet_model):
# Check if the number of classes is set correctly
assert efficientnet_model.num_classes == 10
# Test case to check if EfficientNetFilm handles missing image file correctly
def test_efficientnet_missing_image(efficientnet_model):
with pytest.raises(FileNotFoundError):
efficientnet_model("non_existent_image.jpg")
# Test case to check if EfficientNetFilm handles incorrect image file format correctly
def test_efficientnet_incorrect_image_format(efficientnet_model):
with pytest.raises(ValueError):
efficientnet_model("sample_image.txt")
# Test case to check if EfficientNetFilm handles model selection correctly
def test_efficientnet_model_selection():
# Check if different EfficientNet models can be selected
model_names = [
"efficientnet-b0",
"efficientnet-b1",
"efficientnet-b2",
]
for model_name in model_names:
model = EfficientNetFilm(model_name, 10)
assert model is not None
assert model.model is not None
# Test case to check if EfficientNetFilm handles invalid model name correctly
def test_efficientnet_invalid_model_name():
with pytest.raises(ValueError):
EfficientNetFilm("invalid_model", 10)
# Test case to check if EfficientNetFilm handles invalid number of classes correctly
def test_efficientnet_invalid_num_classes():
with pytest.raises(ValueError):
EfficientNetFilm("efficientnet-b0", -10)
# Test case to check if EfficientNetFilm handles invalid resize size correctly
def test_efficientnet_invalid_resize_size():
with pytest.raises(ValueError):
EfficientNetFilm("efficientnet-b0", 10, resize=-100)
# Test case to check if EfficientNetFilm handles input image with incorrect channels correctly
def test_efficientnet_incorrect_image_channels(efficientnet_model):
# Create an image with incorrect number of channels (4 channels)
image = Image.new(
"RGBA",
(efficientnet_model.resize, efficientnet_model.resize),
(255, 0, 0, 255),
)
image_path = "incorrect_channels_image.png"
image.save(image_path)
with pytest.raises(ValueError):
efficientnet_model(image_path)
# Test case to check if EfficientNetFilm handles input image with incorrect size correctly
def test_efficientnet_incorrect_image_size(efficientnet_model):
# Create an image with incorrect size (smaller than resize size)
image = Image.new(
"RGB",
(
efficientnet_model.resize - 1,
efficientnet_model.resize - 1,
),
(255, 0, 0),
)
image_path = "incorrect_size_image.jpg"
image.save(image_path)
with pytest.raises(ValueError):
efficientnet_model(image_path)
########################### RTX1 ###########################
# Fixture to create an instance of the RTX1 class
@pytest.fixture
def rtx1_model():
model = RTX1()
return model
# Test case to check if RTX1 initializes correctly
def test_rtx1_initialization(rtx1_model):
assert isinstance(rtx1_model, RTX1)
assert isinstance(rtx1_model.vit, FilmMaxVit)
assert isinstance(rtx1_model.model, RT1)
# Test case to check if RTX1 handles training with video and instructions correctly
def test_rtx1_train(rtx1_model):
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
train_logits = rtx1_model.train(video, instructions)
assert isinstance(train_logits, torch.Tensor)
assert train_logits.shape == (2, rtx1_model.num_actions)
# Test case to check if RTX1 handles evaluation with video and instructions correctly
def test_rtx1_eval(rtx1_model):
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
eval_logits = rtx1_model.run(video, instructions, cond_scale=3.0)
assert isinstance(eval_logits, torch.Tensor)
assert eval_logits.shape == (2, rtx1_model.num_actions)
# Test case to check if RTX1 raises an error when training with invalid inputs
def test_rtx1_train_with_invalid_inputs(rtx1_model):
with pytest.raises(RuntimeError):
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
# Intentionally set an invalid shape for instructions
instructions = instructions[:1] # Instructions shape should be (2,)
rtx1_model.train(video, instructions)
# Test case to check if RTX1 raises an error when evaluating with invalid inputs
def test_rtx1_eval_with_invalid_inputs(rtx1_model):
with pytest.raises(RuntimeError):
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
# Intentionally set an invalid shape for video
video = video[:, :, :5] # Video shape should be (2, 3, 6, 224, 224)
rtx1_model.run(video, instructions, cond_scale=3.0)
# Test case to check if RTX1 handles conditional scaling correctly
def test_rtx1_conditional_scaling(rtx1_model):
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
"bring me that apple sitting on the table",
"please pass the butter",
]
eval_logits = rtx1_model.run(video, instructions, cond_scale=3.0)
eval_logits_without_scaling = rtx1_model.run(video, instructions)
# Check if the logits with and without scaling are different
assert not torch.allclose(eval_logits, eval_logits_without_scaling)
# Test case to check if RTX1 handles model selection correctly
def test_rtx1_model_selection():
model_names = [
"efficientnet-b0",
"efficientnet-b1",
"efficientnet-b2",
]
for model_name in model_names:
model = RTX1(model_name=model_name)
assert isinstance(model, RTX1)
# Test case to check if RTX1 raises an error for an invalid model name
def test_rtx1_invalid_model_name():
with pytest.raises(ValueError):
RTX1(model_name="invalid_model")
# Test case to check if RTX1 handles negative number of classes correctly
def test_rtx1_negative_num_classes():
with pytest.raises(ValueError):
RTX1(num_classes=-100)
# Test case to check if RTX1 handles negative dimension correctly
def test_rtx1_negative_dimension():
with pytest.raises(ValueError):
RTX1(dim=-96)
# Test case to check if RTX1 handles negative dimension of convolutional stem correctly
def test_rtx1_negative_dim_conv_stem():
with pytest.raises(ValueError):
RTX1(dim_conv_stem=-64)
# Test case to check if RTX1 handles negative dimension of head for ViT correctly
def test_rtx1_negative_dim_head_vit():
with pytest.raises(ValueError):
RTX1(dim_head_vit=-32)
# Test case to check if RTX1 handles negative depth of ViT correctly
def test_rtx1_negative_depth_vit():
with pytest.raises(ValueError):
RTX1(depth_vit=(-2, 2, 5, 2))
# Test case to check if RTX1 handles negative window size for ViT correctly
def test_rtx1_negative_window_size():
with pytest.raises(ValueError):
RTX1(window_size=-7)
# Test case to check if RTX1 handles negative expansion rate for mbconv correctly
def test_rtx1_negative_mbconv_expansion_rate():
with pytest.raises(ValueError):
RTX1(mbconv_expansion_rate=-4)
# Test case to check if RTX1 handles negative shrinkage rate for mbconv correctly
def test_rtx1_negative_mbconv_shrinkage_rate():
with pytest.raises(ValueError):
RTX1(mbconv_shrinkage_rate=-0.25)
# Test case to check if RTX1 handles negative dropout rate for ViT correctly
def test_rtx1_negative_dropout_vit():
with pytest.raises(ValueError):
RTX1(dropout_vit=-0.1)
# Test case to check if RTX1 handles negative number of actions correctly
def test_rtx1_negative_num_actions():
with pytest.raises(ValueError):
RTX1(num_actions=-11)
# Test case to check if RTX1 handles negative depth of RT1 correctly
def test_rtx1_negative_depth_rt1():
with pytest.raises(ValueError):
RTX1(depth_rt1=-6)
# Test case to check if RTX1 handles negative number of heads for RT1 correctly
def test_rtx1_negative_heads():
with pytest.raises(ValueError):
RTX1(heads=-8)
# Test case to check if RTX1 handles negative dimension of head for RT1 correctly
def test_rtx1_negative_dim_head_rt1():
with pytest.raises(ValueError):
RTX1(dim_head_rt1=-64)
# Test case to check if RTX1 handles negative conditional drop probability for RT1 correctly
def test_rtx1_negative_cond_drop_prob():
with pytest.raises(ValueError):
RTX1(cond_drop_prob=-0.2)
########################### RTX2 ###########################
# Fixture to create an instance of the RTX2 class
@pytest.fixture
def rtx2_model():
model = RTX2()
return model
# Test case to check if RTX2 initializes correctly
def test_rtx2_initialization(rtx2_model):
assert isinstance(rtx2_model, RTX2)
assert isinstance(rtx2_model.encoder, ViTransformerWrapper)
assert isinstance(rtx2_model.decoder, AutoregressiveWrapper)
# Test case to check if RTX2 handles forward pass with image and text correctly
def test_rtx2_forward_pass(rtx2_model):
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
output = rtx2_model(img, text)
assert isinstance(output, torch.Tensor)
# Test case to check if RTX2 raises an error when forwarding with invalid inputs
def test_rtx2_forward_with_invalid_inputs(rtx2_model):
with pytest.raises(Exception):
img = torch.randn(1, 3, 256, 256)
text = torch.randn(1, 1024, 512) # Invalid shape for text input
rtx2_model(img, text)
# Test case to check if RTX2 handles various model configurations correctly
def test_rtx2_with_different_configs():
config_combinations = [
{"encoder_depth": 6, "decoder_depth": 6},
{"encoder_depth": 4, "decoder_depth": 8},
{"encoder_heads": 8, "decoder_heads": 8},
{"encoder_dim": 512, "decoder_dim": 768},
]
for config in config_combinations:
model = RTX2(**config)
assert isinstance(model, RTX2)
assert model.encoder.attn_layers.depth == config["encoder_depth"]
assert model.decoder.attn_layers.depth == config["decoder_depth"]
if "encoder_heads" in config:
assert model.encoder.attn_layers.heads == config["encoder_heads"]
if "decoder_heads" in config:
assert model.decoder.attn_layers.heads == config["decoder_heads"]
if "encoder_dim" in config:
assert model.encoder.attn_layers.dim == config["encoder_dim"]
if "decoder_dim" in config:
assert model.decoder.attn_layers.dim == config["decoder_dim"]
# Test case to check if RTX2 handles negative image size correctly
def test_rtx2_negative_image_size():
with pytest.raises(ValueError):
RTX2(image_size=-256)
# Test case to check if RTX2 handles negative patch size correctly
def test_rtx2_negative_patch_size():
with pytest.raises(ValueError):
RTX2(patch_size=-32)
# Test case to check if RTX2 handles negative encoder dimension correctly
def test_rtx2_negative_encoder_dim():
with pytest.raises(ValueError):
RTX2(encoder_dim=-512)
# Test case to check if RTX2 handles negative encoder depth correctly
def test_rtx2_negative_encoder_depth():
with pytest.raises(ValueError):
RTX2(encoder_depth=-6)
# Test case to check if RTX2 handles negative decoder dimension correctly
def test_rtx2_negative_decoder_dim():
with pytest.raises(ValueError):
RTX2(decoder_dim=-512)
# Test case to check if RTX2 handles negative decoder depth correctly
def test_rtx2_negative_decoder_depth():
with pytest.raises(ValueError):
RTX2(decoder_depth=-6)
# Test case to check if RTX2 handles negative encoder heads correctly
def test_rtx2_negative_encoder_heads():
with pytest.raises(ValueError):
RTX2(encoder_heads=-8)
# Test case to check if RTX2 handles negative decoder heads correctly
def test_rtx2_negative_decoder_heads():
with pytest.raises(ValueError):
RTX2(decoder_heads=-8)
gitextract_yvc81ecv/
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ ├── PULL_REQUEST_TEMPLATE.yml
│ ├── dependabot.yml
│ ├── labeler.yml
│ └── workflows/
│ ├── docs.yml
│ ├── label.yml
│ ├── publish.yml
│ ├── pull-request-links.yml
│ ├── pylint.yml
│ ├── python-publish.yml
│ ├── stale.yml
│ ├── test.yml
│ ├── unit_test.yml
│ └── welcome.yml
├── .gitignore
├── LICENSE
├── README.md
├── examples/
│ ├── __init__.py
│ ├── efficient_net_example.py
│ ├── rtx1_example.py
│ └── train_example.py
├── pyproject.toml
├── requirements.txt
├── rtx/
│ ├── __init__.py
│ ├── data_util.py
│ ├── efficient_net.py
│ ├── rtx1.py
│ └── rtx2.py
├── rtx2_example.py
├── run.py
└── tests/
├── __init__.py
├── test_data_utils.py
├── test_rtx1.py
└── tests.py
SYMBOL INDEX (139 symbols across 11 files)
FILE: examples/rtx1_example.py
function run (line 5) | def run(pretrained=False):
FILE: examples/train_example.py
function run (line 5) | def run(
FILE: rtx/data_util.py
function map_np (line 10) | def map_np(input: np.ndarray, idxs: list[int], fn: callable) -> None:
function write_dict_to (line 29) | def write_dict_to(name: str, writer: SummaryWriter, input: dict, step: i...
function describe (line 52) | def describe(dic, prefix="", str_built=[]) -> str:
function preprocess (line 95) | def preprocess(dic: any, height=224, width=224):
function format_imgs (line 127) | def format_imgs(dic: any, sz: int):
FILE: rtx/efficient_net.py
class EfficientNetFilm (line 7) | class EfficientNetFilm(nn.Module):
method __init__ (line 33) | def __init__(
method __call__ (line 56) | def __call__(self, img: str):
FILE: rtx/rtx1.py
function exists (line 20) | def exists(val):
function default (line 24) | def default(val, d):
function cast_tuple (line 28) | def cast_tuple(val, length=1):
function pack_one (line 32) | def pack_one(x, pattern):
function unpack_one (line 36) | def unpack_one(x, ps, pattern):
function posemb_sincos_1d (line 43) | def posemb_sincos_1d(seq, dim, temperature=10000, device=None, dtype=tor...
class Residual (line 56) | class Residual(nn.Module):
method __init__ (line 57) | def __init__(self, fn):
method forward (line 61) | def forward(self, x):
class FeedForward (line 65) | class FeedForward(nn.Module):
method __init__ (line 66) | def __init__(self, dim, mult=4, dropout=0.0):
method forward (line 79) | def forward(self, x, cond_fn=None):
class SqueezeExcitation (line 92) | class SqueezeExcitation(nn.Module):
method __init__ (line 93) | def __init__(self, dim, shrinkage_rate=0.25):
method forward (line 106) | def forward(self, x):
class MBConvResidual (line 110) | class MBConvResidual(nn.Module):
method __init__ (line 111) | def __init__(self, fn, dropout=0.0):
method forward (line 116) | def forward(self, x):
class Dropsample (line 122) | class Dropsample(nn.Module):
method __init__ (line 123) | def __init__(self, prob=0):
method forward (line 127) | def forward(self, x):
function MBConv (line 140) | def MBConv(
class Attention (line 180) | class Attention(nn.Module):
method __init__ (line 181) | def __init__(self, dim, dim_head=32, dropout=0.0, window_size=7):
method forward (line 215) | def forward(self, x):
class FilmViTConfig (line 284) | class FilmViTConfig:
method __init__ (line 287) | def __init__(
class FilmMaxVit (line 365) | class FilmMaxVit(nn.Module):
method __init__ (line 366) | def __init__(
method forward (line 489) | def forward(
class TransformerAttention (line 518) | class TransformerAttention(nn.Module):
method __init__ (line 519) | def __init__(
method forward (line 548) | def forward(
class Transformer (line 604) | class Transformer(nn.Module):
method __init__ (line 605) | def __init__(
method forward (line 630) | def forward(
class TokenLearner (line 654) | class TokenLearner(nn.Module):
method __init__ (line 660) | def __init__(self, *, dim, ff_mult=2, num_output_tokens=8, num_layers=2):
method forward (line 681) | def forward(self, x):
class RT1Config (line 697) | class RT1Config:
method __init__ (line 698) | def __init__(
class RT1 (line 738) | class RT1(nn.Module):
method __init__ (line 739) | def __init__(
method embed_texts (line 796) | def embed_texts(self, texts: List[str]):
method forward (line 800) | def forward(
class RTX1 (line 884) | class RTX1(nn.Module):
method __init__ (line 905) | def __init__(
method train (line 955) | def train(self, video, instructions):
method run (line 978) | def run(self, video, instructions, cond_scale=1.0):
FILE: rtx/rtx2.py
class RTX2 (line 12) | class RTX2(torch.nn.Module):
method __init__ (line 51) | def __init__(
method forward (line 114) | def forward(self, img: torch.Tensor, text: torch.Tensor):
FILE: rtx2_example.py
function run (line 5) | def run():
FILE: run.py
function main (line 28) | def main(_):
FILE: tests/test_data_utils.py
function test_describe (line 10) | def test_describe():
function test_describe_empty (line 19) | def test_describe_empty():
function test_describe_non_dict (line 24) | def test_describe_non_dict():
function test_preprocess (line 29) | def test_preprocess():
function test_preprocess_empty (line 39) | def test_preprocess_empty():
function test_preprocess_non_dict (line 45) | def test_preprocess_non_dict():
function test_preprocess_none_value (line 51) | def test_preprocess_none_value():
function test_preprocess_image (line 57) | def test_preprocess_image():
function test_format_imgs (line 66) | def test_format_imgs():
function test_format_imgs_empty (line 76) | def test_format_imgs_empty():
function test_format_imgs_non_dict (line 82) | def test_format_imgs_non_dict():
function test_format_imgs_image (line 88) | def test_format_imgs_image():
function test_format_imgs_tensor (line 97) | def test_format_imgs_tensor():
function test_format_imgs_list (line 103) | def test_format_imgs_list():
function test_format_imgs_nested_dict (line 109) | def test_format_imgs_nested_dict():
function test_format_imgs_nested_list (line 115) | def test_format_imgs_nested_list():
function test_format_imgs_nested_image (line 121) | def test_format_imgs_nested_image():
function test_format_imgs_nested_tensor (line 131) | def test_format_imgs_nested_tensor():
function test_format_imgs_nested_list (line 138) | def test_format_imgs_nested_list():
FILE: tests/test_rtx1.py
class RTX1Test (line 6) | class RTX1Test(unittest.TestCase):
method setUp (line 7) | def setUp(self):
method test_default_pretrained_has_same_shape (line 34) | def test_default_pretrained_has_same_shape(self):
method test_default_train_eval (line 40) | def test_default_train_eval(self):
method test_pretrained_train_eval (line 50) | def test_pretrained_train_eval(self):
FILE: tests/tests.py
function efficientnet_model (line 20) | def efficientnet_model():
function test_efficientnet_init (line 26) | def test_efficientnet_init(efficientnet_model):
function test_efficientnet_process_image (line 31) | def test_efficientnet_process_image(efficientnet_model):
function test_efficientnet_image_resize (line 45) | def test_efficientnet_image_resize(efficientnet_model):
function test_efficientnet_model_loading (line 61) | def test_efficientnet_model_loading(efficientnet_model):
function test_efficientnet_image_transformations (line 67) | def test_efficientnet_image_transformations(efficientnet_model):
function test_efficientnet_num_classes (line 81) | def test_efficientnet_num_classes(efficientnet_model):
function test_efficientnet_missing_image (line 87) | def test_efficientnet_missing_image(efficientnet_model):
function test_efficientnet_incorrect_image_format (line 93) | def test_efficientnet_incorrect_image_format(efficientnet_model):
function test_efficientnet_model_selection (line 99) | def test_efficientnet_model_selection():
function test_efficientnet_invalid_model_name (line 113) | def test_efficientnet_invalid_model_name():
function test_efficientnet_invalid_num_classes (line 119) | def test_efficientnet_invalid_num_classes():
function test_efficientnet_invalid_resize_size (line 125) | def test_efficientnet_invalid_resize_size():
function test_efficientnet_incorrect_image_channels (line 131) | def test_efficientnet_incorrect_image_channels(efficientnet_model):
function test_efficientnet_incorrect_image_size (line 146) | def test_efficientnet_incorrect_image_size(efficientnet_model):
function rtx1_model (line 168) | def rtx1_model():
function test_rtx1_initialization (line 174) | def test_rtx1_initialization(rtx1_model):
function test_rtx1_train (line 181) | def test_rtx1_train(rtx1_model):
function test_rtx1_eval (line 195) | def test_rtx1_eval(rtx1_model):
function test_rtx1_train_with_invalid_inputs (line 209) | def test_rtx1_train_with_invalid_inputs(rtx1_model):
function test_rtx1_eval_with_invalid_inputs (line 222) | def test_rtx1_eval_with_invalid_inputs(rtx1_model):
function test_rtx1_conditional_scaling (line 235) | def test_rtx1_conditional_scaling(rtx1_model):
function test_rtx1_model_selection (line 250) | def test_rtx1_model_selection():
function test_rtx1_invalid_model_name (line 262) | def test_rtx1_invalid_model_name():
function test_rtx1_negative_num_classes (line 268) | def test_rtx1_negative_num_classes():
function test_rtx1_negative_dimension (line 274) | def test_rtx1_negative_dimension():
function test_rtx1_negative_dim_conv_stem (line 280) | def test_rtx1_negative_dim_conv_stem():
function test_rtx1_negative_dim_head_vit (line 286) | def test_rtx1_negative_dim_head_vit():
function test_rtx1_negative_depth_vit (line 292) | def test_rtx1_negative_depth_vit():
function test_rtx1_negative_window_size (line 298) | def test_rtx1_negative_window_size():
function test_rtx1_negative_mbconv_expansion_rate (line 304) | def test_rtx1_negative_mbconv_expansion_rate():
function test_rtx1_negative_mbconv_shrinkage_rate (line 310) | def test_rtx1_negative_mbconv_shrinkage_rate():
function test_rtx1_negative_dropout_vit (line 316) | def test_rtx1_negative_dropout_vit():
function test_rtx1_negative_num_actions (line 322) | def test_rtx1_negative_num_actions():
function test_rtx1_negative_depth_rt1 (line 328) | def test_rtx1_negative_depth_rt1():
function test_rtx1_negative_heads (line 334) | def test_rtx1_negative_heads():
function test_rtx1_negative_dim_head_rt1 (line 340) | def test_rtx1_negative_dim_head_rt1():
function test_rtx1_negative_cond_drop_prob (line 346) | def test_rtx1_negative_cond_drop_prob():
function rtx2_model (line 356) | def rtx2_model():
function test_rtx2_initialization (line 362) | def test_rtx2_initialization(rtx2_model):
function test_rtx2_forward_pass (line 369) | def test_rtx2_forward_pass(rtx2_model):
function test_rtx2_forward_with_invalid_inputs (line 379) | def test_rtx2_forward_with_invalid_inputs(rtx2_model):
function test_rtx2_with_different_configs (line 387) | def test_rtx2_with_different_configs():
function test_rtx2_negative_image_size (line 411) | def test_rtx2_negative_image_size():
function test_rtx2_negative_patch_size (line 417) | def test_rtx2_negative_patch_size():
function test_rtx2_negative_encoder_dim (line 423) | def test_rtx2_negative_encoder_dim():
function test_rtx2_negative_encoder_depth (line 429) | def test_rtx2_negative_encoder_depth():
function test_rtx2_negative_decoder_dim (line 435) | def test_rtx2_negative_decoder_dim():
function test_rtx2_negative_decoder_depth (line 441) | def test_rtx2_negative_decoder_depth():
function test_rtx2_negative_encoder_heads (line 447) | def test_rtx2_negative_encoder_heads():
function test_rtx2_negative_decoder_heads (line 453) | def test_rtx2_negative_decoder_heads():
Condensed preview — 36 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (90K chars).
[
{
"path": ".github/FUNDING.yml",
"chars": 673,
"preview": "# These are supported funding model platforms\n\ngithub: [kyegomez]\npatreon: # Replace with a single Patreon username\nopen"
},
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 682,
"preview": "---\nname: Bug report\nabout: Create a detailed report on the bug and it's root cause. Conduct root cause error analysis\nt"
},
{
"path": ".github/ISSUE_TEMPLATE/feature_request.md",
"chars": 603,
"preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: 'kyegomez'\n\n---\n\n**Is "
},
{
"path": ".github/PULL_REQUEST_TEMPLATE.yml",
"chars": 1004,
"preview": "<!-- Thank you for contributing to Zeta!\n\nReplace this comment with:\n - Description: a description of the change, \n - "
},
{
"path": ".github/dependabot.yml",
"chars": 366,
"preview": "# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configu"
},
{
"path": ".github/labeler.yml",
"chars": 420,
"preview": "Documentation:\n- changed-files:\n - any-glob-to-any-file: '**/*.md'\n - any-glob-to-any-file: 'docs/**'\n\n# Add 'feature'"
},
{
"path": ".github/workflows/docs.yml",
"chars": 387,
"preview": "name: Docs WorkFlow\n\non:\n push:\n branches:\n - master\n - main\n - develop\njobs:\n deploy:\n runs-on: "
},
{
"path": ".github/workflows/label.yml",
"chars": 539,
"preview": "# This workflow will triage pull requests and apply a label based on the\n# paths that are modified in the pull request.\n"
},
{
"path": ".github/workflows/publish.yml",
"chars": 1288,
"preview": "name: Supervision Releases to PyPi\non:\n push:\n tags:\n - '[0-9]+.[0-9]+[0-9]+.[0-9]'\n - '[0-9]+.[0-9]+[0-9]"
},
{
"path": ".github/workflows/pull-request-links.yml",
"chars": 299,
"preview": "name: readthedocs/actions\non:\n pull_request_target:\n types:\n - opened\n paths:\n - \"docs/**\"\n\npermissions"
},
{
"path": ".github/workflows/pylint.yml",
"chars": 553,
"preview": "name: Pylint\n\non: [push]\n\njobs:\n build:\n runs-on: ubuntu-latest\n strategy:\n matrix:\n python-version: "
},
{
"path": ".github/workflows/python-publish.yml",
"chars": 653,
"preview": "\nname: Upload Python Package\n\non:\n release:\n types: [published]\n\npermissions:\n contents: read\n\njobs:\n deploy:\n\n "
},
{
"path": ".github/workflows/stale.yml",
"chars": 716,
"preview": "# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.\n#\n# You c"
},
{
"path": ".github/workflows/test.yml",
"chars": 1404,
"preview": "# name: test\n\n# on:\n# push:\n# branches: [master]\n# pull_request:\n# workflow_dispatch:\n\n# env:\n# POETRY_VERSI"
},
{
"path": ".github/workflows/unit_test.yml",
"chars": 571,
"preview": "name: \"python 3.11 | 3.10\"\n\non:\n push:\n branches: [ main ]\n pull_request:\n branches: [ main ]\n\njobs:\n build_and"
},
{
"path": ".github/workflows/welcome.yml",
"chars": 548,
"preview": "name: Welcome WorkFlow\n\non:\n issues:\n types: [opened]\n pull_request_target:\n types: [opened]\n\njobs:\n build:\n "
},
{
"path": ".gitignore",
"chars": 3142,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 1074,
"preview": "MIT License\n\nCopyright (c) 2023 Eternal Reclaimer\n\nPermission is hereby granted, free of charge, to any person obtaining"
},
{
"path": "README.md",
"chars": 6063,
"preview": "[](https://discord.gg/qUtxnK2NMf)\n\n# RT-X\nPytorch implementation of the models RT-1-X "
},
{
"path": "examples/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/efficient_net_example.py",
"chars": 121,
"preview": "from rtx.efficient_net import EfficientNetFilm\n\nmodel = EfficientNetFilm(\"efficientnet-b0\", 10)\n\nout = model(\"img.jpeg\")"
},
{
"path": "examples/rtx1_example.py",
"chars": 824,
"preview": "import torch\nfrom rtx.rtx1 import RTX1, FilmViTConfig\n\n\ndef run(pretrained=False):\n \"\"\"Run RT-X1 example.\n\n Args:\n"
},
{
"path": "examples/train_example.py",
"chars": 122,
"preview": "import torch\nfrom absl import logging\n\n\ndef run(\n model: torch.nn.Module,\n):\n logging.fatal(\"Not yet implemented.\""
},
{
"path": "pyproject.toml",
"chars": 1189,
"preview": "[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poetry]\nname = \"rtx-to"
},
{
"path": "requirements.txt",
"chars": 210,
"preview": "torch\ntorchvision\ntorch-tb-profiler==0.4.3\ntensorboardx==2.6.2.2\ntensorboard==2.16.2\nolefile==0.47\neinops==0.7.0\nefficie"
},
{
"path": "rtx/__init__.py",
"chars": 275,
"preview": "from rtx.rtx2 import RTX2\nfrom rtx.rtx1 import RTX1\nfrom rtx.efficient_net import EfficientNetFilm\nfrom rtx.data_util im"
},
{
"path": "rtx/data_util.py",
"chars": 4474,
"preview": "import io\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom tensorboardX import SummaryWriter\n\nArrayLike = np.n"
},
{
"path": "rtx/efficient_net.py",
"chars": 1891,
"preview": "from torch import nn\nfrom efficientnet_pytorch import EfficientNet\nfrom torchvision import transforms\nfrom PIL import Im"
},
{
"path": "rtx/rtx1.py",
"chars": 28922,
"preview": "from functools import partial\nimport torch\nfrom torch import nn, einsum, Tensor\nfrom typing import List, Optional, Calla"
},
{
"path": "rtx/rtx2.py",
"chars": 3915,
"preview": "import torch\nfrom torch import nn\nfrom zeta.structs import (\n AutoregressiveWrapper,\n Decoder,\n Encoder,\n Tr"
},
{
"path": "rtx2_example.py",
"chars": 250,
"preview": "import torch\nfrom rtx import RTX2\n\n\ndef run():\n # usage\n img = torch.randn(1, 3, 256, 256)\n text = torch.randin"
},
{
"path": "run.py",
"chars": 1170,
"preview": "from examples import rtx1_example, train_example\nfrom rtx import RTX1, RTX2\nfrom rtx.rtx1 import FilmViTConfig\nfrom absl"
},
{
"path": "tests/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/test_data_utils.py",
"chars": 3190,
"preview": "import io\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom rtx.data_util import describe, format_imgs, prepr"
},
{
"path": "tests/test_rtx1.py",
"chars": 2154,
"preview": "import unittest\nimport torch\nfrom rtx.rtx1 import RTX1, FilmViTConfig, RT1Config\n\n\nclass RTX1Test(unittest.TestCase):\n "
},
{
"path": "tests/tests.py",
"chars": 14747,
"preview": "import pytest\nimport torch\nfrom PIL import Image\nfrom zeta.structs import (\n AutoregressiveWrapper,\n ViTransformer"
}
]
About this extraction
This page contains the full source code of the kyegomez/RT-X GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 36 files (82.5 KB), approximately 21.3k tokens, and a symbol index with 139 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.