[
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: [kyegomez]\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with a single Open Collective username\nko_fi: # Replace with a single Ko-fi username\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\notechie: # Replace with a single Otechie username\nlfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry\ncustom: #Nothing\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a detailed report on the bug and it's root cause. Conduct root cause error analysis\ntitle: \"[BUG] \"\nlabels: bug\nassignees: kyegomez\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: 'kyegomez'\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.yml",
    "content": "<!-- Thank you for contributing to Zeta!\n\nReplace this comment with:\n  - Description: a description of the change, \n  - Issue: the issue # it fixes (if applicable),\n  - Dependencies: any dependencies required for this change,\n  - Tag maintainer: for a quicker response, tag the relevant maintainer (see below),\n  - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out!\n\nIf you're adding a new integration, please include:\n  1. a test for the integration, preferably unit tests that do not rely on network access,\n  2. an example notebook showing its use.\n\nMaintainer responsibilities:\n  - nn / Misc / if you don't know who to tag: kye@apac.ai\n  - tokenizers: kye@apac.ai\n  - training / Prompts: kye@apac.ai\n  - models: kye@apac.ai\n\nIf no one reviews your PR within a few days, feel free to kye@apac.ai\n\nSee contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/kyegomez/zeta"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates\n\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n\n  - package-ecosystem: \"pip\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n\n"
  },
  {
    "path": ".github/labeler.yml",
    "content": "# this is a config file for the github action labeler\n\n# Add 'label1' to any changes within 'example' folder or any subfolders\nexample_change:\n- example/**\n\n# Add 'label2' to any file changes within 'example2' folder\nexample2_change: example2/*\n\n# Add label3 to any change to .txt files within the entire repository. Quotation marks are required for the leading asterisk\ntext_files:\n- '**/*.txt'\n"
  },
  {
    "path": ".github/workflows/code_quality_control.yml",
    "content": "name: Linting and Formatting\n\non:\n  push:\n    branches:\n      - main\n\njobs:\n  lint_and_format:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Find Python files\n        run: find swarms_torch -name \"*.py\" -type f -exec autopep8 --in-place --aggressive --aggressive {} +\n\n      - name: Push changes\n        uses: ad-m/github-push-action@master\n        with:\n          github_token: ${{ secrets.GITHUB_TOKEN }}"
  },
  {
    "path": ".github/workflows/cos_integration.yml",
    "content": "name: Continuous Integration\n\non:\n  push:\n    branches:\n      - main\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Run unit tests\n        run: pytest tests/unit\n\n      - name: Run integration tests\n        run: pytest tests/integration\n\n      - name: Run code coverage\n        run: pytest --cov=swarms tests/\n\n      - name: Run linters\n        run: pylint swarms\n\n      - name: Build documentation\n        run: make docs\n\n      - name: Validate documentation\n        run: sphinx-build -b linkcheck docs build/docs\n\n      - name: Run performance tests\n        run: pytest tests/performance"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: Docs WorkFlow\n\non:\n  push:\n    branches:\n      - master\n      - main\n      - develop\njobs:\n  deploy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n      - run: pip install mkdocs-material\n      - run: pip install \"mkdocstrings[python]\"\n      - run: mkdocs gh-deploy --force"
  },
  {
    "path": ".github/workflows/docs_test.yml",
    "content": "name: Documentation Tests\n\non:\n  push:\n    branches:\n      - master\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Build documentation\n        run: make docs\n\n      - name: Validate documentation\n        run: sphinx-build -b linkcheck docs build/docs"
  },
  {
    "path": ".github/workflows/label.yml",
    "content": "# This workflow will triage pull requests and apply a label based on the\n# paths that are modified in the pull request.\n#\n# To use this workflow, you will need to set up a .github/labeler.yml\n# file with configuration.  For more information, see:\n# https://github.com/actions/labeler\n\nname: Labeler\non: [pull_request_target]\n\njobs:\n  label:\n\n    runs-on: ubuntu-latest\n    permissions:\n      contents: read\n      pull-requests: write\n\n    steps:\n    - uses: actions/labeler@v5.0.0\n      with:\n        repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n"
  },
  {
    "path": ".github/workflows/lints.yml",
    "content": "name: Linting\n\non:\n  push:\n    branches:\n      - master\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Run linters\n        run: pylint swarms_torch"
  },
  {
    "path": ".github/workflows/pr_request_checks.yml",
    "content": "name: Pull Request Checks\n\non:\n  pull_request:\n    branches:\n      - master\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Run tests and checks\n        run: |\n          pytest tests/\n          pylint swarms_torch"
  },
  {
    "path": ".github/workflows/pull-request-links.yml",
    "content": "name: readthedocs/actions\non:\n  pull_request_target:\n    types:\n      - opened\n    paths:\n      - \"docs/**\"\n\npermissions:\n  pull-requests: write\n\njobs:\n  pull-request-links:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: readthedocs/actions/preview@v1\n        with:\n          project-slug: swarms_torch"
  },
  {
    "path": ".github/workflows/pylint.yml",
    "content": "name: Pylint\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.9\", \"3.10\"]\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v5\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --no-cache-dir --upgrade pip\n        pip install pylint\n    - name: Analysing the code with pylint\n      run: |\n        pylint $(git ls-files '*.py')\n"
  },
  {
    "path": ".github/workflows/python-publish.yml",
    "content": "\nname: Upload Python Package\n\non:\n  release:\n    types: [published]\n\npermissions:\n  contents: read\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python\n      uses: actions/setup-python@v5\n      with:\n        python-version: '3.10'\n    - name: Install dependencies\n      run: |\n        python -m pip install --no-cache-dir --upgrade pip\n        pip install build\n    - name: Build package\n      run: python -m build\n    - name: Publish package\n      uses: pypa/gh-action-pypi-publish@2f6f737ca5f74c637829c0f5c3acd0e29ea5e8bf\n      with:\n        user: __token__\n        password: ${{ secrets.PYPI_API_TOKEN }}"
  },
  {
    "path": ".github/workflows/quality.yml",
    "content": "name: Quality\n\non:\n  push:\n    branches: [ \"main\" ]\n  pull_request:\n    branches: [ \"main\" ]\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n    steps:\n      - name: Checkout actions\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n      - name: Init environment \n        uses: ./.github/actions/init-environment \n      - name: Run linter\n        run: |\n          pylint `git diff --name-only --diff-filter=d origin/main HEAD | grep -E '\\.py$' | tr '\\n' ' '`"
  },
  {
    "path": ".github/workflows/ruff.yml",
    "content": "name: Ruff\non: [ push, pull_request ]\njobs:\n  ruff:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: chartboost/ruff-action@v1\n"
  },
  {
    "path": ".github/workflows/run_test.yml",
    "content": "name: Python application test\n\non: [push]\n\njobs:\n  build:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python 3.10\n      uses: actions/setup-python@v5\n      with:\n        python-version: '3.10'\n    - name: Install dependencies\n      run: |\n        python -m pip install --no-cache-dir --upgrade pip\n        pip install pytest\n        if [ -f requirements.txt ]; then pip install --no-cache-dir -r requirements.txt; fi\n    - name: Run tests with pytest\n      run: |\n        pytest tests/\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.\n#\n# You can adjust the behavior by modifying this file.\n# For more information, see:\n# https://github.com/actions/stale\nname: Mark stale issues and pull requests\n\non:\n  schedule:\n  - cron: '26 12 * * *'\n\njobs:\n  stale:\n\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n      pull-requests: write\n\n    steps:\n    - uses: actions/stale@v9\n      with:\n        repo-token: ${{ secrets.GITHUB_TOKEN }}\n        stale-issue-message: 'Stale issue message'\n        stale-pr-message: 'Stale pull request message'\n        stale-issue-label: 'no-issue-activity'\n        stale-pr-label: 'no-pr-activity'"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: test\n\non:\n  push:\n    branches: [master]\n  pull_request:\n  workflow_dispatch:\n\nenv:\n  POETRY_VERSION: \"1.4.2\"\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version:\n          - \"3.9\"\n          - \"3.10\"\n          - \"3.11\"\n        test_type:\n          - \"core\"\n          - \"extended\"\n    name: Python ${{ matrix.python-version }} ${{ matrix.test_type }}\n    steps:\n      - uses: actions/checkout@v4\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: \"./.github/actions/poetry_setup\"\n        with:\n          python-version: ${{ matrix.python-version }}\n          poetry-version: \"1.4.2\"\n          cache-key: ${{ matrix.test_type }}\n          install-command: |\n              if [ \"${{ matrix.test_type }}\" == \"core\" ]; then\n                echo \"Running core tests, installing dependencies with poetry...\"\n                poetry install\n              else\n                echo \"Running extended tests, installing dependencies with poetry...\"\n                poetry install -E extended_testing\n              fi\n      - name: Run ${{matrix.test_type}} tests\n        run: |\n          if [ \"${{ matrix.test_type }}\" == \"core\" ]; then\n            make test\n          else\n            make extended_tests\n          fi\n        shell: bash"
  },
  {
    "path": ".github/workflows/testing.yml",
    "content": "name: Unit Tests\n\non:\n  push:\n    branches:\n      - master\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install --no-cache-dir -r requirements.txt\n\n      - name: Run unit tests\n        run: pytest tests/"
  },
  {
    "path": ".github/workflows/unit-test.yml",
    "content": "name: build\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n\n  build:\n\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Setup Python\n      uses: actions/setup-python@v5\n      with:\n        python-version: '3.10'\n\n    - name: Install dependencies\n      run: pip install --no-cache-dir -r requirements.txt\n\n    - name: Run Python unit tests\n      run: python3 -m unittest tests/\n\n    - name: Verify that the Docker image for the action builds\n      run: docker build . --file Dockerfile\n\n    - name: Verify integration test results\n      run: python3 -m unittest tests/\n"
  },
  {
    "path": ".github/workflows/welcome.yml",
    "content": "name: Welcome WorkFlow\n\non:\n  issues:\n    types: [opened]\n  pull_request_target:\n    types: [opened]\n\njobs:\n  build:\n    name: 👋 Welcome\n    permissions: write-all\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/first-interaction@v1.3.0\n        with:\n          repo-token: ${{ secrets.GITHUB_TOKEN }}\n          issue-message: \"Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.\"\n          pr-message:  \"Hello there, thank you for opening an PR ! 🙏🏻 The team was notified and they will get back to you asap.\""
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n.vscode/\n.vscode\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\n.ruff_cache/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/ambv/black\n    rev: 22.3.0\n    hooks:\n    - id: black\n  - repo: https://github.com/charliermarsh/ruff-pre-commit\n    rev: 'v0.0.255'\n    hooks:\n      - id: ruff\n        args: [--fix]\n  - repo: https://github.com/nbQA-dev/nbQA\n    rev: 1.6.3\n    hooks:\n    - id: nbqa-black\n      additional_dependencies: [ipython==8.12, black]\n    - id: nbqa-ruff \n      args: [\"--ignore=I001\"]\n      additional_dependencies: [ipython==8.12, ruff]"
  },
  {
    "path": ".readthedocs.yml",
    "content": "version: 2\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.11\"\n\nmkdocs:\n  configuration: mkdocs.yml\n\npython:\n   install:\n   - requirements: requirements.txt"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Eternal Reclaimer\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)\n\n# Screen AI\nImplementation of the ScreenAI model from the paper: \"A Vision-Language Model for UI and Infographics Understanding\". The flow is:\nimg + text -> patch sizes -> vit -> embed + concat -> attn + ffn -> cross attn + ffn + self attn -> to out. [PAPER LINK: ](https://arxiv.org/abs/2402.04615)\n\n## Install\n`pip3 install screenai`\n\n## Usage\n```python\n\nimport torch\nfrom screenai.main import ScreenAI\n\n# Create a tensor for the image\nimage = torch.rand(1, 3, 224, 224)\n\n# Create a tensor for the text\ntext = torch.randn(1, 1, 512)\n\n# Create an instance of the ScreenAI model with specified parameters\nmodel = ScreenAI(\n    patch_size=16,\n    image_size=224,\n    dim=512,\n    depth=6,\n    heads=8,\n    vit_depth=4,\n    multi_modal_encoder_depth=4,\n    llm_decoder_depth=4,\n    mm_encoder_ff_mult=4,\n)\n\n# Perform forward pass of the model with the given text and image tensors\nout = model(text, image)\n\n# Print the shape of the output tensor\nprint(out)\n\n\n```\n\n# License\nMIT\n\n\n## Citation\n```bibtex\n\n@misc{baechler2024screenai,\n    title={ScreenAI: A Vision-Language Model for UI and Infographics Understanding}, \n    author={Gilles Baechler and Srinivas Sunkara and Maria Wang and Fedir Zubach and Hassan Mansoor and Vincent Etter and Victor Cărbune and Jason Lin and Jindong Chen and Abhanshu Sharma},\n    year={2024},\n    eprint={2402.04615},\n    archivePrefix={arXiv},\n    primaryClass={cs.CV}\n}\n```\n\n# Todo\n- [ ] Implement the nn.ModuleList([]) in the encoder and decoder\n"
  },
  {
    "path": "example.py",
    "content": "import torch\nfrom screenai.main import ScreenAI\n\n# Create a tensor for the image\nimage = torch.rand(1, 3, 224, 224)\n\n# Create a tensor for the text\ntext = torch.randint(0, 20000, (1, 1028))\n\n# Create an instance of the ScreenAI model with specified parameters\nmodel = ScreenAI(\n    num_tokens = 20000,\n    max_seq_len = 1028,\n    patch_size=16,\n    image_size=224,\n    dim=512,\n    depth=6,\n    heads=8,\n    vit_depth=4,\n    multi_modal_encoder_depth=4,\n    llm_decoder_depth=4,\n    mm_encoder_ff_mult=4,\n)\n\n# Perform forward pass of the model with the given text and image tensors\nout = model(text, image)\n\n# Print the shape of the output tensor\nprint(out)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poetry]\nname = \"screenai\"\nversion = \"0.0.8\"\ndescription = \"Screen AI - Pytorch\"\nlicense = \"MIT\"\nauthors = [\"Kye Gomez <kye@apac.ai>\"]\nhomepage = \"https://github.com/kyegomez/ScreenAI\"\ndocumentation = \"https://github.com/kyegomez/ScreenAI\" \nreadme = \"README.md\"  # Assuming you have a README.md\nrepository = \"https://github.com/kyegomez/ScreenAI\"\nkeywords = [\"artificial intelligence\", \"deep learning\", \"optimizers\", \"Prompt Engineering\"]\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Intended Audience :: Developers\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Programming Language :: Python :: 3.9\"\n]\n\n[tool.poetry.dependencies]\npython = \"^3.6\"\nswarms = \"*\"\nzetascale = \"*\"\neinops = \"*\"\ntorch = \"*\"\ntorchvision = \"*\"\n\n\n[tool.poetry.group.lint.dependencies]\nruff = \"^0.1.6\"\ntypes-toml = \"^0.10.8.1\"\ntypes-redis = \"^4.3.21.6\"\ntypes-pytz = \"^2023.3.0.0\"\nblack = \"^23.1.0\"\ntypes-chardet = \"^5.0.4.6\"\nmypy-protobuf = \"^3.0.0\"\n\n\n[tool.autopep8]\nmax_line_length = 80\nignore = \"E501,W6\"  # or [\"E501\", \"W6\"]\nin-place = true\nrecursive = true\naggressive = 3\n\n\n[tool.ruff]\nline-length = 70\n\n[tool.black]\nline-length = 70\ntarget-version = ['py38']\npreview = true\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\nzetascale\neinops\ntorch \n"
  },
  {
    "path": "screenai/__init__.py",
    "content": "from screenai.main import (\n    CrossAttention,\n    MultiModalEncoder,\n    MultiModalDecoder,\n    ScreenAI,\n)\n\n\n__all__ = [\n    \"CrossAttention\",\n    \"MultiModalEncoder\",\n    \"MultiModalDecoder\",\n    \"ScreenAI\",\n]\n"
  },
  {
    "path": "screenai/main.py",
    "content": "import torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import Tensor, einsum, nn\nfrom torch.autograd import Function\nfrom zeta.nn import (\n    SwiGLU,\n    FeedForward,\n    Attention,\n)\nfrom zeta.structs import (\n    Encoder,\n    ViTransformerWrapper,\n)\n\n# helper functions\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\ndef pair(val):\n    return (val, val) if not isinstance(val, tuple) else val\n\n\ndef divisible_by(numer, denom):\n    return (numer % denom) == 0\n\n\ndef dynamic_patching(x, patch_size, image_size):\n    # Calculate the patch size based off the image\n    patch_size = pair(patch_size)\n    image_size = pair(image_size)\n\n    # Get the height and width of the image\n    h, w = image_size\n\n    # Use einops to rearrange the image\n    x = rearrange(\n        x,\n        \"b c (h p1) (w p2) -> b (h w) (p1 p2 c)\",\n        p1=patch_size[0],\n        p2=patch_size[1],\n    )\n\n    return x\n\n\n# distributed\n\n\ndef pad_dim_to(t, length, dim=0):\n    pad_length = length - t.shape[dim]\n    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)\n    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))\n\n\ndef all_gather_variable_batch(t):\n    device, rank, world_size = (\n        t.device,\n        dist.get_rank(),\n        dist.get_world_size(),\n    )\n\n    size = torch.tensor(t.shape[0], device=device, dtype=torch.long)\n    sizes = [\n        torch.empty_like(size, device=device, dtype=torch.long)\n        for i in range(world_size)\n    ]\n    dist.all_gather(sizes, size)\n\n    sizes = torch.stack(sizes)\n    max_size = sizes.amax().item()\n\n    padded_t = pad_dim_to(t, max_size, dim=0)\n    gathered_tensors = [\n        torch.empty_like(\n            padded_t, device=device, dtype=padded_t.dtype\n        )\n        for i in range(world_size)\n    ]\n    dist.all_gather(gathered_tensors, padded_t)\n\n    gathered_tensor = torch.cat(gathered_tensors)\n    seq = torch.arange(max_size, device=device)\n\n    mask = rearrange(seq, \"j -> 1 j\") < rearrange(sizes, \"i -> i 1\")\n    mask = rearrange(mask, \"i j -> (i j)\")\n\n    gathered_tensor = gathered_tensor[mask]\n    sizes = sizes.tolist()\n\n    return gathered_tensor, sizes\n\n\nclass AllGather(Function):\n    @staticmethod\n    def forward(ctx, x):\n        assert dist.is_initialized() and dist.get_world_size() > 1\n        x, batch_sizes = all_gather_variable_batch(x)\n        ctx.batch_sizes = batch_sizes\n        return x\n\n    @staticmethod\n    def backward(ctx, grads):\n        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()\n        grads_by_rank = grads.split(batch_sizes, dim=0)\n        return grads_by_rank[rank]\n\n\nall_gather = AllGather.apply\n\n\n# normalization\n# they use layernorm without bias, something that pytorch does not offer\n\n\n# to latents\n\n\nclass EmbedToLatents(nn.Module):\n    def __init__(self, dim, dim_latents):\n        super().__init__()\n        self.to_latents = nn.Linear(dim, dim_latents, bias=False)\n\n    def forward(self, x):\n        latents = self.to_latents(x)\n        return F.normalize(latents, dim=-1)\n\n\n# parallel attention and feedforward with residual\n# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward\n\n\nclass CrossAttention(nn.Module):\n    \"\"\"\n    Initializes the ScreenAI model.\n\n    Args:\n    dim (int): The input dimension.\n    context_dim (int, optional): The dimension of the context. Defaults to None.\n    dim_head (int, optional): The dimension of each head. Defaults to 64.\n    heads (int, optional): The number of attention heads. Defaults to 8.\n    parallel_ff (bool, optional): Whether to use parallel feedforward. Defaults to False.\n    ff_mult (int, optional): The multiplier for the feedforward inner dimension. Defaults to 4.\n    norm_context (bool, optional): Whether to apply layer normalization to the context. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        *,\n        context_dim=None,\n        dim_head=64,\n        heads=8,\n        parallel_ff=False,\n        ff_mult=4,\n        norm_context=False,\n    ):\n        super().__init__()\n        self.heads = heads\n        self.scale = dim_head**-0.5\n        inner_dim = heads * dim_head\n        context_dim = default(context_dim, dim)\n\n        self.norm = nn.LayerNorm(dim)\n        self.context_norm = (\n            nn.LayerNorm(context_dim)\n            if norm_context\n            else nn.Identity()\n        )\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n        # whether to have parallel feedforward\n\n        ff_inner_dim = ff_mult * dim\n\n        self.ff = (\n            nn.Sequential(\n                nn.Linear(dim, ff_inner_dim * 2, bias=False),\n                SwiGLU(),\n                nn.Linear(ff_inner_dim, dim, bias=False),\n            )\n            if parallel_ff\n            else None\n        )\n\n    def forward(self, x, context):\n        \"\"\"\n        einstein notation\n        b - batch\n        h - heads\n        n, i, j - sequence length (base sequence length, source, target)\n        d - feature dimension\n        \"\"\"\n\n        # pre-layernorm, for queries and context\n\n        x = self.norm(x)\n        context = self.context_norm(context)\n\n        # get queries\n\n        q = self.to_q(x)\n        q = rearrange(q, \"b n (h d) -> b h n d\", h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # get key / values\n\n        k, v = self.to_kv(context).chunk(2, dim=-1)\n\n        # query / key similarity\n\n        sim = einsum(\"b h i d, b j d -> b h i j\", q, k)\n\n        # attention\n\n        attn = sim.softmax(dim=-1)\n\n        # aggregate\n\n        out = einsum(\"b h i j, b j d -> b h i d\", attn, v)\n\n        # merge and combine heads\n\n        out = rearrange(out, \"b h n d -> b n (h d)\")\n        out = self.to_out(out)\n\n        # add parallel feedforward (for multimodal layers)\n\n        if exists(self.ff):\n            out = out + self.ff(x)\n\n        return out\n\n\nclass MultiModalEncoder(nn.Module):\n    \"\"\"\n    MultiModalEncoder class is responsible for encoding multi-modal inputs using self-attention mechanism.\n\n    Args:\n        dim (int): The dimension of the input and output tensors. Default is 512.\n        depth (int): The number of layers in the encoder. Default is 6.\n        dim_head (int): The dimension of each head in the self-attention mechanism. Default is 64.\n        heads (int): The number of attention heads. Default is 8.\n        *args: Variable length argument list.\n        **kwargs: Arbitrary keyword arguments.\n\n    Attributes:\n        dim (int): The dimension of the input and output tensors.\n        depth (int): The number of layers in the encoder.\n        heads (int): The number of attention heads.\n        dim_head (int): The dimension of each head in the self-attention mechanism.\n        layers (list): List of attention and feedforward layers.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int = 512,\n        depth: int = 6,\n        dim_head: int = 64,\n        heads: int = 8,\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.depth = depth\n        self.heads = heads\n        self.dim_head = dim_head\n\n        self.flash = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        self.attn = Attention(\n            dim,\n            dim_head,\n            heads,\n            causal=True,\n            qk_norm=True,\n            flash=self.flash,\n        )\n        self.ffn = FeedForward(dim, dim, 4, *args, **kwargs)\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"\n        Forward pass of the MultiModalEncoder.\n\n        Args:\n            x (Tensor): The input tensor.\n\n        Returns:\n            Tensor: The encoded tensor.\n\n        \"\"\"\n        skip = x\n        x, _ = self.attn(x)\n        x = x + skip\n        x = self.ffn(x) + x\n\n        return x + skip\n\n\nclass MultiModalDecoder(nn.Module):\n    \"\"\"\n    MultiModalDecoder module for decoding multi-modal inputs.\n\n    Args:\n        dim (int): The dimension of the input.\n        depth (int): The number of layers in the decoder.\n        dim_head (int): The dimension of each attention head.\n        heads (int): The number of attention heads.\n        *args: Variable length argument list.\n        **kwargs: Arbitrary keyword arguments.\n\n    Attributes:\n        dim (int): The dimension of the input.\n        depth (int): The number of layers in the decoder.\n        heads (int): The number of attention heads.\n        dim_head (int): The dimension of each attention head.\n        layers (nn.ModuleList): List of decoder layers.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int = 512,\n        depth: int = 6,\n        dim_head: int = 64,\n        heads: int = 8,\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.depth = depth\n        self.heads = heads\n        self.dim_head = dim_head\n        self.flash = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        self.cross_attn = CrossAttention(\n            dim,\n            dim_head=dim_head,\n            heads=heads,\n            parallel_ff=True,\n        )\n\n        self.attn = Attention(\n            dim,\n            dim_head,\n            heads,\n            causal=True,\n            qk_norm=True,\n            flash=self.flash,\n        )\n\n    def forward(self, x: Tensor) -> Tensor:\n        skip = x\n        x = self.cross_attn(x, x) + x\n        x, _ = self.attn(x)\n\n        return x + skip\n\n\nclass ScreenAI(nn.Module):\n    \"\"\"\n    ScreenAI module for multimodal learning.\n\n    Args:\n        patch_size (int): Size of the image patches.\n        image_size (int): Size of the input image.\n        dim (int): Dimension of the model.\n        depth (int): Depth of the model.\n        dim_head (int): Dimension of the attention head.\n        heads (int): Number of attention heads.\n        vit_depth (int): Depth of the ViT transformer.\n        multi_modal_encoder_depth (int): Depth of the multimodal encoder.\n        llm_decoder_depth (int): Depth of the LLM decoder.\n        mm_encoder_ff_mult (int): Multiplier for the feed-forward dimension in the multimodal encoder.\n        *args: Variable length argument list.\n        **kwargs: Arbitrary keyword arguments.\n\n    Attributes:\n        patch_size (int): Size of the image patches.\n        image_size (int): Size of the input image.\n        dim (int): Dimension of the model.\n        depth (int): Depth of the model.\n        heads (int): Number of attention heads.\n        vit_depth (int): Depth of the ViT transformer.\n        multi_modal_encoder_depth (int): Depth of the multimodal encoder.\n        llm_decoder_depth (int): Depth of the LLM decoder.\n        patch_embedding (nn.Conv2d): Patch embedding layer.\n        vit (ViTransformerWrapper): ViT transformer layer.\n        image_embedding (nn.Linear): Image embedding layer.\n        to_out (nn.Sequential): Output layer.\n        flash (str): Device to use for computation.\n        encoder (MultiModalEncoder): Multimodal encoder layer.\n        decoder (MultiModalDecoder): LLM decoder layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_tokens: int,\n        max_seq_len: int,\n        patch_size: int,\n        image_size: int = 224,\n        dim: int = 512,\n        depth: int = 6,\n        dim_head: int = 64,\n        heads: int = 8,\n        vit_depth: int = 4,\n        multi_modal_encoder_depth: int = 4,\n        llm_decoder_depth: int = 4,\n        channels: int = 3,\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n        self.num_tokens = num_tokens\n        self.max_seq_len = max_seq_len\n        self.patch_size = patch_size\n        self.image_size = image_size\n        self.dim = dim\n        self.depth = depth\n        self.heads = heads\n        self.vit_depth = vit_depth\n        self.multi_modal_encoder_depth = multi_modal_encoder_depth\n        self.llm_decoder_depth = llm_decoder_depth\n        \n        # ViTransformerWrapper\n        self.vit = ViTransformerWrapper(\n            image_size=image_size,\n            patch_size=patch_size,\n            post_emb_norm=True,\n            attn_layers=Encoder(\n                dim=dim, depth=vit_depth, heads=heads\n            ),\n        )\n\n        # Image embedding\n        self.image_embedding = nn.Linear(dim, dim)\n\n        # To out\n        self.to_out = nn.Sequential(\n            nn.LayerNorm(dim), nn.Linear(dim, dim), nn.Softmax(dim=-1)\n        )\n\n        # If cuda is avaialble then cuda\n        self.flash = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # MultiModal Encoder layers\n        self.encoder = MultiModalEncoder(\n            dim,\n            multi_modal_encoder_depth,\n            dim_head,\n            heads,\n        )\n\n        # LLM Layer / T5\n        self.decoder = MultiModalDecoder(\n            dim,\n            llm_decoder_depth,\n            dim_head,\n            heads,\n        )\n        self.to_patch_embedding = nn.Sequential(\n            nn.LayerNorm(dim),\n            nn.Linear(dim, dim),\n            nn.LayerNorm(dim),\n        )\n        \n        \n        # Embedding for the tokens\n        self.embedding = nn.Embedding(num_tokens, dim)\n\n    def forward(self, text: Tensor, img: Tensor) -> Tensor:\n        \"\"\"\n        Forward pass of the ScreenAI module.\n\n        Args:\n            text (Tensor): Input text tensor.\n            img (Tensor): Input image tensor.\n\n        Returns:\n            Tensor: Output tensor.\n        \"\"\"\n        text = self.embedding(text)\n        # Aspect ratio preserving grid with max e.g 25 patches, output needs to be 4\n        x = rearrange(\n            img,\n            \"b c (h p1) (w p2) -> b c (h p1) (w p2)\",\n            p1=self.patch_size,\n            p2=self.patch_size,\n        )\n\n        # vit\n        img = self.vit(img, return_embeddings=True)\n        print(f\"Image shape: {img.shape}\")\n\n        # Embed image\n        # img = self.image_embedding(img)\n        img = self.to_patch_embedding(img)\n\n        # Concatenate image and text\n        x = torch.cat((img, text), dim=1)\n        print(x.shape)\n\n        # T5 Multimodal encoder\n        x = self.encoder(x)\n\n        # Pass the k, v values into the cross attention of llm\n        x = self.decoder(x)\n\n        # To out\n        x = self.to_out(x)\n\n        return x\n"
  }
]