[
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: [kyegomez]\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with a single Open Collective username\nko_fi: # Replace with a single Ko-fi username\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\notechie: # Replace with a single Otechie username\nlfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry\ncustom: #Nothing\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a detailed report on the bug and it's root cause. Conduct root cause error analysis\ntitle: \"[BUG] \"\nlabels: bug\nassignees: kyegomez\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: 'kyegomez'\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.yml",
    "content": "<!-- Thank you for contributing to Zeta!\n\nReplace this comment with:\n  - Description: a description of the change, \n  - Issue: the issue # it fixes (if applicable),\n  - Dependencies: any dependencies required for this change,\n  - Tag maintainer: for a quicker response, tag the relevant maintainer (see below),\n  - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out!\n\nIf you're adding a new integration, please include:\n  1. a test for the integration, preferably unit tests that do not rely on network access,\n  2. an example notebook showing its use.\n\nMaintainer responsibilities:\n  - nn / Misc / if you don't know who to tag: kye@apac.ai\n  - tokenizers: kye@apac.ai\n  - training / Prompts: kye@apac.ai\n  - models: kye@apac.ai\n\nIf no one reviews your PR within a few days, feel free to kye@apac.ai\n\nSee contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/kyegomez/zeta"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates\n\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n\n  - package-ecosystem: \"pip\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n\n"
  },
  {
    "path": ".github/labeler.yml",
    "content": "Documentation:\n- changed-files:\n  - any-glob-to-any-file: '**/*.md'\n  - any-glob-to-any-file: 'docs/**'\n\n# Add 'feature' label to any PR where the head branch name starts with `feature` or has a `feature` section in the name\nfeature:\n - head-branch: ['^feature', 'feature']\n\n# Add 'bug' label to any PR where the head branch name starts with `bug` or has a `bug` section in the name\nbug:\n - head-branch: ['^bug', 'bug']\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: Docs WorkFlow\n\non:\n  push:\n    branches:\n      - master\n      - main\n      - develop\njobs:\n  deploy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.x\n      - run: pip install mkdocs-material\n      - run: pip install \"mkdocstrings[python]\"\n      - run: mkdocs gh-deploy --force"
  },
  {
    "path": ".github/workflows/label.yml",
    "content": "# This workflow will triage pull requests and apply a label based on the\n# paths that are modified in the pull request.\n#\n# To use this workflow, you will need to set up a .github/labeler.yml\n# file with configuration.  For more information, see:\n# https://github.com/actions/labeler\n\nname: Labeler\non: [pull_request_target]\n\njobs:\n  label:\n\n    runs-on: ubuntu-latest\n    permissions:\n      contents: read\n      pull-requests: write\n\n    steps:\n    - uses: actions/labeler@v5\n      with:\n        repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: Supervision Releases to PyPi\non:\n  push:\n    tags:\n      - '[0-9]+.[0-9]+[0-9]+.[0-9]'\n      - '[0-9]+.[0-9]+[0-9]+.[0-9]'\n      - '[0-9]+.[0-9]+[0-9]+.[0-9]'\n\n  # Allows you to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [3.8]\n    steps:\n      - name: 🛎️ Checkout\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ github.head_ref }}\n      - name: 🐍 Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name:  🏗️ Build source and wheel distributions\n        run: |\n          python -m pip install --upgrade build twine\n          python -m build\n          twine check --strict dist/*\n      - name: 🚀 Publish to PyPi\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          user: ${{ secrets.PYPI_USERNAME }}\n          password: ${{ secrets.PYPI_PASSWORD }}\n      - name: 🚀 Publish to Test-PyPi\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://test.pypi.org/legacy/\n          user: ${{ secrets.PYPI_TEST_USERNAME }}\n          password: ${{ secrets.PYPI_TEST_PASSWORD }}"
  },
  {
    "path": ".github/workflows/pull-request-links.yml",
    "content": "name: readthedocs/actions\non:\n  pull_request_target:\n    types:\n      - opened\n    paths:\n      - \"docs/**\"\n\npermissions:\n  pull-requests: write\n\njobs:\n  pull-request-links:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: readthedocs/actions/preview@v1\n        with:\n          project-slug: zeta"
  },
  {
    "path": ".github/workflows/pylint.yml",
    "content": "name: Pylint\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.8\", \"3.9\", \"3.10\"]\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v5\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install pylint\n    - name: Analysing the code with pylint\n      run: |\n        pylint $(git ls-files '*.py')\n"
  },
  {
    "path": ".github/workflows/python-publish.yml",
    "content": "\nname: Upload Python Package\n\non:\n  release:\n    types: [published]\n\npermissions:\n  contents: read\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python\n      uses: actions/setup-python@v5\n      with:\n        python-version: '3.x'\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install build\n    - name: Build package\n      run: python -m build\n    - name: Publish package\n      uses: pypa/gh-action-pypi-publish@81e9d935c883d0b210363ab89cf05f3894778450\n      with:\n        user: __token__\n        password: ${{ secrets.PYPI_API_TOKEN }}"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.\n#\n# You can adjust the behavior by modifying this file.\n# For more information, see:\n# https://github.com/actions/stale\nname: Mark stale issues and pull requests\n\non:\n  schedule:\n  - cron: '26 12 * * *'\n\njobs:\n  stale:\n\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n      pull-requests: write\n\n    steps:\n    - uses: actions/stale@v9\n      with:\n        repo-token: ${{ secrets.GITHUB_TOKEN }}\n        stale-issue-message: 'Stale issue message'\n        stale-pr-message: 'Stale pull request message'\n        stale-issue-label: 'no-issue-activity'\n        stale-pr-label: 'no-pr-activity'"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "# name: test\n\n# on:\n#   push:\n#     branches: [master]\n#   pull_request:\n#   workflow_dispatch:\n\n# env:\n#   POETRY_VERSION: \"1.4.2\"\n\n# jobs:\n#   build:\n#     runs-on: ubuntu-latest\n#     strategy:\n#       matrix:\n#         python-version:\n#           - \"3.8\"\n#           - \"3.9\"\n#           - \"3.10\"\n#           - \"3.11\"\n#         test_type:\n#           - \"core\"\n#           - \"extended\"\n#     name: Python ${{ matrix.python-version }} ${{ matrix.test_type }}\n#     steps:\n#       - uses: actions/checkout@v4\n#       - name: Set up Python ${{ matrix.python-version }}\n#         uses: \"./.github/actions/poetry_setup\"\n#         with:\n#           python-version: ${{ matrix.python-version }}\n#           poetry-version: \"1.4.2\"\n#           cache-key: ${{ matrix.test_type }}\n#           install-command: |\n#               if [ \"${{ matrix.test_type }}\" == \"core\" ]; then\n#                 echo \"Running core tests, installing dependencies with poetry...\"\n#                 poetry install\n#               else\n#                 echo \"Running extended tests, installing dependencies with poetry...\"\n#                 poetry install -E extended_testing\n#               fi\n#       - name: Run ${{matrix.test_type}} tests\n#         run: |\n#           if [ \"${{ matrix.test_type }}\" == \"core\" ]; then\n#             make test\n#           else\n#             make extended_tests\n#           fi\n#         shell: bash"
  },
  {
    "path": ".github/workflows/unit_test.yml",
    "content": "name: \"python 3.11 | 3.10\"\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  build_and_test:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version:\n          - '3.11'\n          - '3.10'\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Setup Python\n      uses: actions/setup-python@v5\n      with:\n        python-version:  ${{ matrix.python-version }}\n\n    - name: Install dependencies\n      run: pip install -r requirements.txt\n\n    - name: Verify integration test results\n      run: python3 -m unittest \n\n\n"
  },
  {
    "path": ".github/workflows/welcome.yml",
    "content": "name: Welcome WorkFlow\n\non:\n  issues:\n    types: [opened]\n  pull_request_target:\n    types: [opened]\n\njobs:\n  build:\n    name: 👋 Welcome\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/first-interaction@v1.3.0\n        with:\n          repo-token: ${{ secrets.GITHUB_TOKEN }}\n          issue-message: \"Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.\"\n          pr-message:  \"Hello there, thank you for opening an PR ! 🙏🏻 The team was notified and they will get back to you asap.\""
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\n.ruff_cache/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\ndatasets_cache/\n**/*checkpoints**/*\n.DS_Store\nruns/"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Eternal Reclaimer\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n# RT-X\nPytorch implementation of the models RT-1-X and RT-2-X from the paper: \"Open X-Embodiment: Robotic Learning Datasets and RT-X Models\".\n\nHere we implement both model architectures, RTX-1 and RTX-2\n\n[Paper Link](https://robotics-transformer-x.github.io/)\n\n- The RTX-2 Implementation does not natively output for simplicity a 7 dimensional vector but rather text tokens, if you wanted to output 7 dimensional vector you could implement the same token learner as in RTX1\n\n\n# Appreciation\n* Lucidrains\n* Agorians\n\n# Install\n`pip install rtx-torch `\n\n# Usage\nTo see detailed usage, run `python run.py --help`.\n## RTX1\n- RTX1 Usage takes in text and videos\n- Does not use Efficient Net yet, we're integrating it now then the implementation will be complete\n- Uses SOTA transformer architecture\n\n```python\n\nimport torch\nfrom rtx.rtx1 import RTX1, FilmViTConfig\n\n# Use a pre-trained MaxVit model from pytorch\nmodel = RTX1(film_vit_config=FilmViTConfig(pretrained=pretrained))\n\nvideo = torch.randn(2, 3, 6, 224, 224)\n\ninstructions = [\"bring me that apple sitting on the table\", \"please pass the butter\"]\n\n# compute the train logits\ntrain_logits = model.train(video, instructions)\n\n# set the model to evaluation mode\nmodel.model.eval()\n\n# compute the eval logits with a conditional scale of 3\neval_logits = model.run(video, instructions, cond_scale=3.0)\nprint(eval_logits.shape)\n```\n\n\n## RTX-2\n- RTX-2 takes in images and text and interleaves them to form multi-modal sentences and outputs text tokens not a 7 dimensional vector of x,y,z,roll,pitch,yaw,and gripper\n```python\n\nimport torch\nfrom rtx import RTX2\n\n# usage\nimg = torch.randn(1, 3, 256, 256)\ntext = torch.randint(0, 20000, (1, 1024))\n\nmodel = RTX2()\noutput = model(img, text)\nprint(output)\n\n```\n\n## EfficientNetFilm\n- Extracts the feature from the given image\n```python\nfrom rtx import EfficientNetFilm\n\nmodel = EfficientNetFilm(\"efficientnet-b0\", 10)\n\nout = model(\"img.jpeg\")\n\n\n```\n# Model Differences from the Paper Implementation\n## RT-1\nThe main difference here is the substitution of a Film-EfficientNet backbone (pre-trained EfficientNet-B3 with Film layers inserted) with a MaxViT model.\n\n\n\n# Tests\nI created a single tests file that uses pytest to run tests on all the modules, RTX1, RTX2, EfficientNetFil, first git clone and get into the repository, install the requirements.txt with pip then run this:\n\n`python -m pytest tests/tests.py`\n\n# License\nMIT\n\n# Citations\n```bibtex\n@misc{open_x_embodiment_rt_x_2023,\ntitle={Open {X-E}mbodiment: Robotic Learning Datasets and {RT-X} Models},\nauthor = {Open X-Embodiment Collaboration and Abhishek Padalkar and Acorn Pooley and Ajinkya Jain and Alex Bewley and Alex Herzog and Alex Irpan and Alexander Khazatsky and Anant Rai and Anikait Singh and Anthony Brohan and Antonin Raffin and Ayzaan Wahid and Ben Burgess-Limerick and Beomjoon Kim and Bernhard Schölkopf and Brian Ichter and Cewu Lu and Charles Xu and Chelsea Finn and Chenfeng Xu and Cheng Chi and Chenguang Huang and Christine Chan and Chuer Pan and Chuyuan Fu and Coline Devin and Danny Driess and Deepak Pathak and Dhruv Shah and Dieter Büchler and Dmitry Kalashnikov and Dorsa Sadigh and Edward Johns and Federico Ceola and Fei Xia and Freek Stulp and Gaoyue Zhou and Gaurav S. Sukhatme and Gautam Salhotra and Ge Yan and Giulio Schiavi and Hao Su and Hao-Shu Fang and Haochen Shi and Heni Ben Amor and Henrik I Christensen and Hiroki Furuta and Homer Walke and Hongjie Fang and Igor Mordatch and Ilija Radosavovic and Isabel Leal and Jacky Liang and Jaehyung Kim and Jan Schneider and Jasmine Hsu and Jeannette Bohg and Jeffrey Bingham and Jiajun Wu and Jialin Wu and Jianlan Luo and Jiayuan Gu and Jie Tan and Jihoon Oh and Jitendra Malik and Jonathan Tompson and Jonathan Yang and Joseph J. Lim and João Silvério and Junhyek Han and Kanishka Rao and Karl Pertsch and Karol Hausman and Keegan Go and Keerthana Gopalakrishnan and Ken Goldberg and Kendra Byrne and Kenneth Oslund and Kento Kawaharazuka and Kevin Zhang and Keyvan Majd and Krishan Rana and Krishnan Srinivasan and Lawrence Yunliang Chen and Lerrel Pinto and Liam Tan and Lionel Ott and Lisa Lee and Masayoshi Tomizuka and Maximilian Du and Michael Ahn and Mingtong Zhang and Mingyu Ding and Mohan Kumar Srirama and Mohit Sharma and Moo Jin Kim and Naoaki Kanazawa and Nicklas Hansen and Nicolas Heess and Nikhil J Joshi and Niko Suenderhauf and Norman Di Palo and Nur Muhammad Mahi Shafiullah and Oier Mees and Oliver Kroemer and Pannag R Sanketi and Paul Wohlhart and Peng Xu and Pierre Sermanet and Priya Sundaresan and Quan Vuong and Rafael Rafailov and Ran Tian and Ria Doshi and Roberto Martín-Martín and Russell Mendonca and Rutav Shah and Ryan Hoque and Ryan Julian and Samuel Bustamante and Sean Kirmani and Sergey Levine and Sherry Moore and Shikhar Bahl and Shivin Dass and Shuran Song and Sichun Xu and Siddhant Haldar and Simeon Adebola and Simon Guist and Soroush Nasiriany and Stefan Schaal and Stefan Welker and Stephen Tian and Sudeep Dasari and Suneel Belkhale and Takayuki Osa and Tatsuya Harada and Tatsuya Matsushima and Ted Xiao and Tianhe Yu and Tianli Ding and Todor Davchev and Tony Z. Zhao and Travis Armstrong and Trevor Darrell and Vidhi Jain and Vincent Vanhoucke and Wei Zhan and Wenxuan Zhou and Wolfram Burgard and Xi Chen and Xiaolong Wang and Xinghao Zhu and Xuanlin Li and Yao Lu and Yevgen Chebotar and Yifan Zhou and Yifeng Zhu and Ying Xu and Yixuan Wang and Yonatan Bisk and Yoonyoung Cho and Youngwoon Lee and Yuchen Cui and Yueh-hua Wu and Yujin Tang and Yuke Zhu and Yunzhu Li and Yusuke Iwasawa and Yutaka Matsuo and Zhuo Xu and Zichen Jeff Cui},\nhowpublished  = {\\url{https://arxiv.org/abs/2310.08864}},\nyear = {2023},\n}\n```\n\n# Todo\n- Integrate EfficientNetFilm with RTX-1\n- Create training script for RTX-1 by unrolling observations and do basic cross entropy in first rt-1\n- Use RTX-2 dataset on huggingface\n- [Check out the project board for more tasks](https://github.com/users/kyegomez/projects/10/views/1)"
  },
  {
    "path": "examples/__init__.py",
    "content": ""
  },
  {
    "path": "examples/efficient_net_example.py",
    "content": "from rtx.efficient_net import EfficientNetFilm\n\nmodel = EfficientNetFilm(\"efficientnet-b0\", 10)\n\nout = model(\"img.jpeg\")\n"
  },
  {
    "path": "examples/rtx1_example.py",
    "content": "import torch\nfrom rtx.rtx1 import RTX1, FilmViTConfig\n\n\ndef run(pretrained=False):\n    \"\"\"Run RT-X1 example.\n\n    Args:\n        pretrained (bool, optional): Whether or not to use a pretrained MaxVit with film (downloads from pytorch).\n            Defaults to False.\n    \"\"\"\n    model = RTX1(vit_config=FilmViTConfig(pretrained=pretrained))\n\n    video = torch.randn(2, 3, 6, 224, 224)\n\n    instructions = [\n        \"bring me that apple sitting on the table\",\n        \"please pass the butter\",\n    ]\n\n    # compute the train logits\n    model.train(video, instructions)\n\n    # set the model to evaluation mode\n    model.model.eval()\n\n    # compute the eval logits with a conditional scale of 3\n    eval_logits = model.run(video, instructions, cond_scale=3.0)\n    print(eval_logits.shape)\n\n\nif __name__ == \"__main__\":\n    run()\n"
  },
  {
    "path": "examples/train_example.py",
    "content": "import torch\nfrom absl import logging\n\n\ndef run(\n    model: torch.nn.Module,\n):\n    logging.fatal(\"Not yet implemented.\")\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poetry]\nname = \"rtx-torch\"\nversion = \"0.1.3\"\ndescription = \"rtx - Pytorch\"\nlicense = \"MIT\"\nauthors = [\"Kye Gomez <kye@apac.ai>\"]\nhomepage = \"https://github.com/kyegomez/rt-x\"\ndocumentation = \"https://github.com/kyegomez/rt-x\"  # Replace if you have documentation.\nreadme = \"README.md\"  # Assuming you have a README.md\nrepository = \"https://github.com/kyegomez/rtx\"\nkeywords = [\"artificial intelligence\", \"deep learning\", \"optimizers\", \"Prompt Engineering\"]\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Intended Audience :: Developers\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Programming Language :: Python :: 3.6\"\n]\npackages = [\n    { include = \"rtx\" },\n    { include = \"rtx/**/*.py\" },\n]\n\n\n[tool.poetry.dependencies]\npython = \"^3.9,<3.12\"\ntorch = \"*\"\ntorchvision = \"^0.16.2\"\neinops = \"0.7.0\"\nefficientnet_pytorch = \"0.7.1\"\nzetascale = \"1.2.5\"\nclassifier-free-guidance-pytorch = \"0.5.3\"\nlz4 = \"^4.3.2\"\ntorch-tb-profiler = \"^0.4.3\"\ntensorboardX = \"^2.6.2.2\"\ntensorboard = \"^2.15.1\"\nolefile = \"^0.47\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\ntorch-tb-profiler==0.4.3\ntensorboardx==2.6.2.2\ntensorboard==2.16.2\nolefile==0.47\neinops==0.7.0\nefficientnet_pytorch==0.7.1\nzetascale==1.2.5\nclassifier-free-guidance-pytorch==0.5.3\ntensorboardX"
  },
  {
    "path": "rtx/__init__.py",
    "content": "from rtx.rtx2 import RTX2\nfrom rtx.rtx1 import RTX1\nfrom rtx.efficient_net import EfficientNetFilm\nfrom rtx.data_util import describe, format_imgs, preprocess\n\n__all__ = [\n    \"RTX2\",\n    \"RTX1\",\n    \"EfficientNetFilm\",\n    \"describe\",\n    \"format_imgs\",\n    \"preprocess\",\n]\n"
  },
  {
    "path": "rtx/data_util.py",
    "content": "import io\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom tensorboardX import SummaryWriter\n\nArrayLike = np.ndarray | list | torch.Tensor\n\n\ndef map_np(input: np.ndarray, idxs: list[int], fn: callable) -> None:\n    \"\"\"Maps a function through a numpy array.\n\n    Args:\n        input (np.ndarray): Input.\n        fn (callable): Function to map.\n\n    Returns: None\n    \"\"\"\n    if sum(input.shape) <= 1:\n        fn(input, idxs)\n        idxs.pop()\n        return\n\n    for i, x in enumerate(input):\n        idxs.append(i)\n        map_np(x, idxs, fn)\n\n\ndef write_dict_to(name: str, writer: SummaryWriter, input: dict, step: int):\n    \"\"\"Writes a dictionary to tensorboard.\n\n    Args:\n        name (str): Name of group to identify values in dict with.\n        writer (SummaryWriter): Tensorboard writer.\n        input (dict): Input dictionary.\n        step (int): Global step value.\n    \"\"\"\n    for k, v in input.items():\n        v = np.array(v).squeeze()\n        if sum(v.shape) <= 1:\n            writer.add_scalar(name + \"_\" + k, v, step)\n            continue\n        map_np(\n            v,\n            [],\n            lambda x, idxs: writer.add_scalar(\n                \"{}_{}-{}\".format(name, k, \"-\".join([str(i) for i in idxs])), x, step\n            ),\n        )\n\n\ndef describe(dic, prefix=\"\", str_built=[]) -> str:\n    \"\"\"Useful to print out the structure of TF Record. ds.info can also be used\n        but it does not show lengths of lists and dicts.\n\n    Args:\n        dic (dict): Input\n        prefix (str, optional): Prefix used for nested indentation. Defaults to ''.\n        str_built (str, optional): Desription string built so far. Defaults to ''.\n    \"\"\"\n    if not isinstance(dic, dict):\n        return \"\"\n\n    def describe_img(img: bytes):\n        img = Image.open(io.BytesIO(img))\n        return f\"{img.__class__.__name__} sz: { img.size}\"\n\n    for k, v in dic.items():\n        if isinstance(v, list):\n            list_type = \"\"\n            if len(v) > 0:\n                v_description = \"\"\n                if isinstance(v[0], torch.Tensor):\n                    v_description = f\"({tuple(v[0].size())}, {v[0].dtype})\"\n                elif isinstance(v[0], bytes):\n                    v_description = describe_img(v[0])\n                list_type = f\"({v[0].__class__.__name__ }{v_description})\"\n            print(f\"{prefix} {k}, {v.__class__.__name__}{list_type} sz:\" f\" {len(v)}\")\n            if len(v) > 0:\n                str_built.append(describe(v[0], prefix + \"  \"))\n        elif isinstance(v, dict):\n            print(f\"{prefix} {k}, {v.__class__.__name__} sz:\" f\" {len(v.items())}\")\n            describe(v, prefix + \"  \")\n        elif isinstance(v, bytes):\n            print(f\"{prefix} {k}, {describe_img( v)}\")\n        elif isinstance(v, str):\n            str_built.append(f\"{prefix} {k}, {v.__class__.__name__} v: {v}\\n\")\n        else:\n            tensor_type = \"\"\n            if isinstance(v, torch.Tensor):\n                tensor_type = f\"({tuple(v[0].size())}, {v[0].dtype})\"\n            print(f\"{prefix} {k}, {v.__class__.__name__} {tensor_type} \")\n\n\ndef preprocess(dic: any, height=224, width=224):\n    \"\"\"Remove nonetypes from a dict, convert images to numpy arrays and return.\n\n    Args:\n        dic (dict): Input.\n\n    Returns:\n        dict: Output.\n    \"\"\"\n    if isinstance(dic, bytes):\n        img = Image.open(io.BytesIO(dic))\n        return np.array(img.resize((width, height)))\n\n    if not isinstance(dic, dict):\n        return dic\n\n    to_remove = []\n    for k, v in dic.items():\n        if isinstance(v, list):\n            processed = []\n            for vv in v:\n                processed.append(preprocess(vv, height, width))\n            dic[k] = processed\n        elif v is None:\n            to_remove.append(k)\n        else:\n            dic[k] = preprocess(v, height, width)\n    for k in to_remove:\n        del dic[k]\n    return dic\n\n\ndef format_imgs(dic: any, sz: int):\n    \"\"\"Resizes images to sz as a numpy array.\n\n    Args:\n        dic (dict): Input.\n\n    Returns:\n        dict: Output.\n    \"\"\"\n    if isinstance(dic, bytes):\n        img = Image.open(io.BytesIO(dic))\n        return np.array(img.resize((sz, sz)))\n        return np.array(img.resize((sz, sz)))\n\n    if not isinstance(dic, dict):\n        return dic\n\n    for k, v in dic.items():\n        if isinstance(v, list):\n            for i in range(len(v)):\n                v[i] = format_imgs(v, sz)\n        else:\n            dic[k] = format_imgs(v, sz)\n    return dic\n"
  },
  {
    "path": "rtx/efficient_net.py",
    "content": "from torch import nn\nfrom efficientnet_pytorch import EfficientNet\nfrom torchvision import transforms\nfrom PIL import Image\n\n\nclass EfficientNetFilm(nn.Module):\n    \"\"\"\n    EfficientNet with FiLM layer\n\n    Args:\n        model (str): EfficientNet model name\n        num_classes (int): Number of classes\n        num_features (int): Number of features to output from the model\n        resize (int): Size to resize the image to\n\n    Attributes:\n        model (EfficientNet): EfficientNet model\n        num_classes (int): Number of classes\n        num_features (int): Number of features to output from the model\n        resize (int): Size to resize the image to\n        transform (torchvision.transforms.Compose): Image transformations\n\n    Example:\n        >>> model = EfficientNetFilm('efficientnet-b0', 10)\n        >>> img = Image.open('img.jpeg')\n        >>> features = model(img)\n        >>> features.shape\n        torch.Size([1, 1280])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_classes,\n        num_features=1280,\n        resize=224,\n    ):\n        super().__init__()\n        self.model = model\n        self.num_classes = num_classes\n        self.num_features = num_features\n        self.resize = resize\n\n        self.model = EfficientNet.from_pretrained(model)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(resize),\n                transforms.ToTensor(),\n                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n            ]\n        )\n\n    def __call__(self, img: str):\n        \"\"\"\n        Extract the feature embeddings from the image\n\n        Args:\n            img (str): Path to image\n        \"\"\"\n        img = Image.open(img)\n        img = self.transform(img).unsqueeze(0)\n        print(img.shape)\n\n        features = self.model.extract_features(img)\n        print(features.shape)\n"
  },
  {
    "path": "rtx/rtx1.py",
    "content": "from functools import partial\nimport torch\nfrom torch import nn, einsum, Tensor\nfrom typing import List, Optional, Callable, Tuple\n\n# from beartype import beartype\nfrom einops import pack, unpack, repeat, reduce, rearrange\nfrom einops.layers.torch import Rearrange, Reduce\n\nfrom classifier_free_guidance_pytorch import (\n    TextConditioner as FilmTextConditioner,\n    AttentionTextConditioner as FilmAttentionTextConditioner,\n    classifier_free_guidance,\n)\n\n\n# helpers\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\ndef cast_tuple(val, length=1):\n    return val if isinstance(val, tuple) else ((val,) * length)\n\n\ndef pack_one(x, pattern):\n    return pack([x], pattern)\n\n\ndef unpack_one(x, ps, pattern):\n    return unpack(x, ps, pattern)[0]\n\n\n# sinusoidal positions\n\n\ndef posemb_sincos_1d(seq, dim, temperature=10000, device=None, dtype=torch.float32):\n    n = torch.arange(seq, device=device)\n    omega = torch.arange(dim // 2, device=device) / (dim // 2 - 1)\n    omega = 1.0 / (temperature**omega)\n\n    n = n[:, None] * omega[None, :]\n    pos_emb = torch.cat((n.sin(), n.cos()), dim=1)\n    return pos_emb.type(dtype)\n\n\n# helper classes\n\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x):\n        return self.fn(x) + x\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, mult=4, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        self.norm = nn.LayerNorm(dim)\n\n        self.net = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU(),\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x, cond_fn=None):\n        x = self.norm(x)\n\n        if exists(cond_fn):\n            # adaptive layernorm\n            x = cond_fn(x)\n\n        return self.net(x)\n\n\n# MBConv\n\n\nclass SqueezeExcitation(nn.Module):\n    def __init__(self, dim, shrinkage_rate=0.25):\n        super().__init__()\n        hidden_dim = int(dim * shrinkage_rate)\n\n        self.gate = nn.Sequential(\n            Reduce(\"b c h w -> b c\", \"mean\"),\n            nn.Linear(dim, hidden_dim, bias=False),\n            nn.SiLU(),\n            nn.Linear(hidden_dim, dim, bias=False),\n            nn.Sigmoid(),\n            Rearrange(\"b c -> b c 1 1\"),\n        )\n\n    def forward(self, x):\n        return x * self.gate(x)\n\n\nclass MBConvResidual(nn.Module):\n    def __init__(self, fn, dropout=0.0):\n        super().__init__()\n        self.fn = fn\n        self.dropsample = Dropsample(dropout)\n\n    def forward(self, x):\n        out = self.fn(x)\n        out = self.dropsample(out)\n        return out + x\n\n\nclass Dropsample(nn.Module):\n    def __init__(self, prob=0):\n        super().__init__()\n        self.prob = prob\n\n    def forward(self, x):\n        device = x.device\n\n        if self.prob == 0.0 or (not self.training):\n            return x\n\n        keep_mask = (\n            torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()\n            > self.prob\n        )\n        return x * keep_mask / (1 - self.prob)\n\n\ndef MBConv(\n    dim_in,\n    dim_out,\n    *,\n    downsample,\n    expansion_rate=4,\n    shrinkage_rate=0.25,\n    dropout=0.0,\n):\n    hidden_dim = int(expansion_rate * dim_out)\n    stride = 2 if downsample else 1\n\n    net = nn.Sequential(\n        nn.Conv2d(dim_in, hidden_dim, 1),\n        nn.BatchNorm2d(hidden_dim),\n        nn.GELU(),\n        nn.Conv2d(\n            hidden_dim,\n            hidden_dim,\n            3,\n            stride=stride,\n            padding=1,\n            groups=hidden_dim,\n        ),\n        nn.BatchNorm2d(hidden_dim),\n        nn.GELU(),\n        SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),\n        nn.Conv2d(hidden_dim, dim_out, 1),\n        nn.BatchNorm2d(dim_out),\n    )\n\n    if dim_in == dim_out and not downsample:\n        net = MBConvResidual(net, dropout=dropout)\n\n    return net\n\n\n# attention related classes\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, dim_head=32, dropout=0.0, window_size=7):\n        super().__init__()\n        assert (\n            dim % dim_head\n        ) == 0, \"dimension should be divisible by dimension per head\"\n\n        self.norm = nn.LayerNorm(dim)\n\n        self.heads = dim // dim_head\n        self.scale = dim_head**-0.5\n\n        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)\n\n        self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))\n\n        self.to_out = nn.Sequential(\n            nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)\n        )\n\n        # relative positional bias\n\n        self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)\n\n        pos = torch.arange(window_size)\n        grid = torch.stack(torch.meshgrid(pos, pos, indexing=\"ij\"))\n        grid = rearrange(grid, \"c i j -> (i j) c\")\n        rel_pos = rearrange(grid, \"i ... -> i 1 ...\") - rearrange(\n            grid, \"j ... -> 1 j ...\"\n        )\n        rel_pos += window_size - 1\n        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)\n\n        self.register_buffer(\"rel_pos_indices\", rel_pos_indices, persistent=False)\n\n    def forward(self, x):\n        (\n            batch,\n            height,\n            width,\n            window_height,\n            window_width,\n            _,\n            device,\n            h,\n        ) = (\n            *x.shape,\n            x.device,\n            self.heads,\n        )\n\n        x = self.norm(x)\n\n        # flatten\n\n        x = rearrange(x, \"b x y w1 w2 d -> (b x y) (w1 w2) d\")\n\n        # project for queries, keys, values\n\n        q, k, v = self.to_qkv(x).chunk(3, dim=-1)\n\n        # split heads\n\n        q, k, v = map(\n            lambda t: rearrange(t, \"b n (h d ) -> b h n d\", h=h),\n            (q, k, v),\n        )\n\n        # scale\n\n        q = q * self.scale\n\n        # sim\n\n        sim = einsum(\"b h i d, b h j d -> b h i j\", q, k)\n\n        # add positional bias\n\n        bias = self.rel_pos_bias(self.rel_pos_indices)\n        sim = sim + rearrange(bias, \"i j h -> h i j\")\n\n        # attention\n\n        attn = self.attend(sim)\n\n        # aggregate\n\n        out = einsum(\"b h i j, b h j d -> b h i d\", attn, v)\n\n        # merge heads\n\n        out = rearrange(\n            out,\n            \"b h (w1 w2) d -> b w1 w2 (h d)\",\n            w1=window_height,\n            w2=window_width,\n        )\n\n        # combine heads out\n\n        out = self.to_out(out)\n        return rearrange(out, \"(b x y) ... -> b x y ...\", x=height, y=width)\n\n\nclass FilmViTConfig:\n    \"\"\"Configuration class to store the configuration of a `FilmMaxVit`.\"\"\"\n\n    def __init__(\n        self,\n        num_classes=1000,  # 1000 for ImageNet\n        input_channels=3,\n        stem_channels_in=64,  # Number of stem channels\n        dim_head=32,  # Attention head dimension\n        block_channel_ins: List = [\n            64,\n            128,\n            256,\n            512,\n        ],  # Number of channels for each ViT block\n        block_layers=[\n            2,\n            2,\n            5,\n            2,\n        ],  # Number of layers for each ViT block\n        window_size=7,  # Partition size\n        mbconv_expansion_rate=4,\n        mbconv_shrinkage_rate=0.25,  # MBConv squeeze ratio\n        dropout=0.1,\n        norm_layer: nn.Module = None,\n        activation_layer=nn.GELU,\n        stochastic_depth_prob=0.2,\n        pretrained=False,\n    ):\n        \"\"\"\n        Constructs a MaxVit architecture with optional film layers from\n        `MaxVit: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.\n            Parameters\n            ----------\n            num_classes : int\n                Number of classes for the classification task\n            input_channels : int\n                Number of input channels\n            stem_channels_in : int\n                Number of stem channels\n            dim_head : int\n                Dimension of the head\n            block_channel_ins : List\n                Number of channels for each ViT block\n            block_layers : List\n                Number of layers for each ViT block\n            window_size : int\n                Partition size\n            mbconv_expansion_rate : int\n                MBConv expansion rate\n            mbconv_shrinkage_rate : float\n                MBConv squeeze ratio\n            dropout : float\n                Dropout probability\n            norm_layer : nn.Module\n                Normalization layer\n            activation_layer : nn.Module\n                Activation layer\n            stochastic_depth_prob : float\n                Stochastic depth probability\n        \"\"\"\n        self.num_classes = num_classes\n        self.input_channels = input_channels\n        self.stem_channels_in = stem_channels_in\n        self.block_channel_ins = block_channel_ins\n        self.block_layers = block_layers\n        self.dim_head = dim_head\n        self.stem_channels_in = stem_channels_in\n        self.window_size = window_size\n        self.mbconv_expansion_rate = mbconv_expansion_rate\n        self.mbconv_shrinkage_rate = mbconv_shrinkage_rate\n        self.dropout = dropout\n        self.norm_layer = norm_layer\n        if self.norm_layer is None:\n            self.norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)\n        self.activation_layer = activation_layer\n        self.pretrained = pretrained\n        self.stochastic_depth_prob = stochastic_depth_prob\n\n\nclass FilmMaxVit(nn.Module):\n    def __init__(\n        self,\n        config: FilmViTConfig,\n    ):\n        super().__init__()\n        assert isinstance(config.block_layers, tuple | list), (\n            \"depth needs to be tuple if integers indicating number of\"\n            \" transformer blocks at that stage\"\n        )\n\n        # List of number of input and output channels for each ViT block.\n        in_channels: List = [config.stem_channels_in] + config.block_channel_ins[:-1]\n        out_channels: List = config.block_channel_ins\n\n        # Condition after each layer starting with the input to the stem block.\n        self.cond_hidden_dims = [config.stem_channels_in]  # Used by FilmTextConditioner\n        for block_in_channels, block_layers in zip(out_channels, config.block_layers):\n            for _ in range(block_layers):\n                self.cond_hidden_dims.append(block_in_channels)\n        self.cond_hidden_dims = self.cond_hidden_dims[\n            :-1\n        ]  # Don't condition on last embedding.\n        self.embed_dim = out_channels[-1]\n\n        if config.pretrained:\n            from torchvision.models import maxvit_t, MaxVit_T_Weights\n\n            self._vit = maxvit_t(weights=MaxVit_T_Weights.DEFAULT)\n            self.conv_stem = self._vit.stem\n            self.mlp_head = self._vit.classifier\n            self.layers = nn.ModuleList([])\n            for block in self._vit.blocks:\n                for layer in block.layers:\n                    self.layers.append(layer)\n            return\n\n        # convolutional stem\n        self.conv_stem = nn.Sequential(\n            nn.Conv2d(\n                config.input_channels,\n                config.stem_channels_in,\n                3,\n                stride=2,\n                padding=1,\n            ),\n            nn.Conv2d(\n                config.stem_channels_in,\n                config.stem_channels_in,\n                3,\n                padding=1,\n            ),\n        )\n        self.layers = nn.ModuleList([])\n\n        for (\n            block_channels_in,\n            block_channels_out,\n            block_num_layers,\n        ) in zip(in_channels, out_channels, config.block_layers):\n            for i in range(block_num_layers):\n                layer_channels_in = block_channels_in if i == 0 else block_channels_out\n\n                layer = nn.Sequential(\n                    MBConv(\n                        layer_channels_in,\n                        block_channels_out,\n                        downsample=(i == 0),\n                        expansion_rate=config.mbconv_expansion_rate,\n                        shrinkage_rate=config.mbconv_shrinkage_rate,\n                    ),\n                    Rearrange(\n                        \"b d (x w1) (y w2) -> b x y w1 w2 d\",\n                        w1=config.window_size,\n                        w2=config.window_size,\n                    ),  # block-like attention\n                    Residual(\n                        Attention(\n                            dim=block_channels_out,\n                            dim_head=config.dim_head,\n                            dropout=config.dropout,\n                            window_size=config.window_size,\n                        )\n                    ),\n                    Residual(\n                        FeedForward(\n                            dim=block_channels_out,\n                            dropout=config.dropout,\n                        )\n                    ),\n                    Rearrange(\"b x y w1 w2 d -> b d (x w1) (y w2)\"),\n                    Rearrange(\n                        \"b d (w1 x) (w2 y) -> b x y w1 w2 d\",\n                        w1=config.window_size,\n                        w2=config.window_size,\n                    ),  # grid-like attention\n                    Residual(\n                        Attention(\n                            dim=block_channels_out,\n                            dim_head=config.dim_head,\n                            dropout=config.dropout,\n                            window_size=config.window_size,\n                        )\n                    ),\n                    Residual(\n                        FeedForward(\n                            dim=block_channels_out,\n                            dropout=config.dropout,\n                        )\n                    ),\n                    Rearrange(\"b x y w1 w2 d -> b d (w1 x) (w2 y)\"),\n                )\n\n                self.layers.append(layer)\n\n        # mlp head out\n\n        self.mlp_head = nn.Sequential(\n            Reduce(\"b d h w -> b d\", \"mean\"),\n            nn.LayerNorm(self.embed_dim),\n            nn.Linear(self.embed_dim, config.num_classes, bias=False),\n        )\n\n    # @beartype\n    def forward(\n        self,\n        x,\n        texts: Optional[List[str]] = None,\n        cond_fns: Optional[Tuple[Callable, ...]] = None,\n        cond_drop_prob=0.0,\n        return_embeddings=False,\n    ):\n        x = self.conv_stem(x)\n\n        cond_fns = iter(default(cond_fns, []))\n\n        for stage in self.layers:\n            cond_fn = next(cond_fns, None)\n\n            if exists(cond_fn):\n                x = cond_fn(x)\n\n            x = stage(x)\n\n        if return_embeddings:\n            return x\n\n        return self.mlp_head(x)\n\n\n# attention\n\n\nclass TransformerAttention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        causal=False,\n        dim_head=64,\n        dim_context=None,\n        heads=8,\n        norm_context=False,\n        dropout=0.1,\n    ):\n        super().__init__()\n        self.heads = heads\n        self.scale = dim_head**-0.5\n        self.causal = causal\n        inner_dim = dim_head * heads\n\n        dim_context = default(dim_context, dim)\n\n        self.norm = nn.LayerNorm(dim)\n        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()\n\n        self.attn_dropout = nn.Dropout(dropout)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias=False)\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout)\n        )\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        attn_bias=None,\n        attn_mask=None,\n        cond_fn: Optional[Callable] = None,\n    ):\n        x.shape[0]\n\n        if exists(context):\n            context = self.context_norm(context)\n\n        kv_input = default(context, x)\n\n        x = self.norm(x)\n\n        if exists(cond_fn):\n            # adaptive layer-norm\n            x = cond_fn(x)\n\n        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = rearrange(q, \"b n (h d) -> b h n d\", h=self.heads)\n\n        q = q * self.scale\n\n        sim = einsum(\"b h i d, b j d -> b h i j\", q, k)\n\n        if exists(attn_bias):\n            sim = sim + attn_bias\n\n        if exists(attn_mask):\n            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)\n\n        if exists(mask):\n            mask = rearrange(mask, \"b j -> b 1 1 j\")\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        if self.causal:\n            i, j = sim.shape[-2:]\n            causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(\n                j - i + 1\n            )\n            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)\n\n        attn = sim.softmax(dim=-1)\n        attn = self.attn_dropout(attn)\n\n        out = einsum(\"b h i j, b j d -> b h i d\", attn, v)\n\n        out = rearrange(out, \"b h n d -> b n (h d)\")\n        return self.to_out(out)\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        dim,\n        dim_head=64,\n        heads=8,\n        depth=6,\n        attn_dropout=0.0,\n        ff_dropout=0.0,\n    ):\n        super().__init__()\n        self.layers = nn.ModuleList([])\n        self.norm = nn.LayerNorm(dim)\n\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        TransformerAttention(\n                            dim=dim, heads=heads, dropout=attn_dropout\n                        ),\n                        FeedForward(dim=dim, dropout=ff_dropout),\n                    ]\n                )\n            )\n\n    def forward(\n        self,\n        x,\n        cond_fns: Optional[Tuple[Callable, ...]] = None,\n        attn_mask=None,\n    ):\n        cond_fns = iter(default(cond_fns, []))\n        x = self.norm(x)\n        for attn, ff in self.layers:\n            x = (\n                attn(\n                    self.norm(x),\n                    attn_mask=attn_mask,\n                    cond_fn=next(cond_fns, None),\n                )\n                + x\n            )\n            x = ff(self.norm(x), cond_fn=next(cond_fns, None)) + x\n        return x\n\n\n# token learner module\n\n\nclass TokenLearner(nn.Module):\n    \"\"\"\n    https://arxiv.org/abs/2106.11297\n    using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map\n    \"\"\"\n\n    def __init__(self, *, dim, ff_mult=2, num_output_tokens=8, num_layers=2):\n        super().__init__()\n        inner_dim = dim * ff_mult * num_output_tokens\n\n        self.num_output_tokens = num_output_tokens\n        self.net = nn.Sequential(\n            nn.Conv2d(\n                dim * num_output_tokens,\n                inner_dim,\n                1,\n                groups=num_output_tokens,\n            ),\n            nn.GELU(),\n            nn.Conv2d(\n                inner_dim,\n                num_output_tokens,\n                1,\n                groups=num_output_tokens,\n            ),\n        )\n\n    def forward(self, x):\n        x, ps = pack_one(x, \"* c h w\")\n        x = repeat(x, \"b c h w -> b (g c) h w\", g=self.num_output_tokens)\n        attn = self.net(x)\n\n        attn = rearrange(attn, \"b g h w -> b 1 g h w\")\n        x = rearrange(x, \"b (g c) h w -> b c g h w\", g=self.num_output_tokens)\n\n        x = reduce(x * attn, \"b c g h w -> b c g\", \"mean\")\n        x = unpack_one(x, ps, \"* c n\")\n        return x\n\n\n# Robotic Transformer\n\n\nclass RT1Config:\n    def __init__(\n        self,\n        num_actions=11,\n        action_bins=256,\n        depth=6,\n        heads=8,\n        dim_head=64,\n        token_learner_ff_mult=2,\n        token_learner_num_layers=2,\n        token_learner_num_output_tokens=8,\n        cond_drop_prob=0.2,\n        use_attn_conditioner=False,\n    ):\n        \"\"\"Configuration class to store the configuration of a `RT1`.\n\n        Args:\n            num_actions (int): Number of actions for the classification task\n            action_bins (int): Number of bins for each action\n            depth (int): Number of transformer blocks\n            heads (int): Number of heads for the transformer\n            dim_head (int): Dimension of the head\n            token_learner_ff_mult (int): Multiplier for the token learner\n            token_learner_num_layers (int): Number of layers for the token learner\n            token_learner_num_output_tokens (int): Number of output tokens for the token learner\n            cond_drop_prob (float): Dropout probability\n            use_attn_conditioner (bool): Whether to use the attention conditioner\n        \"\"\"\n        self.num_actions = num_actions\n        self.action_bins = action_bins\n        self.depth = depth\n        self.heads = heads\n        self.dim_head = dim_head\n        self.token_learner_ff_mult = token_learner_ff_mult\n        self.token_learner_num_layers = token_learner_num_layers\n        self.token_learner_num_output_tokens = token_learner_num_output_tokens\n        self.cond_drop_prob = cond_drop_prob\n        self.use_attn_conditioner = use_attn_conditioner\n\n\n# @beartype\nclass RT1(nn.Module):\n    def __init__(\n        self,\n        config: RT1Config,\n        vit: FilmMaxVit,\n        conditioner_kwargs: dict = dict(),\n    ):\n        super().__init__()\n        self.vit = vit\n        self.num_vit_stages = len(vit.cond_hidden_dims)\n\n        film_layer = (\n            FilmAttentionTextConditioner\n            if config.use_attn_conditioner\n            else FilmTextConditioner\n        )\n\n        self.conditioner = film_layer(\n            hidden_dims=(\n                *tuple(vit.cond_hidden_dims),\n                *((vit.embed_dim,) * config.depth * 2),\n            ),\n            hiddens_channel_first=(\n                *((True,) * self.num_vit_stages),\n                *((False,) * config.depth * 2),\n            ),\n            cond_drop_prob=config.cond_drop_prob,\n            **conditioner_kwargs,\n        )\n\n        self.token_learner = TokenLearner(\n            dim=vit.embed_dim,\n            ff_mult=config.token_learner_ff_mult,\n            num_output_tokens=config.token_learner_num_output_tokens,\n            num_layers=config.token_learner_num_layers,\n        )\n\n        self.num_learned_tokens = config.token_learner_num_output_tokens\n\n        self.transformer_depth = config.depth\n\n        self.transformer = Transformer(\n            dim=vit.embed_dim,\n            dim_head=config.dim_head,\n            heads=config.heads,\n            depth=config.depth,\n        )\n\n        self.norm = nn.LayerNorm(config.embed_dim)\n\n        self.cond_drop_prob = config.cond_drop_prob\n\n        self.to_logits = nn.Sequential(\n            nn.LayerNorm(vit.embed_dim),\n            nn.Linear(vit.embed_dim, config.num_actions * config.action_bins),\n            Rearrange(\"... (a b) -> ... a b\", b=config.action_bins),\n        )\n\n    def embed_texts(self, texts: List[str]):\n        return self.conditioner.embed_texts(texts)\n\n    @classifier_free_guidance\n    def forward(\n        self,\n        video,\n        texts: Optional[List[str]] = None,\n        text_embeds: Optional[Tensor] = None,\n        cond_drop_prob=0.0,\n    ):\n        assert exists(texts) ^ exists(text_embeds)\n        cond_kwargs = dict(texts=texts, text_embeds=text_embeds)\n\n        depth = self.transformer_depth\n        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)\n\n        frames, device = video.shape[2], video.device\n\n        cond_fns, _ = self.conditioner(\n            **cond_kwargs,\n            cond_drop_prob=cond_drop_prob,\n            repeat_batch=(\n                *((frames,) * self.num_vit_stages),\n                *((1,) * self.transformer_depth * 2),\n            ),\n        )\n\n        vit_cond_fns, transformer_cond_fns = (\n            cond_fns[: -(depth * 2)],\n            cond_fns[-(depth * 2) :],\n        )\n\n        video = rearrange(video, \"b c f h w -> b f c h w\")\n        images, packed_shape = pack_one(video, \"* c h w\")\n\n        tokens = self.vit(\n            images,\n            texts=texts,\n            cond_fns=vit_cond_fns,\n            cond_drop_prob=cond_drop_prob,\n            return_embeddings=True,\n        )\n\n        tokens = unpack_one(tokens, packed_shape, \"* c h w\")\n        learned_tokens = self.token_learner(tokens)\n\n        learned_tokens = rearrange(learned_tokens, \"b f c n -> b (f n) c\")\n\n        # causal attention mask\n\n        attn_mask = torch.ones((frames, frames), dtype=torch.bool, device=device).triu(\n            1\n        )\n        attn_mask = repeat(\n            attn_mask,\n            \"i j -> (i r1) (j r2)\",\n            r1=self.num_learned_tokens,\n            r2=self.num_learned_tokens,\n        )\n\n        # sinusoidal positional embedding\n\n        pos_emb = posemb_sincos_1d(\n            frames,\n            learned_tokens.shape[-1],\n            dtype=learned_tokens.dtype,\n            device=learned_tokens.device,\n        )\n\n        learned_tokens = learned_tokens + repeat(\n            pos_emb, \"n d -> (n r) d\", r=self.num_learned_tokens\n        )\n\n        # attention\n\n        attended_tokens = self.transformer(\n            learned_tokens,\n            cond_fns=transformer_cond_fns,\n            attn_mask=~attn_mask,\n        )\n\n        pooled = reduce(attended_tokens, \"b (f n) d -> b f d\", \"mean\", f=frames)\n\n        logits = self.to_logits(pooled)\n        return logits\n\n\nclass RTX1(nn.Module):\n    \"\"\"\n    A class for real-time video processing using Vision Transformers (ViT) and Reinforcement Learning (RT1) models.\n\n    ...\n\n    Attributes\n    ----------\n    vit : FilmMaxVit\n        a Vision Transformer model\n    model : RT1\n        a reinforcement learning model\n\n    Methods\n    -------\n    train(video, instructions):\n        Computes the logits for the given video and instructions using the RT1 model in training mode.\n    eval(video, instructions, cond_scale=1.0):\n        Computes the logits for the given video and instructions using the RT1 model in evaluation mode.\n    \"\"\"\n\n    def __init__(\n        self,\n        rt1_config: RT1Config = None,\n        vit_config: FilmViTConfig = None,\n    ):\n        \"\"\"\n        Constructs all the necessary attributes for the RTX1 object.\n\n        Parameters\n        ----------\n        rt1_config : RT1Config, optional\n            a configuration object for the RT1 model (default is None)\n        vit_config : FilmViTConfig, optional\n            a configuration object for the ViT model (default is None)\n\n\n\n        Example:\n\n        import torch\n        from rtx import RTX1\n\n        model = RTX1()\n\n        video = torch.randn(2, 3, 6, 224, 224)\n\n        instructions = [\"bring me that apple sitting on the table\", \"please pass the butter\"]\n\n        # compute the train logits\n        train_logits = model.train(video, instructions)\n\n        # set the model to evaluation mode\n        model.model.eval()\n\n        # compute the eval logits with a conditional scale of 3\n        eval_logits = model.run(video, instructions, cond_scale=3.0)\n        print(eval_logits.shape)\n        \"\"\"\n        super().__init__()\n        if rt1_config is None:\n            rt1_config = RT1Config()\n        if vit_config is None:\n            vit_config = FilmViTConfig()\n\n        self.vit = FilmMaxVit(vit_config)\n        self.model = RT1(\n            config=rt1_config,\n            vit=self.vit,\n        )\n\n    def train(self, video, instructions):\n        \"\"\"\n        Computes the logits for the given video and instructions using the RT1 model in training mode.\n\n        Parameters\n        ----------\n        video : torch.Tensor\n            a tensor containing the video data\n        instructions : torch.Tensor\n            a tensor containing the instructions\n\n        Returns\n        -------\n        torch.Tensor\n            a tensor containing the computed logits\n        \"\"\"\n\n        try:\n            train_logits = self.model(video, instructions)\n            return train_logits\n        except Exception as e:\n            raise RuntimeError(\"Error in training: {}\".format(e))\n\n    def run(self, video, instructions, cond_scale=1.0):\n        \"\"\"\n        Computes the logits for the given video and instructions using the RT1 model in evaluation mode.\n\n        Parameters\n        ----------\n\n        video : torch.Tensor\n            a tensor containing the video data\n        instructions : torch.Tensor\n            a tensor containing the instructions\n        cond_scale : float, optional\n            a scale factor for the conditional scaling (default is 1.0)\n\n        Returns\n        -------\n        torch.Tensor\n            a tensor containing the computed logits\n        \"\"\"\n\n        try:\n            self.model.eval()\n            # shape => 2, 3, 6, 224, 224\n            eval_logits = self.model(video, instructions, cond_scale=cond_scale)\n            return eval_logits\n        except Exception as e:\n            raise RuntimeError(\"Error in evaluation: {}\".format(e))\n"
  },
  {
    "path": "rtx/rtx2.py",
    "content": "import torch\nfrom torch import nn\nfrom zeta.structs import (\n    AutoregressiveWrapper,\n    Decoder,\n    Encoder,\n    Transformer,\n    ViTransformerWrapper,\n)\n\n\nclass RTX2(torch.nn.Module):\n    \"\"\"\n    RTX2 is a transformer architecture that uses a ViT encoder and a transformer decoder.\n\n    Args:\n\n        image_size (int): Size of the image.\n        patch_size (int): Size of the patch.\n        encoder_dim (int): Dimension of the encoder.\n        encoder_depth (int): Depth of the encoder.\n        encoder_heads (int): Number of heads in the encoder.\n        num_tokens (int): Number of tokens.\n        max_seq_len (int): Maximum sequence length.\n        decoder_dim (int): Dimension of the decoder.\n        decoder_depth (int): Depth of the decoder.\n        decoder_heads (int): Number of heads in the decoder.\n        alibi_num_heads (int): Number of heads in the alibi attention.\n        attn_kv_heads (int): Number of heads in the attention key-value projection.\n        use_abs_pos_emb (bool): Whether to use absolute positional embeddings.\n        cross_attend (bool): Whether to cross attend in the decoder.\n        alibi_pos_bias (bool): Whether to use positional bias in the alibi attention.\n        rotary_xpos (bool): Whether to use rotary positional embeddings.\n        attn_flash (bool): Whether to use attention flash.\n        qk_norm (bool): Whether to normalize the query and key in the attention layer.\n\n    Returns:\n\n            torch.Tensor: The output of the model.\n\n    Usage:\n\n            >>> img = torch.randn(1, 3, 256, 256)\n            >>> text = torch.randint(0, 20000, (1, 1024))\n            >>> model = RTX2()\n            >>> output = model(img, text)\n            >>> print(output)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size=256,\n        patch_size=32,\n        encoder_dim=512,\n        encoder_depth=6,\n        encoder_heads=8,\n        num_tokens=20000,\n        max_seq_len=1024,\n        decoder_dim=512,\n        decoder_depth=6,\n        decoder_heads=8,\n        alibi_num_heads=4,\n        attn_kv_heads=2,\n        use_abs_pos_emb=False,\n        cross_attend=True,\n        alibi_pos_bias=True,\n        rotary_xpos=True,\n        attn_flash=True,\n        qk_norm=True,\n        *args,\n        **kwargs,\n    ):\n        super(RTX2, self).__init__()\n\n        # vit architecture\n        self.encoder = ViTransformerWrapper(\n            image_size=image_size,\n            patch_size=patch_size,\n            attn_layers=Encoder(\n                dim=encoder_dim,\n                depth=encoder_depth,\n                heads=encoder_heads,\n            ),\n        )\n\n        # palm model architecture\n        self.decoder = Transformer(\n            num_tokens=num_tokens,\n            max_seq_len=max_seq_len,\n            use_abs_pos_emb=use_abs_pos_emb,\n            attn_layers=Decoder(\n                dim=decoder_dim,\n                depth=decoder_depth,\n                heads=decoder_heads,\n                cross_attend=cross_attend,\n                alibi_pos_bias=alibi_pos_bias,\n                alibi_num_heads=alibi_num_heads,\n                rotary_xpos=rotary_xpos,\n                attn_kv_heads=attn_kv_heads,\n                attn_flash=attn_flash,\n                qk_norm=qk_norm,\n                *args,\n                **kwargs,\n            ),\n        )\n\n        # autoregressive wrapper to enable generation of tokens\n        self.decoder = AutoregressiveWrapper(self.decoder)\n\n        # Norm\n        self.norm = nn.LayerNorm(encoder_dim)\n\n    def forward(self, img: torch.Tensor, text: torch.Tensor):\n        \"\"\"Forward pass of the model.\"\"\"\n        try:\n            encoded = self.encoder(img, return_embeddings=True)\n            encoded = self.norm(encoded)\n            encoded = self.norm(encoded)\n\n            return self.decoder(text, context=encoded)\n        except Exception as error:\n            print(f\"Failed in forward method: {error}\")\n            raise\n"
  },
  {
    "path": "rtx2_example.py",
    "content": "import torch\nfrom rtx import RTX2\n\n\ndef run():\n    # usage\n    img = torch.randn(1, 3, 256, 256)\n    text = torch.randint(0, 20000, (1, 1024))\n\n    model = RTX2()\n    output = model(img, text)\n    print(output)\n\n\nif __name__ == \"__main__\":\n    run()\n"
  },
  {
    "path": "run.py",
    "content": "from examples import rtx1_example, train_example\nfrom rtx import RTX1, RTX2\nfrom rtx.rtx1 import FilmViTConfig\nfrom absl import app, flags, logging\n\nfrom . import rtx2_example\n\nREGISTRY = {\n    \"rtx1\": RTX1,\n    \"rtx2\": RTX2,\n}\n\nMODES = [\"inference\", \"train\"]\n\nEXAMPLE_SCRIPTS = {\n    \"rtx1\": rtx1_example,\n    \"rtx2\": rtx2_example,\n}\n\nFLAGS = flags.FLAGS\nflags.DEFINE_boolean(\n    \"pretrained_vit\", False, \"Whether to use a  pretrained ViT as a backbone or not.\"\n)\nflags.DEFINE_enum(\"model\", \"rtx1\", REGISTRY.keys(), \"Model to choose from.\")\nflags.DEFINE_enum(\"mode\", \"inference\", MODES, \"Experiment mode to run.\")\n\n\ndef main(_):\n    if FLAGS.mode == \"inference\":\n        EXAMPLE_SCRIPTS[FLAGS.model].run()\n    elif FLAGS.mode == \"train\":\n        if FLAGS.pretrained_vit and FLAGS.model == \"rtx2\":\n            logging.fatal(\n                \"Option `pretrained_vit` is not available for model {} \".format(\n                    FLAGS.model\n                )\n            )\n        model = REGISTRY[FLAGS.model](\n            vit_config=FilmViTConfig(pretrained=FLAGS.pretrained_vit)\n        )\n        train_example.run(model)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/test_data_utils.py",
    "content": "import io\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom rtx.data_util import describe, format_imgs, preprocess\n\n\ndef test_describe():\n    dic = {\n        \"key1\": \"value1\",\n        \"key2\": [1, 2, 3],\n        \"key3\": {\"nested_key\": \"nested_value\"},\n    }\n    describe(dic)\n\n\ndef test_describe_empty():\n    dic = {}\n    describe(dic)\n\n\ndef test_describe_non_dict():\n    non_dict = \"not a dict\"\n    describe(non_dict)\n\n\ndef test_preprocess():\n    dic = {\n        \"key1\": \"value1\",\n        \"key2\": [1, 2, 3],\n        \"key3\": {\"nested_key\": \"nested_value\"},\n    }\n    result = preprocess(dic)\n    assert result == dic\n\n\ndef test_preprocess_empty():\n    dic = {}\n    result = preprocess(dic)\n    assert result == dic\n\n\ndef test_preprocess_non_dict():\n    non_dict = \"not a dict\"\n    result = preprocess(non_dict)\n    assert result == non_dict\n\n\ndef test_preprocess_none_value():\n    dic = {\"key1\": None}\n    result = preprocess(dic)\n    assert result == {}\n\n\ndef test_preprocess_image():\n    img = Image.new(\"RGB\", (60, 30), color=\"red\")\n    img_byte_arr = io.BytesIO()\n    img.save(img_byte_arr, format=\"PNG\")\n    img_byte_arr = img_byte_arr.getvalue()\n    result = preprocess(img_byte_arr)\n    assert isinstance(result, np.ndarray)\n\n\ndef test_format_imgs():\n    dic = {\n        \"key1\": \"value1\",\n        \"key2\": [1, 2, 3],\n        \"key3\": {\"nested_key\": \"nested_value\"},\n    }\n    result = format_imgs(dic, 224)\n    assert result == dic\n\n\ndef test_format_imgs_empty():\n    dic = {}\n    result = format_imgs(dic, 224)\n    assert result == dic\n\n\ndef test_format_imgs_non_dict():\n    non_dict = \"not a dict\"\n    result = format_imgs(non_dict, 224)\n    assert result == non_dict\n\n\ndef test_format_imgs_image():\n    img = Image.new(\"RGB\", (60, 30), color=\"red\")\n    img_byte_arr = io.BytesIO()\n    img.save(img_byte_arr, format=\"PNG\")\n    img_byte_arr = img_byte_arr.getvalue()\n    result = format_imgs(img_byte_arr, 224)\n    assert isinstance(result, np.ndarray)\n\n\ndef test_format_imgs_tensor():\n    tensor = torch.tensor([1, 2, 3])\n    result = format_imgs(tensor, 224)\n    assert isinstance(result, torch.Tensor)\n\n\ndef test_format_imgs_list():\n    list_val = [1, 2, 3]\n    result = format_imgs(list_val, 224)\n    assert result == list_val\n\n\ndef test_format_imgs_nested_dict():\n    dic = {\"key1\": {\"nested_key\": \"nested_value\"}}\n    result = format_imgs(dic, 224)\n    assert result == dic\n\n\ndef test_format_imgs_nested_list():\n    dic = {\"key1\": [1, 2, 3]}\n    result = format_imgs(dic, 224)\n    assert result == dic\n\n\ndef test_format_imgs_nested_image():\n    img = Image.new(\"RGB\", (60, 30), color=\"red\")\n    img_byte_arr = io.BytesIO()\n    img.save(img_byte_arr, format=\"PNG\")\n    img_byte_arr = img_byte_arr.getvalue()\n    dic = {\"key1\": img_byte_arr}\n    result = format_imgs(dic, 224)\n    assert isinstance(result[\"key1\"], np.ndarray)\n\n\ndef test_format_imgs_nested_tensor():\n    tensor = torch.tensor([1, 2, 3])\n    dic = {\"key1\": tensor}\n    result = format_imgs(dic, 224)\n    assert isinstance(result[\"key1\"], torch.Tensor)\n\n\ndef test_format_imgs_nested_list():\n    list_val = [1, 2, 3]\n    dic = {\"key1\": list_val}\n    result = format_imgs(dic, 224)\n    assert result == dic\n"
  },
  {
    "path": "tests/test_rtx1.py",
    "content": "import unittest\nimport torch\nfrom rtx.rtx1 import RTX1, FilmViTConfig, RT1Config\n\n\nclass RTX1Test(unittest.TestCase):\n    def setUp(self):\n        self.batch_size = 2\n        self.num_frames = 6\n        self.num_actions = 11\n        self.num_action_bins = 256\n\n        self.video = torch.randn(self.batch_size, 3, self.num_frames, 224, 224)\n        self.instructions = [\n            \"bring me that apple sitting on the table\",\n            \"please pass the butter\",\n        ]\n\n        rt1_config = RT1Config(\n            num_actions=self.num_actions,\n            action_bins=self.num_action_bins,\n        )\n        self.rtx1 = RTX1(rt1_config)\n        self.rtx1_pretrained = RTX1(rt1_config, FilmViTConfig(pretrained=True))\n        self.expected_logits_shape = torch.Size(\n            [\n                self.batch_size,\n                self.num_frames,\n                self.num_actions,\n                self.num_action_bins,\n            ]\n        )\n\n    def test_default_pretrained_has_same_shape(self):\n        # Tests the general shape as the pretrained version from pytorch has\n        # different layernorm and conv2dnorm implementations.\n\n        assert len(self.rtx1.vit.layers) == len(self.rtx1_pretrained.vit.layers)\n\n    def test_default_train_eval(self):\n        train_logits = self.rtx1.train(self.video, self.instructions)\n\n        assert train_logits.shape == self.expected_logits_shape\n        self.rtx1.model.eval()\n\n        # compute the eval logits with a conditional scale of 3\n        eval_logits = self.rtx1.run(self.video, self.instructions, cond_scale=3.0)\n        assert eval_logits.shape == self.expected_logits_shape\n\n    def test_pretrained_train_eval(self):\n        train_logits = self.rtx1_pretrained.train(self.video, self.instructions)\n\n        assert train_logits.shape == self.expected_logits_shape\n        self.rtx1.model.eval()\n\n        # compute the eval logits with a conditional scale of 3\n        eval_logits = self.rtx1_pretrained.run(\n            self.video, self.instructions, cond_scale=3.0\n        )\n        assert eval_logits.shape == self.expected_logits_shape\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/tests.py",
    "content": "import pytest\nimport torch\nfrom PIL import Image\nfrom zeta.structs import (\n    AutoregressiveWrapper,\n    ViTransformerWrapper,\n)\n\nfrom rtx.efficient_net import EfficientNetFilm\nfrom rtx.rtx1 import RT1, RTX1, FilmMaxVit\nfrom rtx.rtx2 import RTX2\n\n\n########################### EfficientNetFilm ###########################\nimg = \"img.jpeg\"\n\n\n# Fixture to create an instance of the EfficientNetFilm class\n@pytest.fixture\ndef efficientnet_model():\n    model = EfficientNetFilm(\"efficientnet-b0\", 10)\n    return model\n\n\n# Test case to check if EfficientNetFilm initializes correctly\ndef test_efficientnet_init(efficientnet_model):\n    assert efficientnet_model is not None\n\n\n# Test case to check if EfficientNetFilm processes an image correctly\ndef test_efficientnet_process_image(efficientnet_model):\n    # Load a sample image\n    image_path = img\n    Image.open(image_path)\n\n    # Process the image using the model\n    features = efficientnet_model(image_path)\n\n    # Check if the output features are of the correct shape\n    assert isinstance(features, torch.Tensor)\n    assert features.shape == (1, efficientnet_model.num_features)\n\n\n# Test case to check if EfficientNetFilm handles image resizing correctly\ndef test_efficientnet_image_resize(efficientnet_model):\n    # Load a sample image\n    image_path = img\n    image = Image.open(image_path)\n\n    # Process the image using the model\n    efficientnet_model(image_path)\n\n    # Check if the input image was resized to the specified size\n    assert image.size == (\n        efficientnet_model.resize,\n        efficientnet_model.resize,\n    )\n\n\n# Test case to check if EfficientNetFilm handles model loading correctly\ndef test_efficientnet_model_loading(efficientnet_model):\n    # Check if the model was loaded successfully\n    assert efficientnet_model.model is not None\n\n\n# Test case to check if EfficientNetFilm handles image transformations correctly\ndef test_efficientnet_image_transformations(efficientnet_model):\n    # Load a sample image\n    image_path = img\n    Image.open(image_path)\n\n    # Process the image using the model\n    features = efficientnet_model(image_path)\n\n    # Check if image transformations were applied correctly\n    assert torch.max(features).item() <= 1.0\n    assert torch.min(features).item() >= -1.0\n\n\n# Test case to check if EfficientNetFilm handles the number of classes correctly\ndef test_efficientnet_num_classes(efficientnet_model):\n    # Check if the number of classes is set correctly\n    assert efficientnet_model.num_classes == 10\n\n\n# Test case to check if EfficientNetFilm handles missing image file correctly\ndef test_efficientnet_missing_image(efficientnet_model):\n    with pytest.raises(FileNotFoundError):\n        efficientnet_model(\"non_existent_image.jpg\")\n\n\n# Test case to check if EfficientNetFilm handles incorrect image file format correctly\ndef test_efficientnet_incorrect_image_format(efficientnet_model):\n    with pytest.raises(ValueError):\n        efficientnet_model(\"sample_image.txt\")\n\n\n# Test case to check if EfficientNetFilm handles model selection correctly\ndef test_efficientnet_model_selection():\n    # Check if different EfficientNet models can be selected\n    model_names = [\n        \"efficientnet-b0\",\n        \"efficientnet-b1\",\n        \"efficientnet-b2\",\n    ]\n    for model_name in model_names:\n        model = EfficientNetFilm(model_name, 10)\n        assert model is not None\n        assert model.model is not None\n\n\n# Test case to check if EfficientNetFilm handles invalid model name correctly\ndef test_efficientnet_invalid_model_name():\n    with pytest.raises(ValueError):\n        EfficientNetFilm(\"invalid_model\", 10)\n\n\n# Test case to check if EfficientNetFilm handles invalid number of classes correctly\ndef test_efficientnet_invalid_num_classes():\n    with pytest.raises(ValueError):\n        EfficientNetFilm(\"efficientnet-b0\", -10)\n\n\n# Test case to check if EfficientNetFilm handles invalid resize size correctly\ndef test_efficientnet_invalid_resize_size():\n    with pytest.raises(ValueError):\n        EfficientNetFilm(\"efficientnet-b0\", 10, resize=-100)\n\n\n# Test case to check if EfficientNetFilm handles input image with incorrect channels correctly\ndef test_efficientnet_incorrect_image_channels(efficientnet_model):\n    # Create an image with incorrect number of channels (4 channels)\n    image = Image.new(\n        \"RGBA\",\n        (efficientnet_model.resize, efficientnet_model.resize),\n        (255, 0, 0, 255),\n    )\n    image_path = \"incorrect_channels_image.png\"\n    image.save(image_path)\n\n    with pytest.raises(ValueError):\n        efficientnet_model(image_path)\n\n\n# Test case to check if EfficientNetFilm handles input image with incorrect size correctly\ndef test_efficientnet_incorrect_image_size(efficientnet_model):\n    # Create an image with incorrect size (smaller than resize size)\n    image = Image.new(\n        \"RGB\",\n        (\n            efficientnet_model.resize - 1,\n            efficientnet_model.resize - 1,\n        ),\n        (255, 0, 0),\n    )\n    image_path = \"incorrect_size_image.jpg\"\n    image.save(image_path)\n\n    with pytest.raises(ValueError):\n        efficientnet_model(image_path)\n\n\n########################### RTX1 ###########################\n\n\n# Fixture to create an instance of the RTX1 class\n@pytest.fixture\ndef rtx1_model():\n    model = RTX1()\n    return model\n\n\n# Test case to check if RTX1 initializes correctly\ndef test_rtx1_initialization(rtx1_model):\n    assert isinstance(rtx1_model, RTX1)\n    assert isinstance(rtx1_model.vit, FilmMaxVit)\n    assert isinstance(rtx1_model.model, RT1)\n\n\n# Test case to check if RTX1 handles training with video and instructions correctly\ndef test_rtx1_train(rtx1_model):\n    video = torch.randn(2, 3, 6, 224, 224)\n    instructions = [\n        \"bring me that apple sitting on the table\",\n        \"please pass the butter\",\n    ]\n\n    train_logits = rtx1_model.train(video, instructions)\n\n    assert isinstance(train_logits, torch.Tensor)\n    assert train_logits.shape == (2, rtx1_model.num_actions)\n\n\n# Test case to check if RTX1 handles evaluation with video and instructions correctly\ndef test_rtx1_eval(rtx1_model):\n    video = torch.randn(2, 3, 6, 224, 224)\n    instructions = [\n        \"bring me that apple sitting on the table\",\n        \"please pass the butter\",\n    ]\n\n    eval_logits = rtx1_model.run(video, instructions, cond_scale=3.0)\n\n    assert isinstance(eval_logits, torch.Tensor)\n    assert eval_logits.shape == (2, rtx1_model.num_actions)\n\n\n# Test case to check if RTX1 raises an error when training with invalid inputs\ndef test_rtx1_train_with_invalid_inputs(rtx1_model):\n    with pytest.raises(RuntimeError):\n        video = torch.randn(2, 3, 6, 224, 224)\n        instructions = [\n            \"bring me that apple sitting on the table\",\n            \"please pass the butter\",\n        ]\n        # Intentionally set an invalid shape for instructions\n        instructions = instructions[:1]  # Instructions shape should be (2,)\n        rtx1_model.train(video, instructions)\n\n\n# Test case to check if RTX1 raises an error when evaluating with invalid inputs\ndef test_rtx1_eval_with_invalid_inputs(rtx1_model):\n    with pytest.raises(RuntimeError):\n        video = torch.randn(2, 3, 6, 224, 224)\n        instructions = [\n            \"bring me that apple sitting on the table\",\n            \"please pass the butter\",\n        ]\n        # Intentionally set an invalid shape for video\n        video = video[:, :, :5]  # Video shape should be (2, 3, 6, 224, 224)\n        rtx1_model.run(video, instructions, cond_scale=3.0)\n\n\n# Test case to check if RTX1 handles conditional scaling correctly\ndef test_rtx1_conditional_scaling(rtx1_model):\n    video = torch.randn(2, 3, 6, 224, 224)\n    instructions = [\n        \"bring me that apple sitting on the table\",\n        \"please pass the butter\",\n    ]\n\n    eval_logits = rtx1_model.run(video, instructions, cond_scale=3.0)\n    eval_logits_without_scaling = rtx1_model.run(video, instructions)\n\n    # Check if the logits with and without scaling are different\n    assert not torch.allclose(eval_logits, eval_logits_without_scaling)\n\n\n# Test case to check if RTX1 handles model selection correctly\ndef test_rtx1_model_selection():\n    model_names = [\n        \"efficientnet-b0\",\n        \"efficientnet-b1\",\n        \"efficientnet-b2\",\n    ]\n    for model_name in model_names:\n        model = RTX1(model_name=model_name)\n        assert isinstance(model, RTX1)\n\n\n# Test case to check if RTX1 raises an error for an invalid model name\ndef test_rtx1_invalid_model_name():\n    with pytest.raises(ValueError):\n        RTX1(model_name=\"invalid_model\")\n\n\n# Test case to check if RTX1 handles negative number of classes correctly\ndef test_rtx1_negative_num_classes():\n    with pytest.raises(ValueError):\n        RTX1(num_classes=-100)\n\n\n# Test case to check if RTX1 handles negative dimension correctly\ndef test_rtx1_negative_dimension():\n    with pytest.raises(ValueError):\n        RTX1(dim=-96)\n\n\n# Test case to check if RTX1 handles negative dimension of convolutional stem correctly\ndef test_rtx1_negative_dim_conv_stem():\n    with pytest.raises(ValueError):\n        RTX1(dim_conv_stem=-64)\n\n\n# Test case to check if RTX1 handles negative dimension of head for ViT correctly\ndef test_rtx1_negative_dim_head_vit():\n    with pytest.raises(ValueError):\n        RTX1(dim_head_vit=-32)\n\n\n# Test case to check if RTX1 handles negative depth of ViT correctly\ndef test_rtx1_negative_depth_vit():\n    with pytest.raises(ValueError):\n        RTX1(depth_vit=(-2, 2, 5, 2))\n\n\n# Test case to check if RTX1 handles negative window size for ViT correctly\ndef test_rtx1_negative_window_size():\n    with pytest.raises(ValueError):\n        RTX1(window_size=-7)\n\n\n# Test case to check if RTX1 handles negative expansion rate for mbconv correctly\ndef test_rtx1_negative_mbconv_expansion_rate():\n    with pytest.raises(ValueError):\n        RTX1(mbconv_expansion_rate=-4)\n\n\n# Test case to check if RTX1 handles negative shrinkage rate for mbconv correctly\ndef test_rtx1_negative_mbconv_shrinkage_rate():\n    with pytest.raises(ValueError):\n        RTX1(mbconv_shrinkage_rate=-0.25)\n\n\n# Test case to check if RTX1 handles negative dropout rate for ViT correctly\ndef test_rtx1_negative_dropout_vit():\n    with pytest.raises(ValueError):\n        RTX1(dropout_vit=-0.1)\n\n\n# Test case to check if RTX1 handles negative number of actions correctly\ndef test_rtx1_negative_num_actions():\n    with pytest.raises(ValueError):\n        RTX1(num_actions=-11)\n\n\n# Test case to check if RTX1 handles negative depth of RT1 correctly\ndef test_rtx1_negative_depth_rt1():\n    with pytest.raises(ValueError):\n        RTX1(depth_rt1=-6)\n\n\n# Test case to check if RTX1 handles negative number of heads for RT1 correctly\ndef test_rtx1_negative_heads():\n    with pytest.raises(ValueError):\n        RTX1(heads=-8)\n\n\n# Test case to check if RTX1 handles negative dimension of head for RT1 correctly\ndef test_rtx1_negative_dim_head_rt1():\n    with pytest.raises(ValueError):\n        RTX1(dim_head_rt1=-64)\n\n\n# Test case to check if RTX1 handles negative conditional drop probability for RT1 correctly\ndef test_rtx1_negative_cond_drop_prob():\n    with pytest.raises(ValueError):\n        RTX1(cond_drop_prob=-0.2)\n\n\n########################### RTX2 ###########################\n\n\n# Fixture to create an instance of the RTX2 class\n@pytest.fixture\ndef rtx2_model():\n    model = RTX2()\n    return model\n\n\n# Test case to check if RTX2 initializes correctly\ndef test_rtx2_initialization(rtx2_model):\n    assert isinstance(rtx2_model, RTX2)\n    assert isinstance(rtx2_model.encoder, ViTransformerWrapper)\n    assert isinstance(rtx2_model.decoder, AutoregressiveWrapper)\n\n\n# Test case to check if RTX2 handles forward pass with image and text correctly\ndef test_rtx2_forward_pass(rtx2_model):\n    img = torch.randn(1, 3, 256, 256)\n    text = torch.randint(0, 20000, (1, 1024))\n\n    output = rtx2_model(img, text)\n\n    assert isinstance(output, torch.Tensor)\n\n\n# Test case to check if RTX2 raises an error when forwarding with invalid inputs\ndef test_rtx2_forward_with_invalid_inputs(rtx2_model):\n    with pytest.raises(Exception):\n        img = torch.randn(1, 3, 256, 256)\n        text = torch.randn(1, 1024, 512)  # Invalid shape for text input\n        rtx2_model(img, text)\n\n\n# Test case to check if RTX2 handles various model configurations correctly\ndef test_rtx2_with_different_configs():\n    config_combinations = [\n        {\"encoder_depth\": 6, \"decoder_depth\": 6},\n        {\"encoder_depth\": 4, \"decoder_depth\": 8},\n        {\"encoder_heads\": 8, \"decoder_heads\": 8},\n        {\"encoder_dim\": 512, \"decoder_dim\": 768},\n    ]\n\n    for config in config_combinations:\n        model = RTX2(**config)\n        assert isinstance(model, RTX2)\n        assert model.encoder.attn_layers.depth == config[\"encoder_depth\"]\n        assert model.decoder.attn_layers.depth == config[\"decoder_depth\"]\n        if \"encoder_heads\" in config:\n            assert model.encoder.attn_layers.heads == config[\"encoder_heads\"]\n        if \"decoder_heads\" in config:\n            assert model.decoder.attn_layers.heads == config[\"decoder_heads\"]\n        if \"encoder_dim\" in config:\n            assert model.encoder.attn_layers.dim == config[\"encoder_dim\"]\n        if \"decoder_dim\" in config:\n            assert model.decoder.attn_layers.dim == config[\"decoder_dim\"]\n\n\n# Test case to check if RTX2 handles negative image size correctly\ndef test_rtx2_negative_image_size():\n    with pytest.raises(ValueError):\n        RTX2(image_size=-256)\n\n\n# Test case to check if RTX2 handles negative patch size correctly\ndef test_rtx2_negative_patch_size():\n    with pytest.raises(ValueError):\n        RTX2(patch_size=-32)\n\n\n# Test case to check if RTX2 handles negative encoder dimension correctly\ndef test_rtx2_negative_encoder_dim():\n    with pytest.raises(ValueError):\n        RTX2(encoder_dim=-512)\n\n\n# Test case to check if RTX2 handles negative encoder depth correctly\ndef test_rtx2_negative_encoder_depth():\n    with pytest.raises(ValueError):\n        RTX2(encoder_depth=-6)\n\n\n# Test case to check if RTX2 handles negative decoder dimension correctly\ndef test_rtx2_negative_decoder_dim():\n    with pytest.raises(ValueError):\n        RTX2(decoder_dim=-512)\n\n\n# Test case to check if RTX2 handles negative decoder depth correctly\ndef test_rtx2_negative_decoder_depth():\n    with pytest.raises(ValueError):\n        RTX2(decoder_depth=-6)\n\n\n# Test case to check if RTX2 handles negative encoder heads correctly\ndef test_rtx2_negative_encoder_heads():\n    with pytest.raises(ValueError):\n        RTX2(encoder_heads=-8)\n\n\n# Test case to check if RTX2 handles negative decoder heads correctly\ndef test_rtx2_negative_decoder_heads():\n    with pytest.raises(ValueError):\n        RTX2(decoder_heads=-8)\n"
  }
]