[
  {
    "path": ".github/ISSUE_TEMPLATE/breaking-bug-report.md",
    "content": "---\nname: Breaking bug report\nabout: Create a report about a breaking bug\ntitle: \"[BUG: Breaking]\"\nlabels: 'bug: breaking'\nassignees: ''\n\n---\n\n## 🧨 Describe the Bug\n\nA clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc).\n\n## 📄 Input Document\n\nAttach the PDF or input file that triggered the error.\n\n## 📤 Output Trace / Stack Trace\n\nPaste the **complete** stack trace or error output, if available.\n\n<details>\n<summary>Click to expand</summary>\n\n```\nPaste stack trace here\n```\n\n</details>\n\n## ⚙️ Environment\n\nPlease fill in all relevant details:\n\n- **Marker version**: \n- **Surya version**: \n- **Python version**: \n- **PyTorch version**: \n- **Transformers version**: \n- **Operating System** (incl. container info if relevant): \n\n## ✅ Expected Behavior\n\nWhat did you expect Marker to do?\n\n## 📟 Command or Code Used\n\nPaste the **exact bash command** or **Python code** you used to run Marker:\n\n<details>\n<summary>Click to expand</summary>\n\n```bash\n# or Python code block\nyour_command_here --with-flags\n```\n\n</details>\n\n## 📎 Additional Context\n\nAny other context that might help us debug this (e.g., CLI options, working directory, runtime settings).\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: \"[FEAT]\"\nlabels: enhancement\nassignees: ''\n\n---\n\n## ✨ Is your feature request related to a problem?\n\nA clear and concise description of what the problem is. \n\n## 💡 Describe the Solution You'd Like\n\nA concise description of what you want to happen or how you envision it working.\n\n## 📋 Alternatives Considered\n\nAny alternative solutions or workarounds you've tried.\n\n## 🧩 Additional Context\n\nAny additional context, references, or related issues.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/output-bug-report.md",
    "content": "---\nname: Output bug report\nabout: Create a report about poor output quality\ntitle: \"[BUG: Output]\"\nlabels: 'bug: output'\nassignees: ''\n\n---\n\n## 📝 Describe the Output Issue\n\nA clear and concise description of the incorrect or unexpected output.\n\n## 📄 Input Document\n\nAttach the PDF or input file used.\n\n## 📤 Current Output\n\nPaste the Markdown or HTML that Marker generated:\n\n````markdown\nPaste output here\n`````\n\n## ✅ Expected Output\n\nDescribe or paste what you expected Marker to generate.\n\n## ⚙️ Environment\n\nPlease fill in all relevant details:\n\n* **Marker version**:\n* **Surya version**:\n* **Python version**:\n* **PyTorch version**:\n* **Transformers version**:\n* **Operating System**:\n\n## 📟 Command or Code Used\n\nPaste the **exact bash command** or **Python code** you used to run Marker:\n\n<details>\n<summary>Click to expand</summary>\n\n```bash\n# or Python code block\nyour_command_here --with-flags\n```\n\n</details>\n\n## 📎 Additional Context\n\nAny other relevant info, configs, or assumptions.\n"
  },
  {
    "path": ".github/workflows/benchmarks.yml",
    "content": "name: Integration test\n\non: [push]\n\nenv:\n  PYTHONIOENCODING: \"utf-8\"\n\njobs:\n  build:\n    runs-on: t4_gpu\n    steps:\n      - uses: actions/checkout@v3\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.11\n      - name: Install python dependencies\n        run: |\n          pip install poetry\n          poetry install\n      - name: Run detection benchmark test\n        run: |\n          poetry run python benchmark/detection.py --max_rows 2\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection\n      - name: Run recognition benchmark test\n        run: |\n          poetry run python benchmark/recognition.py --max_rows 2\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition\n      - name: Run layout benchmark test\n        run: |\n          poetry run python benchmark/layout.py --max_rows 5\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout\n      - name: Run ordering benchmark\n        run: |\n          poetry run python benchmark/ordering.py --max_rows 5\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering\n      - name: Run table recognition benchmark\n        run: |\n          poetry run python benchmark/table_recognition.py --max_rows 5\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition\n      - name: Run texify benchmark\n        run: |\n          poetry run python benchmark/texify.py --max_rows 5\n          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: Unit tests\n\non: [push]\n\njobs:\n  build:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [t4_gpu, ubuntu-latest, windows-latest]\n    steps:\n      - uses: actions/checkout@v3\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.11\n      - name: Install python dependencies\n        run: |\n          pip install poetry\n          poetry install\n      - name: Run tests\n        run: poetry run pytest"
  },
  {
    "path": ".github/workflows/cla.yml",
    "content": "name: \"Surya CLA Assistant\"\non:\n  issue_comment:\n    types: [created]\n  pull_request_target:\n    types: [opened,closed,synchronize]\n\n# explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings\npermissions:\n  actions: write\n  contents: write\n  pull-requests: write\n  statuses: write\n\njobs:\n  CLAAssistant:\n    runs-on: ubuntu-latest\n    steps:\n      - name: \"Surya CLA Assistant\"\n        if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'\n        uses: contributor-assistant/github-action@v2.3.0\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          # the below token should have repo scope and must be manually added by you in the repository's secret\n          # This token is required only if you have configured to store the signatures in a remote repository/organization\n          PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}\n        with:\n          path-to-signatures: 'signatures/version1/cla.json'\n          path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md'\n          # branch should not be protected\n          branch: 'master'\n          allowlist: VikParuchuri"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: Python package\non:\n  push:\n    tags:\n      - \"v*.*.*\"\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.11\n      - name: Install python dependencies\n        run: |\n          pip install poetry\n          poetry install\n      - name: Build package\n        run: |\n          poetry build\n      - name: Publish package\n        env:\n          PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}\n        run: |\n          poetry config pypi-token.pypi \"$PYPI_TOKEN\"\n          poetry publish\n"
  },
  {
    "path": ".github/workflows/scripts.yml",
    "content": "name: Test CLI scripts\n\non: [push]\n\njobs:\n  build:\n    runs-on: t4_gpu\n    steps:\n      - uses: actions/checkout@v3\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.11\n      - name: Install python dependencies\n        run: |\n          pip install poetry\n          poetry install\n      - name: Download benchmark data\n        run: |\n          wget -O benchmark_data.zip \"https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi\"\n          unzip -o benchmark_data.zip\n      - name: Test detection\n        run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0\n      - name: Test OCR\n        env:\n          RECOGNITION_MAX_TOKENS: 25\n        run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0\n      - name: Test layout\n        run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0\n      - name: Test table\n        run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0\n      - name: Test texify\n        env:\n          TEXIFY_MAX_TOKENS: 25\n        run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0\n      - name: Test detection folder\n        run: poetry run surya_detect benchmark_data/pdfs --page_range 0\n"
  },
  {
    "path": ".gitignore",
    "content": "private.py\n.DS_Store\nlocal.env\nexperiments\ntest_data\ntraining\nwandb\nnotebooks\nresults\ndata\nslices\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n- repo: https://github.com/astral-sh/ruff-pre-commit\n  # Ruff version.\n  rev: v0.9.10\n  hooks:\n    # Run the linter.\n    - id: ruff\n      types_or: [ python, pyi ]\n      args: [ --fix ]\n    # Run the formatter.\n    - id: ruff-format\n      types_or: [ python, pyi ]"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it using the following metadata.\"\ntitle: \"Surya: A lightweight framework for analyzing documents and PDFs at scale\"\nauthors:\n  - family-names: Paruchuri\n    given-names: Vikas\n  - name: Datalab Team\ndate-released: 2025-05-13\nurl: https://github.com/VikParuchuri/surya\nversion: 0.14.0\nrepository-code: https://github.com/VikParuchuri/surya"
  },
  {
    "path": "CLA.md",
    "content": "Surya Contributor Agreement\n\nThis Surya Contributor Agreement (\"SCA\") applies to any contribution that you make to any product or project managed by us (the \"project\"), and sets out the intellectual property rights you grant to us in the contributed materials. The term \"us\" shall mean Endless Labs, Inc. The term \"you\" shall mean the person or entity identified below. \n\nIf you agree to be bound by these terms, sign by writing \"I have read the CLA document and I hereby sign the CLA\" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.\n\n1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. \n2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: \n   - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; \n   - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; \n   - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; \n   - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and \n   - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. \n3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to:\n   - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and\n   - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. \nIf you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed.\n4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. \n5. You covenant, represent, warrant and agree that: \n   - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; \n   - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and \n   - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.\nYou agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. \n6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply."
  },
  {
    "path": "LICENSE",
    "content": "                    GNU GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU General Public License is a free, copyleft license for\nsoftware and other kinds of works.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nthe GNU General Public License is intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.  We, the Free Software Foundation, use the\nGNU General Public License for most of our software; it applies also to\nany other work released this way by its authors.  You can apply it to\nyour programs, too.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  To protect your rights, we need to prevent others from denying you\nthese rights or asking you to surrender the rights.  Therefore, you have\ncertain responsibilities if you distribute copies of the software, or if\nyou modify it: responsibilities to respect the freedom of others.\n\n  For example, if you distribute copies of such a program, whether\ngratis or for a fee, you must pass on to the recipients the same\nfreedoms that you received.  You must make sure that they, too, receive\nor can get the source code.  And you must show them these terms so they\nknow their rights.\n\n  Developers that use the GNU GPL protect your rights with two steps:\n(1) assert copyright on the software, and (2) offer you this License\ngiving you legal permission to copy, distribute and/or modify it.\n\n  For the developers' and authors' protection, the GPL clearly explains\nthat there is no warranty for this free software.  For both users' and\nauthors' sake, the GPL requires that modified versions be marked as\nchanged, so that their problems will not be attributed erroneously to\nauthors of previous versions.\n\n  Some devices are designed to deny users access to install or run\nmodified versions of the software inside them, although the manufacturer\ncan do so.  This is fundamentally incompatible with the aim of\nprotecting users' freedom to change the software.  The systematic\npattern of such abuse occurs in the area of products for individuals to\nuse, which is precisely where it is most unacceptable.  Therefore, we\nhave designed this version of the GPL to prohibit the practice for those\nproducts.  If such problems arise substantially in other domains, we\nstand ready to extend this provision to those domains in future versions\nof the GPL, as needed to protect the freedom of users.\n\n  Finally, every program is threatened constantly by software patents.\nStates should not allow patents to restrict development and use of\nsoftware on general-purpose computers, but in those that do, we wish to\navoid the special danger that patents applied to a free program could\nmake it effectively proprietary.  To prevent this, the GPL assures that\npatents cannot be used to render the program non-free.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Use with the GNU Affero General Public License.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU Affero General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the special requirements of the GNU Affero General Public License,\nsection 13, concerning interaction through a network will apply to the\ncombination as such.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU General Public License from time to time.  Such new versions will\nbe similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    Surya OCR\n    Copyright (C) 2024  Endless Labs, Inc.\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU General Public License as published by\n    the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU General Public License for more details.\n\n    You should have received a copy of the GNU General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If the program does terminal interaction, make it output a short\nnotice like this when it starts in an interactive mode:\n\n    Surya OCR Copyright (C) 2024  Endless Labs, Inc.\n    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.\n    This is free software, and you are welcome to redistribute it\n    under certain conditions; type `show c' for details.\n\nThe hypothetical commands `show w' and `show c' should show the appropriate\nparts of the General Public License.  Of course, your program's commands\nmight be different; for a GUI interface, you would use an \"about box\".\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU GPL, see\n<https://www.gnu.org/licenses/>.\n\n  The GNU General Public License does not permit incorporating your program\ninto proprietary programs.  If your program is a subroutine library, you\nmay consider it more useful to permit linking proprietary applications with\nthe library.  If this is what you want to do, use the GNU Lesser General\nPublic License instead of this License.  But first, please read\n<https://www.gnu.org/licenses/why-not-lgpl.html>."
  },
  {
    "path": "MODEL_LICENSE",
    "content": "                   AI PUBS OPEN RAIL-M LICENSE (MODIFIED)\n\nVersion 0.1, March 2, 2023 (Modified)\nhttp://licenses.ai/\n\nPLEASE READ THESE TERMS CAREFULLY BEFORE USING THE MODEL OR A DERIVATIVE WORKS OF THE MODEL MADE AVAILABLE IN CONNECTION WITH THESE TERMS.  BY DOWNLOADING, REPRODUCING, DISTRIBUTING OR USING THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER, YOU (“YOU”) AGREE TO BE BOUND BY THESE TERMS (THE “AGREEMENT”) TO THE EXCLUSION OF ALL OTHER TERMS. YOU REPRESENT AND WARRANT THAT YOU HAVE THE AUTHORITY TO ENTER INTO THIS AGREEMENT; IF YOU ARE ENTERING INTO THIS AGREEMENT ON BEHALF OF AN ORGANIZATION OR ENTITY, REFERENCES TO AND “YOU” IN THIS AGREEMENT, REFER TO THAT ORGANIZATION OR ENTITY. IF YOU DO NOT AGREE TO ALL OF THE FOLLOWING, YOU MAY NOT DOWNLOAD, REPRODUCE, DISTRIBUTE OR USE THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER.\n Section  I:  PREAMBLE\nThis OpenRAIL-M License, as modified, is generally applicable to any machine-learning Model.\nThe “Open” nomenclature indicates that the licensed Model is be freely accessible to downstream and other users.  The “RAIL” nomenclature indicates that there are use restrictions prohibiting the use of the Model. These restrictions are intended to avoid potential misuse. This License specifies that the  use restrictions in the original License must apply to such derivatives.\nNOW THEREFORE, You and Licensor agree as follows:\n1. Definitions\n(a) “Complementary Material” means the applicable source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, and any related information, if any. Complementary Material is not licensed under this License.\n(b) \"Contribution\" means any work, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the rights owner or by an individual or legal entity authorized to submit on behalf of the rights owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the rights owner as \"Not a Contribution.\"\n(c) \"Contributor\"  means Licensor and any individual or legal entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.\n(d) “Data” means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.\n(e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.\n(f) “Distribution” means any transmission, reproduction, publication, distribution, or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means, including but not limited to API-based or web access.\n(g) “Harm” includes but is not limited to physical, mental, psychological, financial and reputational damage, pain, or loss\n(h) \"License\" means the terms and conditions for use, reproduction, and Distribution as defined in this document.\n(i) “Licensor” means the rights owner or entity authorized by the rights owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.\n(j) “Model” means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.\n(k) “Output” means the results of operating a Model as embodied in informational content resulting therefrom.\n(l) “Third Parties” means individuals or legal entities that are not under common control with Licensor or You.\n(m) \"You\" (or \"Your\")  means an individual or legal entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application, including but not limited to a chatbot, translator, or image generator.\n                                              Section II:   INTELLECTUAL PROPERTY RIGHTS\nBoth copyright and patent grants may apply to the Model and Derivatives of the Model. The Model and Derivatives of the Model are subject to additional terms as described in Section III, which shall govern the use of the Model and Derivatives of the Model even in the event Section II is held unenforceable.\n2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Model and Derivatives of the Model.\n3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and/or Derivatives of the Model where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model or Derivatives of the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model or Derivative of the Model and/or a Contribution incorporated within the Model or Derivative of the Model constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Derivative of the Model shall terminate as of the date such litigation is asserted or filed.\nSection III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION\n4. Distribution and Redistribution. You may host the Model or Derivatives of the Model for remote access by Third Parties, including but not limited to  software-as-a-service, reproduce,  or Distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the conditions in this Section III:\n(a) Use-based restrictions in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (for example, a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model and Derivatives of the Model are subject to paragraph 5;\n(b) You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;\n(c) You must cause any modified files to carry prominent notices stating that You changed the files; and\n(d) You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model or  Derivatives of the Model.\nYou may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions, consistent with paragraph 4.a., for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.\n5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Accordingly, You cannot use the Model or the Derivatives of the Model in violation of such restrictions. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, fine-tuning, updating, running, training, evaluating and/or re-parametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph 5.\n6.  The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are solely responsible for the Output you generate and its subsequent uses. No use of the Output can contravene any provision as stated in the License.\n7.  Attribution.  In connection with any Output, or use of Distribution of any Model or Derivatives of the Model, You agree to give appropriate credit and attribution to Licensor, provide a link to the original Model or Derivatives of the Model, provide a copy of this License, and identify any changes You have made to the Model or Derivatives of the Model (collectively, the “Attribution”).  The Attribution must not suggest endorsement by any Licensor.\n8.  Share-a-Like.  As a condition to the license and authorizations herein, You agree to apply this License (to the exclusion of all others) to any and all copies of the Model, Derivatives of the Model, any changes or improvements to the Model or Derivatives of the Model, and to the Output and any derivatives, changes or improvements to or of the Output.\nSection IV: OTHER PROVISIONS\n9. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or cause modification to the Output resulting from updates to the Model based.\n10. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.\n11. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model (and each Contributor provides its Contributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model and Derivatives of the Model, and assume any risks associated with Your exercise of permissions under this License.\n12. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n13. Accepting Warranty or Additional Liability. While Distributing the Model or Derivatives of the Model, You may choose to charge a fee in exchange for support, warranty, indemnity, or other obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor or Licensor, and only if You agree to indemnify, defend, and hold each Contributor and the Licensor harmless for any liability incurred by, or claims asserted against, such Contributor or Licensor by reason of your accepting any such warranty or additional liability.\n14. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.\nEND OF TERMS AND CONDITIONS\n\nAttachment A\nUSE RESTRICTIONS\nAs conditions to the Licenses set forth in this Agreement, You agree not to use, reproduce, modify, create or Distribute the Model, Derivatives of the Model, or Output (collectively, “Use”)  in any of the following ways:\n1. Legal:\n(a) In any way that violates any applicable national, federal, state, local or international law or regulation; or\n(b) to directly or indirectly infringe or misappropriate any third party intellectual property rights (including those of Licensor or any Contributor)\n2. Commercial:\n(a) for any purpose if You (your employer, or the entity you are affiliated with) generated more than two million US Dollars ($2,000,000) in gross revenue in the prior year, except where Your Use is limited to personal use or research purposes;\n(b) for any purpose if You (your employer, or the entity you are affiliated with) has raised more than two million US dollars ($2,000,000) in total equity or debt funding from any source, except where Your Use is limited to personal use or research purposes; or\n(c)  for any purpose if You (your employer, or the entity you are affiliated with) provides or otherwise makes available any product or service that competes with any product or service offered by or made available by Licensor or any of its affiliates.\nCommercial and broader use licenses may be available from Licensor at the following URL: https://www.datalab.to/"
  },
  {
    "path": "README.md",
    "content": "# Surya\n\nSurya is a document OCR toolkit that does:\n\n- OCR in 90+ languages that benchmarks favorably vs cloud services\n- Line-level text detection in any language\n- Layout analysis (table, image, header, etc detection)\n- Reading order detection\n- Table recognition (detecting rows/columns)\n- LaTeX OCR\n\nIt works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).\n\nFor our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya).\n\n\n|                            Detection                             |                                   OCR                                   |\n|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|\n|  <img src=\"static/images/excerpt.png\" width=\"500px\"/>  |  <img src=\"static/images/excerpt_text.png\" width=\"500px\"/> |\n\n|                               Layout                               |                               Reading Order                                |\n|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|\n| <img src=\"static/images/excerpt_layout.png\" width=\"500px\"/> | <img src=\"static/images/excerpt_reading.jpg\" width=\"500px\"/> |\n\n|                       Table Recognition                       |                       LaTeX OCR                        |\n|:-------------------------------------------------------------:|:------------------------------------------------------:|\n| <img src=\"static/images/scanned_tablerec.png\" width=\"500px\"/> | <img src=\"static/images/latex_ocr.png\" width=\"500px\"/> |\n\n\nSurya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.\n\n## Community\n\n[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.\n\n## Examples\n\n| Name             |              Detection              |                                      OCR |                                     Layout |                                       Order |                                    Table Rec |\n|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|\n| Japanese         | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |\n| Chinese          | [Image](static/images/chinese.jpg)  |  [Image](static/images/chinese_text.jpg) |  [Image](static/images/chinese_layout.jpg) |  [Image](static/images/chinese_reading.jpg) |                                              |\n| Hindi            |  [Image](static/images/hindi.jpg)   |    [Image](static/images/hindi_text.jpg) |    [Image](static/images/hindi_layout.jpg) |    [Image](static/images/hindi_reading.jpg) |                                              |\n| Arabic           |  [Image](static/images/arabic.jpg)  |   [Image](static/images/arabic_text.jpg) |   [Image](static/images/arabic_layout.jpg) |   [Image](static/images/arabic_reading.jpg) |                                              |\n| Chinese + Hindi  | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) |                                              |\n| Presentation     |   [Image](static/images/pres.png)   |     [Image](static/images/pres_text.jpg) |     [Image](static/images/pres_layout.jpg) |     [Image](static/images/pres_reading.jpg) |     [Image](static/images/pres_tablerec.png) |\n| Scientific Paper |  [Image](static/images/paper.jpg)   |    [Image](static/images/paper_text.jpg) |    [Image](static/images/paper_layout.jpg) |    [Image](static/images/paper_reading.jpg) |    [Image](static/images/paper_tablerec.png) |\n| Scanned Document | [Image](static/images/scanned.png)  |  [Image](static/images/scanned_text.jpg) |  [Image](static/images/scanned_layout.jpg) |  [Image](static/images/scanned_reading.jpg) |  [Image](static/images/scanned_tablerec.png) |\n| New York Times   |   [Image](static/images/nyt.jpg)    |      [Image](static/images/nyt_text.jpg) |      [Image](static/images/nyt_layout.jpg) |        [Image](static/images/nyt_order.jpg) |                                              |\n| Scanned Form     |  [Image](static/images/funsd.png)   |    [Image](static/images/funsd_text.jpg) |    [Image](static/images/funsd_layout.jpg) |    [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |\n| Textbook         | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) |   [Image](static/images/textbook_order.jpg) |                                              |\n\n# Hosted API\n\nThere is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya):\n\n- Works with PDF, images, word docs, and powerpoints\n- Consistent speed, with no latency spikes\n- High reliability and uptime\n\n# Commercial usage\n\nOur model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya).\n\n\n# Installation\n\nYou'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine.  See [here](https://pytorch.org/get-started/locally/) for more details.\n\nInstall with:\n\n```shell\npip install surya-ocr\n```\n\nModel weights will automatically download the first time you run surya.\n\n# Usage\n\n- Inspect the settings in `surya/settings.py`.  You can override any settings with environment variables.\n- Your torch device will be automatically detected, but you can override this.  For example, `TORCH_DEVICE=cuda`.\n\n## Interactive App\n\nI've included a streamlit app that lets you interactively try Surya on images or PDF files.  Run it with:\n\n```shell\npip install streamlit pdftext\nsurya_gui\n```\n\n## OCR (text recognition)\n\nThis command will write out a json file with the detected text and bboxes:\n\n```shell\nsurya_ocr DATA_PATH\n```\n\n- `DATA_PATH` can be an image, pdf, or folder of images/pdfs\n- `--task_name` will specify which task to use for predicting the lines.  `ocr_with_boxes` is the default, which will format text and give you bboxes.  If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes.  For blocks like equations and paragraphs, try `block_without_boxes`.\n- `--images` will save images of the pages and detected text lines (optional)\n- `--output_dir` specifies the directory to save results to instead of the default\n- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.\n- `--disable_math` - by default, surya will recognize math in text.  This can lead to false positives - you can disable this with this flag.\n\nThe `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:\n\n- `text_lines` - the detected text and bounding boxes for each line\n  - `text` - the text in the line\n  - `confidence` - the confidence of the model in the detected text (0-1)\n  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.\n  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.\n  - `chars` - the individual characters in the line\n    - `text` - the text of the character\n    - `bbox` - the character bbox (same format as line bbox)\n    - `polygon` - the character polygon (same format as line polygon)\n    - `confidence` - the confidence of the model in the detected character (0-1)\n    - `bbox_valid` - if the character is a special token or math, the bbox may not be valid\n  - `words` - the individual words in the line (computed from the characters)\n    - `text` - the text of the word\n    - `bbox` - the word bbox (same format as line bbox)\n    - `polygon` - the word polygon (same format as line polygon)\n    - `confidence` - mean character confidence\n    - `bbox_valid` - if the word is a special token or math, the bbox may not be valid\n- `page` - the page number in the file\n- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.\n\n**Performance tips**\n\nSetting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `40MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `512`, which will use about 20GB of VRAM.  Depending on your CPU core count, it may help, too - the default CPU batch size is `32`.\n\n### From python\n\n```python\nfrom PIL import Image\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.detection import DetectionPredictor\n\nimage = Image.open(IMAGE_PATH)\nfoundation_predictor = FoundationPredictor()\nrecognition_predictor = RecognitionPredictor(foundation_predictor)\ndetection_predictor = DetectionPredictor()\n\npredictions = recognition_predictor([image], det_predictor=detection_predictor)\n```\n\n\n## Text line detection\n\nThis command will write out a json file with the detected bboxes.\n\n```shell\nsurya_detect DATA_PATH\n```\n\n- `DATA_PATH` can be an image, pdf, or folder of images/pdfs\n- `--images` will save images of the pages and detected text lines (optional)\n- `--output_dir` specifies the directory to save results to instead of the default\n- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.\n\nThe `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:\n\n- `bboxes` - detected bounding boxes for text\n  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.\n  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.\n  - `confidence` - the confidence of the model in the detected text (0-1)\n- `vertical_lines` - vertical lines detected in the document\n  - `bbox` - the axis-aligned line coordinates.\n- `page` - the page number in the file\n- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.\n\n**Performance tips**\n\nSetting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `440MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `36`, which will use about 16GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.\n\n### From python\n\n```python\nfrom PIL import Image\nfrom surya.detection import DetectionPredictor\n\nimage = Image.open(IMAGE_PATH)\ndet_predictor = DetectionPredictor()\n\n# predictions is a list of dicts, one per image\npredictions = det_predictor([image])\n```\n\n## Layout and reading order\n\nThis command will write out a json file with the detected layout and reading order.\n\n```shell\nsurya_layout DATA_PATH\n```\n\n- `DATA_PATH` can be an image, pdf, or folder of images/pdfs\n- `--images` will save images of the pages and detected text lines (optional)\n- `--output_dir` specifies the directory to save results to instead of the default\n- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.\n\nThe `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:\n\n- `bboxes` - detected bounding boxes for text\n  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.\n  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.\n  - `position` - the reading order of the box.\n  - `label` - the label for the bbox.  One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.\n  - `top_k` - the top-k other potential labels for the box.  A dictionary with labels as keys and confidences as values.\n- `page` - the page number in the file\n- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.\n\n**Performance tips**\n\nSetting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `220MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `32`, which will use about 7GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.\n\n### From python\n\n```python\nfrom PIL import Image\nfrom surya.foundation import FoundationPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.settings import settings\n\nimage = Image.open(IMAGE_PATH)\nlayout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))\n\n# layout_predictions is a list of dicts, one per image\nlayout_predictions = layout_predictor([image])\n```\n\n## Table Recognition\n\nThis command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.  If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo.  You can use the `TableConverter` to detect and extract tables in images and PDFs.  It supports output in json (with bboxes), markdown, and html.\n\n```shell\nsurya_table DATA_PATH\n```\n\n- `DATA_PATH` can be an image, pdf, or folder of images/pdfs\n- `--images` will save images of the pages and detected table cells + rows and columns (optional)\n- `--output_dir` specifies the directory to save results to instead of the default\n- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.\n- `--detect_boxes` specifies if cells should be detected.  By default, they're pulled out of the PDF, but this is not always possible.\n- `--skip_table_detection` tells table recognition not to detect tables first.  Use this if your image is already cropped to a table.\n\nThe `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:\n\n- `rows` - detected table rows\n  - `bbox` - the bounding box of the table row\n  - `row_id` - the id of the row\n  - `is_header` - if it is a header row.\n- `cols` - detected table columns\n  - `bbox` - the bounding box of the table column\n  - `col_id`- the id of the column\n  - `is_header` - if it is a header column\n- `cells` - detected table cells\n  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.\n  - `text` - if text could be pulled out of the pdf, the text of this cell.\n  - `row_id` - the id of the row the cell belongs to.\n  - `col_id` - the id of the column the cell belongs to.\n  - `colspan` - the number of columns spanned by the cell.\n  - `rowspan` - the number of rows spanned by the cell.\n  - `is_header` - whether it is a header cell.\n- `page` - the page number in the file\n- `table_idx` - the index of the table on the page (sorted in vertical order)\n- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.\n\n**Performance tips**\n\nSetting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `150MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `64`, which will use about 10GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.\n\n### From python\n\n```python\nfrom PIL import Image\nfrom surya.table_rec import TableRecPredictor\n\nimage = Image.open(IMAGE_PATH)\ntable_rec_predictor = TableRecPredictor()\n\ntable_predictions = table_rec_predictor([image])\n```\n\n## LaTeX OCR\n\nThis command will write out a json file with the LaTeX of the equations.  You must pass in images that are already cropped to the equations.  You can do this by running the layout model, then cropping, if you want.\n\n```shell\nsurya_latex_ocr DATA_PATH\n```\n\n- `DATA_PATH` can be an image, pdf, or folder of images/pdfs\n- `--output_dir` specifies the directory to save results to instead of the default\n- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.\n\nThe `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  See the OCR section above for the format of the output.\n\n### From python\n\n```python\nfrom PIL import Image\nfrom surya.texify import TexifyPredictor\n\nimage = Image.open(IMAGE_PATH)\npredictor = TexifyPredictor()\n\npredictor([image])\n```\n\n### Interactive app\n\nYou can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:\n\n```shell\npip install streamlit==1.40 streamlit-drawable-canvas-jsretry\ntexify_gui\n```\n\n## Compilation\n\nThe following models have support for compilation. You will need to set the following environment variables to enable compilation:\n\n- Detection: `COMPILE_DETECTOR=true`\n- Layout: `COMPILE_LAYOUT=true`\n- Table recognition: `COMPILE_TABLE_REC=true`\n\nAlternatively, you can also set `COMPILE_ALL=true` which will compile all models.\n\nHere are the speedups on an A10 GPU:\n\n| Model             | Time per page (s) | Compiled time per page (s) | Speedup (%) |\n| ----------------- | ----------------- | -------------------------- | ----------- |\n| Detection         | 0.108808          | 0.10521                    | 3.306742151 |\n| Layout            | 0.27319           | 0.27063                    | 0.93707676  |\n| Table recognition | 0.0219            | 0.01938                    | 11.50684932 |\n\n# Limitations\n\n- This is specialized for document OCR.  It will likely not work on photos or other images.\n- It is for printed text, not handwriting (though it may work on some handwriting).\n- The text detection model has trained itself to ignore advertisements.\n- You can find language support for OCR in `surya/recognition/languages.py`.  Text detection, layout analysis, and reading order will work with any language.\n\n## Troubleshooting\n\nIf OCR isn't working properly:\n\n- Try increasing resolution of the image so the text is bigger.  If the resolution is already very high, try decreasing it to no more than a `2048px` width.\n- Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images.\n- You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results.  `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space.  `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text.  `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range.  Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds).\n\n# Manual install\n\nIf you want to develop surya, you can install it manually:\n\n- `git clone https://github.com/VikParuchuri/surya.git`\n- `cd surya`\n- `poetry install` - installs main and dev dependencies\n- `poetry shell` - activates the virtual environment\n\n# Benchmarks\n\n## OCR\n\n![Benchmark chart tesseract](static/images/benchmark_rec_chart.png)\n\n| Model     | Time per page (s) | Avg similarity (⬆) |\n|-----------|-------------------|--------------------|\n| surya     | .62               | 0.97               |\n| tesseract | .45               | 0.88               |\n\n[Full language results](static/images/rec_acc_table.png)\n\nTesseract is CPU-based, and surya is CPU or GPU.  I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean).\n\n### Google Cloud Vision\n\nI benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya.\n\n![Benchmark chart google cloud](static/images/gcloud_rec_bench.png)\n\n[Full language results](static/images/gcloud_full_langs.png)\n\n**Methodology**\n\nI measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs.  I sampled PDFs from common crawl, then filtered out the ones with bad OCR.  I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those.\n\nI used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality.\n\nFor Google Cloud, I aligned the output from Google Cloud with the ground truth.  I had to skip RTL languages since they didn't align well.\n\n## Text line detection\n\n![Benchmark chart](static/images/benchmark_chart_small.png)\n\n| Model     | Time (s)   | Time per page (s)   | precision   |   recall |\n|-----------|------------|---------------------|-------------|----------|\n| surya     | 47.2285    | 0.094452            | 0.835857    | 0.960807 |\n| tesseract | 74.4546    | 0.290838            | 0.631498    | 0.997694 |\n\n\nTesseract is CPU-based, and surya is CPU or GPU.  I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU.  This was the resource usage:\n\n- tesseract - 32 CPU cores, or 8 workers using 4 cores each\n- surya - 36 batch size, for 16GB VRAM usage\n\n**Methodology**\n\nSurya predicts line-level bboxes, while tesseract and others predict word-level or character-level.  It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation.\n\nI instead used coverage, which calculates:\n\n- Precision - how well the predicted bboxes cover ground truth bboxes\n- Recall - how well ground truth bboxes cover predicted bboxes\n\nFirst calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes.  Anything with a coverage of 0.5 or higher is considered a match.\n\nThen we calculate precision and recall for the whole dataset.\n\n## Layout analysis\n\n| Layout Type   |   precision |   recall |\n|---------------|-------------|----------|\n| Image         |     0.91265 |  0.93976 |\n| List          |     0.80849 |  0.86792 |\n| Table         |     0.84957 |  0.96104 |\n| Text          |     0.93019 |  0.94571 |\n| Title         |     0.92102 |  0.95404 |\n\nTime per image - .13 seconds on GPU (A10).\n\n**Methodology**\n\nI benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data.  I had to align publaynet labels with the surya layout labels.  I was then able to find coverage for each layout type:\n\n- Precision - how well the predicted bboxes cover ground truth bboxes\n- Recall - how well ground truth bboxes cover predicted bboxes\n\n## Reading Order\n\n88% mean accuracy, and .4 seconds per image on an A10 GPU.  See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.\n\n**Methodology**\n\nI benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data.  Unfortunately, this dataset is fairly noisy, and not all the labels are correct.  It was very hard to find a dataset annotated with reading order and also layout information.  I wanted to avoid using a cloud service for the ground truth.\n\nThe accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.\n\n## Table Recognition\n\n| Model             |   Row Intersection |   Col Intersection |   Time Per Image |\n|-------------------|--------------------|--------------------|------------------|\n| Surya             |               1    |            0.98625 |          0.30202 |\n| Table transformer |               0.84 |            0.86857 |          0.08082 |\n\nHigher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.  This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker)\n\n**Methodology**\n\nThe benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM.  It has labeled rows and columns.  After table recognition is run, the predicted rows and columns are compared to the ground truth.  There is an additional penalty for predicting too many or too few rows/columns.\n\n## LaTeX OCR\n\n| Method | edit ⬇   | time taken (s) ⬇ |\n|--------|----------|------------------|\n| texify | 0.122617 | 35.6345          |\n\nThis inferences texify on a ground truth set of LaTeX, then does edit distance.  This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them.\n\n## Running your own benchmarks\n\nYou can benchmark the performance of surya on your machine.\n\n- Follow the manual install instructions above.\n- `poetry install --group dev` - installs dev dependencies\n\n**Text line detection**\n\nThis will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).\n\n```shell\npython benchmark/detection.py --max_rows 256\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--debug` will render images and detected bboxes\n- `--pdf_path` will let you specify a pdf to benchmark instead of the default data\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n\n**Text recognition**\n\nThis will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).\n\n```shell\npython benchmark/recognition.py --tesseract\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--debug 2` will render images with detected text\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n- `--tesseract` will run the benchmark with tesseract.  You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.\n\n- Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark.\n- Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking.  This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus).\n\n**Layout analysis**\n\nThis will evaluate surya on the publaynet dataset.\n\n```shell\npython benchmark/layout.py\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--debug` will render images with detected text\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n\n**Reading Order**\n\n```shell\npython benchmark/ordering.py\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--debug` will render images with detected text\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n\n**Table Recognition**\n\n```shell\npython benchmark/table_recognition.py --max_rows 1024 --tatr\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--debug` will render images with detected text\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n- `--tatr` specifies whether to also run table transformer\n\n**LaTeX OCR**\n\n```shell\npython benchmark/texify.py --max_rows 128\n```\n\n- `--max_rows` controls how many images to process for the benchmark\n- `--results_dir` will let you specify a directory to save results to instead of the default one\n\n# Training\n\nText detection was trained on 4x A6000s for 3 days.  It used a diverse set of images as training data.  It was trained from scratch using a modified efficientvit architecture for semantic segmentation.\n\nText recognition was trained on 4x A6000s for 2 weeks.  It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).\n\n# Finetuning Surya OCR\nYou can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py).\nIt’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed.\n\nTo setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script.\n```bash\n# Tested on 1xH100 GPU\n# Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise\n# the default surya ocr weights will be loaded as the initialization\npython surya/scripts/finetune_ocr.py \\\n  --output_dir $OUTPUT_DIR \\\n  --dataset_name datalab-to/ocr_finetune_example \\\n  --per_device_train_batch_size 64 \\\n  --gradient_checkpointing true \\\n  --max_sequence_length 1024\n```\n\nThis is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at hi@datalab.to!\n\n# Thanks\n\nThis work would not have been possible without amazing open source AI work:\n\n- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA\n- [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT\n- [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman\n- [Donut](https://github.com/clovaai/donut) from Naver\n- [transformers](https://github.com/huggingface/transformers) from huggingface\n- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model\n\nThank you to everyone who makes open source AI possible.\n\n# Citation\n\nIf you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry:\n\n```bibtex\n@misc{paruchuri2025surya,\n  author       = {Vikas Paruchuri and Datalab Team},\n  title        = {Surya: A lightweight document OCR and analysis toolkit},\n  year         = {2025},\n  howpublished = {\\url{https://github.com/VikParuchuri/surya}},\n  note         = {GitHub repository},\n}\n"
  },
  {
    "path": "benchmark/detection.py",
    "content": "import argparse\nimport collections\nimport copy\nimport json\n\nimport click\n\nfrom benchmark.utils.bbox import get_pdf_lines\nfrom benchmark.utils.metrics import precision_recall\nfrom benchmark.utils.tesseract import tesseract_parallel\nfrom surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb\nfrom surya.debug.draw import draw_polys_on_image\nfrom surya.common.util import rescale_bbox\nfrom surya.settings import settings\nfrom surya.detection import DetectionPredictor\n\nimport os\nimport time\nfrom tabulate import tabulate\nimport datasets\n\n\n@click.command(help=\"Benchmark detection model.\")\n@click.option(\"--pdf_path\", type=str, help=\"Path to PDF to detect bboxes in.\", default=None)\n@click.option(\"--results_dir\", type=str, help=\"Path to JSON file with OCR results.\", default=os.path.join(settings.RESULT_DIR, \"benchmark\"))\n@click.option(\"--max_rows\", type=int, help=\"Maximum number of pdf pages to OCR.\", default=100)\n@click.option(\"--debug\", is_flag=True, help=\"Enable debug mode.\", default=False)\n@click.option(\"--tesseract\", is_flag=True, help=\"Run tesseract as well.\", default=False)\ndef main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool):\n    det_predictor = DetectionPredictor()\n\n    if pdf_path is not None:\n        pathname = pdf_path\n        doc = open_pdf(pdf_path)\n        page_count = len(doc)\n        page_indices = list(range(page_count))\n        page_indices = page_indices[:max_rows]\n\n        images = get_page_images(doc, page_indices)\n        doc.close()\n\n        image_sizes = [img.size for img in images]\n        correct_boxes = get_pdf_lines(pdf_path, image_sizes)\n    else:\n        pathname = \"det_bench\"\n        # These have already been shuffled randomly, so sampling from the start is fine\n        dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f\"train[:{max_rows}]\")\n        images = list(dataset[\"image\"])\n        images = convert_if_not_rgb(images)\n        correct_boxes = []\n        for i, boxes in enumerate(dataset[\"bboxes\"]):\n            img_size = images[i].size\n            # 1000,1000 is bbox size for doclaynet\n            correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])\n\n    if settings.DETECTOR_STATIC_CACHE:\n        # Run through one batch to compile the model\n        det_predictor(images[:1])\n\n    start = time.time()\n    predictions = det_predictor(images)\n    surya_time = time.time() - start\n\n    if tesseract:\n        start = time.time()\n        tess_predictions = tesseract_parallel(images)\n        tess_time = time.time() - start\n    else:\n        tess_predictions = [None] * len(images)\n        tess_time = None\n\n    folder_name = os.path.basename(pathname).split(\".\")[0]\n    result_path = os.path.join(results_dir, folder_name)\n    os.makedirs(result_path, exist_ok=True)\n\n    page_metrics = collections.OrderedDict()\n    for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):\n        surya_boxes = [s.bbox for s in sb.bboxes]\n        surya_polys = [s.polygon for s in sb.bboxes]\n\n        surya_metrics = precision_recall(surya_boxes, cb)\n        if tb is not None:\n            tess_metrics = precision_recall(tb, cb)\n        else:\n            tess_metrics = None\n\n        page_metrics[idx] = {\n            \"surya\": surya_metrics,\n            \"tesseract\": tess_metrics\n        }\n\n        if debug:\n            bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))\n            bbox_image.save(os.path.join(result_path, f\"{idx}_bbox.png\"))\n\n    mean_metrics = {}\n    metric_types = sorted(page_metrics[0][\"surya\"].keys())\n    models = [\"surya\"]\n    if tesseract:\n        models.append(\"tesseract\")\n\n    for k in models:\n        for m in metric_types:\n            metric = []\n            for page in page_metrics:\n                metric.append(page_metrics[page][k][m])\n            if k not in mean_metrics:\n                mean_metrics[k] = {}\n            mean_metrics[k][m] = sum(metric) / len(metric)\n\n    out_data = {\n        \"times\": {\n            \"surya\": surya_time,\n            \"tesseract\": tess_time\n        },\n        \"metrics\": mean_metrics,\n        \"page_metrics\": page_metrics\n    }\n\n    with open(os.path.join(result_path, \"results.json\"), \"w+\", encoding=\"utf-8\") as f:\n        json.dump(out_data, f, indent=4)\n\n    table_headers = [\"Model\", \"Time (s)\", \"Time per page (s)\"] + metric_types\n    table_data = [\n        [\"surya\", surya_time, surya_time / len(images)] + [mean_metrics[\"surya\"][m] for m in metric_types],\n    ]\n    if tesseract:\n        table_data.append(\n            [\"tesseract\", tess_time, tess_time / len(images)] + [mean_metrics[\"tesseract\"][m] for m in metric_types]\n        )\n\n    print(tabulate(table_data, headers=table_headers, tablefmt=\"github\"))\n    print(\"Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.  There is a precision penalty for multiple boxes overlapping reference lines.\")\n    print(f\"Wrote results to {result_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/layout.py",
    "content": "import collections\nimport copy\nimport json\n\nimport click\n\nfrom benchmark.utils.metrics import precision_recall\nfrom surya.foundation import FoundationPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.input.processing import convert_if_not_rgb\nfrom surya.debug.draw import draw_bboxes_on_image\nfrom surya.settings import settings\nimport os\nimport time\nfrom tabulate import tabulate\nimport datasets\n\n\n@click.command(help=\"Benchmark surya layout model.\")\n@click.option(\n    \"--results_dir\",\n    type=str,\n    help=\"Path to JSON file with OCR results.\",\n    default=os.path.join(settings.RESULT_DIR, \"benchmark\"),\n)\n@click.option(\n    \"--max_rows\",\n    type=int,\n    help=\"Maximum number of images to run benchmark on.\",\n    default=100,\n)\n@click.option(\"--debug\", is_flag=True, help=\"Run in debug mode.\", default=False)\ndef main(results_dir: str, max_rows: int, debug: bool):\n    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)\n    layout_predictor = LayoutPredictor(foundation_predictor)\n\n    pathname = \"layout_bench\"\n    # These have already been shuffled randomly, so sampling from the start is fine\n    dataset = datasets.load_dataset(\n        settings.LAYOUT_BENCH_DATASET_NAME, split=f\"train[:{max_rows}]\"\n    )\n    images = list(dataset[\"image\"])\n    images = convert_if_not_rgb(images)\n\n    if settings.LAYOUT_STATIC_CACHE:\n        layout_predictor(images[:1])\n\n    start = time.time()\n    layout_predictions = layout_predictor(images)\n    surya_time = time.time() - start\n\n    folder_name = os.path.basename(pathname).split(\".\")[0]\n    result_path = os.path.join(results_dir, folder_name)\n    os.makedirs(result_path, exist_ok=True)\n\n    label_alignment = {  # First is publaynet, second is surya\n        \"Image\": [[\"Figure\"], [\"Picture\", \"Figure\"]],\n        \"Table\": [[\"Table\"], [\"Table\", \"Form\", \"TableOfContents\"]],\n        \"Text\": [\n            [\"Text\"],\n            [\n                \"Text\",\n                \"Formula\",\n                \"Footnote\",\n                \"Caption\",\n                \"TextInlineMath\",\n                \"Code\",\n                \"Handwriting\",\n            ],\n        ],\n        \"List\": [[\"List\"], [\"ListItem\"]],\n        \"Title\": [[\"Title\"], [\"SectionHeader\", \"Title\"]],\n    }\n\n    page_metrics = collections.OrderedDict()\n    for idx, pred in enumerate(layout_predictions):\n        row = dataset[idx]\n        all_correct_bboxes = []\n        page_results = {}\n        for label_name in label_alignment:\n            correct_cats, surya_cats = label_alignment[label_name]\n            correct_bboxes = [\n                b\n                for b, category in zip(row[\"bboxes\"], row[\"labels\"])\n                if category in correct_cats\n            ]\n            all_correct_bboxes.extend(correct_bboxes)\n            pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]\n\n            metrics = precision_recall(\n                pred_bboxes, correct_bboxes, penalize_double=False\n            )\n            weight = len(correct_bboxes)\n            metrics[\"weight\"] = weight\n            page_results[label_name] = metrics\n\n        page_metrics[idx] = page_results\n\n        if debug:\n            bbox_image = draw_bboxes_on_image(\n                all_correct_bboxes, copy.deepcopy(images[idx])\n            )\n            bbox_image.save(os.path.join(result_path, f\"{idx}_layout.png\"))\n\n    mean_metrics = collections.defaultdict(dict)\n    layout_types = sorted(page_metrics[0].keys())\n    metric_types = sorted(page_metrics[0][layout_types[0]].keys())\n    metric_types.remove(\"weight\")\n    for label in layout_types:\n        for m in metric_types:\n            metric = []\n            total = 0\n            for page in page_metrics:\n                metric.append(\n                    page_metrics[page][label][m] * page_metrics[page][label][\"weight\"]\n                )\n                total += page_metrics[page][label][\"weight\"]\n\n            value = sum(metric)\n            if value > 0:\n                value /= total\n            mean_metrics[label][m] = value\n\n    out_data = {\n        \"time\": surya_time,\n        \"metrics\": mean_metrics,\n        \"page_metrics\": page_metrics,\n    }\n\n    with open(os.path.join(result_path, \"results.json\"), \"w+\", encoding=\"utf-8\") as f:\n        json.dump(out_data, f, indent=4)\n\n    table_headers = [\n        \"Layout Type\",\n    ] + metric_types\n    table_data = []\n    for layout_type in layout_types:\n        table_data.append(\n            [\n                layout_type,\n            ]\n            + [f\"{mean_metrics[layout_type][m]:.5f}\" for m in metric_types]\n        )\n\n    print(tabulate(table_data, headers=table_headers, tablefmt=\"github\"))\n    print(\n        f\"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total.\"\n    )\n    print(\n        \"Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.\"\n    )\n    print(f\"Wrote results to {result_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/ordering.py",
    "content": "import collections\nimport json\n\nimport click\n\nfrom surya.foundation import FoundationPredictor\nfrom surya.input.processing import convert_if_not_rgb\nfrom surya.layout import LayoutPredictor\nfrom surya.common.polygon import PolygonBox\nfrom surya.settings import settings\nfrom benchmark.utils.metrics import rank_accuracy\nimport os\nimport time\nimport datasets\n\n\n@click.command(help=\"Benchmark surya layout for reading order.\")\n@click.option(\n    \"--results_dir\",\n    type=str,\n    help=\"Path to JSON file with benchmark results.\",\n    default=os.path.join(settings.RESULT_DIR, \"benchmark\"),\n)\n@click.option(\n    \"--max_rows\",\n    type=int,\n    help=\"Maximum number of images to run benchmark on.\",\n    default=None,\n)\ndef main(results_dir: str, max_rows: int):\n    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)\n    layout_predictor = LayoutPredictor(foundation_predictor)\n    pathname = \"order_bench\"\n    # These have already been shuffled randomly, so sampling from the start is fine\n    split = \"train\"\n    if max_rows is not None:\n        split = f\"train[:{max_rows}]\"\n    dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)\n    images = list(dataset[\"image\"])\n    images = convert_if_not_rgb(images)\n\n    start = time.time()\n    layout_predictions = layout_predictor(images)\n    surya_time = time.time() - start\n\n    folder_name = os.path.basename(pathname).split(\".\")[0]\n    result_path = os.path.join(results_dir, folder_name)\n    os.makedirs(result_path, exist_ok=True)\n\n    page_metrics = collections.OrderedDict()\n    mean_accuracy = 0\n    for idx, order_pred in enumerate(layout_predictions):\n        row = dataset[idx]\n        labels = row[\"labels\"]\n        bboxes = row[\"bboxes\"]\n        pred_positions = []\n        for label, bbox in zip(labels, bboxes):\n            max_intersection = 0\n            matching_idx = 0\n            for pred_box in order_pred.bboxes:\n                intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox))\n                if intersection > max_intersection:\n                    max_intersection = intersection\n                    matching_idx = pred_box.position\n            pred_positions.append(matching_idx)\n        accuracy = rank_accuracy(pred_positions, labels)\n        mean_accuracy += accuracy\n        page_results = {\"accuracy\": accuracy, \"box_count\": len(labels)}\n\n        page_metrics[idx] = page_results\n\n    mean_accuracy /= len(layout_predictions)\n\n    out_data = {\n        \"time\": surya_time,\n        \"mean_accuracy\": mean_accuracy,\n        \"page_metrics\": page_metrics,\n    }\n\n    with open(os.path.join(result_path, \"results.json\"), \"w+\", encoding=\"utf-8\") as f:\n        json.dump(out_data, f, indent=4)\n\n    print(f\"Mean accuracy is {mean_accuracy:.2f}.\")\n    print(\n        f\"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.\"\n    )\n    print(\"Mean accuracy is the % of correct ranking pairs.\")\n    print(f\"Wrote results to {result_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/recognition.py",
    "content": "import re\nimport unicodedata\nfrom collections import defaultdict\n\nimport click\n\nfrom benchmark.utils.scoring import overlap_score, overlap_score_exact\nfrom surya.input.processing import convert_if_not_rgb\nfrom surya.debug.text import draw_text_on_image\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.settings import settings\nfrom surya.recognition.languages import CODE_TO_LANGUAGE\nfrom benchmark.utils.tesseract import (\n    tesseract_ocr_parallel,\n    surya_lang_to_tesseract,\n    TESS_CODE_TO_LANGUAGE,\n)\nfrom benchmark.utils.textract import textract_ocr_parallel\nimport os\nimport datasets\nimport json\nimport time\nfrom tabulate import tabulate\n\nKEY_LANGUAGES = [\n    \"Chinese\",\n    \"Spanish\",\n    \"English\",\n    \"Arabic\",\n    \"Hindi\",\n    \"Bengali\",\n    \"Russian\",\n    \"Japanese\",\n]\n\n\ndef list_in(lst: str | list, lst2: list):\n    if isinstance(lst, str):\n        lst = [lst]\n    return any([item in lst for item in lst2])\n\n\ndef standardize_bullets(text):\n    patterns = [\n        r\"•\\s+\",\n        r\"·\\s+\",\n        r\"○\\s+\",\n        r\"◦\\s+\",\n        r\"▪\\s+\",\n        r\"▫\\s+\",\n        r\"➢\\s+\",\n        r\"➤\\s+\",\n        r\"★\\s+\",\n        r\"✓\\s+\",\n        r\"✗\\s+\",\n        r\"✦\\s+\",\n        r\"\\\\bullet\\s+\",\n    ]\n\n    combined_pattern = \"|\".join(patterns)\n    text = re.sub(combined_pattern, \"*\", text)\n\n    return text\n\n\ndef normalize_text(text: str) -> str:\n    # Remove HTML tags\n    text = re.sub(r\"<[^>]+>\", \"\", text)\n    # Remove LaTeX tags\n    text = re.sub(r\"\\\\[a-zA-Z]+\", \"\", text)\n    text = standardize_bullets(text)\n    text = unicodedata.normalize(\"NFKC\", text)\n    return text.strip().lower().replace(\",\", \".\")\n\n\n@click.command(help=\"Benchmark recognition model.\")\n@click.option(\n    \"--results_dir\",\n    type=str,\n    help=\"Path to JSON file with OCR results.\",\n    default=os.path.join(settings.RESULT_DIR, \"benchmark\"),\n)\n@click.option(\n    \"--max_rows\", type=int, help=\"Maximum number of pdf pages to OCR.\", default=None\n)\n@click.option(\"--debug\", is_flag=True, help=\"Enable debug mode.\", default=False)\n@click.option(\n    \"--tesseract\", is_flag=True, help=\"Run benchmarks on tesseract.\", default=False\n)\n@click.option(\n    \"--textract\", is_flag=True, help=\"Run benchmarks on textract.\", default=False\n)\n@click.option(\n    \"--tess_cpus\", type=int, help=\"Number of CPUs to use for tesseract.\", default=28\n)\n@click.option(\n    \"--textract_cpus\", type=int, help=\"Number of CPUs to use for textract.\", default=28\n)\n@click.option(\n    \"--languages\",\n    type=str,\n    help=\"Comma-separated list of languages to benchmark.\",\n    default=None,\n)\n@click.option(\n    \"--print_results\",\n    is_flag=True,\n)\ndef main(\n    results_dir: str,\n    max_rows: int,\n    debug: bool,\n    tesseract: bool,\n    textract: bool,\n    tess_cpus: int,\n    textract_cpus: int,\n    languages: str | None,\n    print_results: bool,\n):\n    foundation_predictor = FoundationPredictor()\n    rec_predictor = RecognitionPredictor(foundation_predictor)\n\n    split = \"train\"\n    dataset = datasets.load_dataset(\n        settings.RECOGNITION_BENCH_DATASET_NAME, split=split\n    )\n\n    if languages:\n        languages = languages.split(\",\")\n        dataset = dataset.filter(\n            lambda x: list_in(x[\"language\"], languages), num_proc=4\n        )\n\n    if max_rows and max_rows < len(dataset):\n        dataset = dataset.shuffle(seed=1).select(range(max_rows))\n\n    images = list(dataset[\"image\"])\n    images = convert_if_not_rgb(images)\n    bboxes = list(dataset[\"bboxes\"])\n    line_text = list(dataset[\"text\"])\n    languages = list(dataset[\"language\"])\n\n    print(f\"Loaded {len(images)} images. Running OCR...\")\n\n    start = time.time()\n    predictions_by_image = rec_predictor(images, None, bboxes=bboxes)\n    surya_time = time.time() - start\n\n    lang_list = []\n    for lang in languages:\n        if not isinstance(lang, list):\n            lang_list.append([lang])\n        else:\n            lang_list.append(lang)\n\n    surya_scores = defaultdict(list)\n    img_surya_scores = []\n    outputs = []\n    for idx, (pred, ref_text, langs) in enumerate(\n        zip(predictions_by_image, line_text, lang_list)\n    ):\n        pred_text = [line.text for line in pred.text_lines]\n\n        score_ref_text = [normalize_text(line) for line in ref_text]\n        score_pred_text = [normalize_text(text) for text in pred_text]\n        image_scores, image_weights = overlap_score_exact(\n            score_pred_text, score_ref_text\n        )\n        normalized_scores = [\n            score / max(1, weight) for score, weight in zip(image_scores, image_weights)\n        ]\n        image_score = sum(image_scores) / max(1, sum(image_weights))\n\n        img_surya_scores.append(image_score)\n        for lang in langs:\n            surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score)\n\n        assert len(pred_text) == len(ref_text) == len(bboxes[idx])\n        if debug:\n            for j, (pred_line, ref_line, score, bbox) in enumerate(\n                zip(pred_text, ref_text, normalized_scores, bboxes[idx])\n            ):\n                image_slice = images[idx].crop(bbox)\n\n                outputs.append(\n                    {\n                        \"image\": image_slice,\n                        \"bbox\": bbox,\n                        \"score\": score,\n                        \"pred\": pred_line,\n                        \"ref\": ref_line,\n                        \"langs\": \",\".join(langs),\n                    }\n                )\n\n    if debug:\n        out_ds = datasets.Dataset.from_list(outputs)\n        out_ds.push_to_hub(\"datalab-to/rec_bench_outputs\", private=True)\n\n    flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]]\n    benchmark_stats = {\n        \"surya\": {\n            \"avg_score\": sum(flat_surya_scores) / max(1, len(flat_surya_scores)),\n            \"lang_scores\": {\n                lang: sum(scores) / max(1, len(scores))\n                for lang, scores in surya_scores.items()\n            },\n            \"time_per_img\": surya_time / max(1, len(images)),\n        }\n    }\n\n    result_path = os.path.join(results_dir, \"rec_bench\")\n    os.makedirs(result_path, exist_ok=True)\n\n    with open(os.path.join(result_path, \"surya_scores.json\"), \"w+\") as f:\n        json.dump(surya_scores, f)\n\n    if tesseract:\n        tess_valid = []\n        tess_langs = []\n        for idx, lang in enumerate(lang_list):\n            # Tesseract does not support all languages\n            tess_lang = surya_lang_to_tesseract(lang[0])\n            if tess_lang is None:\n                continue\n\n            tess_valid.append(idx)\n            tess_langs.append(tess_lang)\n\n        tess_imgs = [images[i] for i in tess_valid]\n        tess_bboxes = [bboxes[i] for i in tess_valid]\n        tess_reference = [line_text[i] for i in tess_valid]\n        start = time.time()\n        tess_predictions = tesseract_ocr_parallel(\n            tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus\n        )\n        tesseract_time = time.time() - start\n\n        tess_scores = defaultdict(list)\n        for idx, (pred, ref_text, lang) in enumerate(\n            zip(tess_predictions, tess_reference, tess_langs)\n        ):\n            image_scores, image_weights, _ = overlap_score(pred, ref_text)\n            image_score = sum(image_scores) / max(1, sum(image_weights))\n            tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)\n\n        flat_tess_scores = [\n            score for lang in tess_scores for score in tess_scores[lang]\n        ]\n        benchmark_stats[\"tesseract\"] = {\n            \"avg_score\": sum(flat_tess_scores) / len(flat_tess_scores),\n            \"lang_scores\": {\n                lang: sum(scores) / len(scores) for lang, scores in tess_scores.items()\n            },\n            \"time_per_img\": tesseract_time / len(tess_imgs),\n        }\n\n        with open(os.path.join(result_path, \"tesseract_scores.json\"), \"w+\") as f:\n            json.dump(tess_scores, f)\n\n    if textract:\n        start = time.time()\n        textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)\n        textract_time = time.time() - start\n\n        textract_scores = defaultdict(list)\n        for idx, (pred, ref_text, lang) in enumerate(\n            zip(textract_predictions, line_text, lang_list)\n        ):\n            image_scores, image_weights, _ = overlap_score(pred, ref_text)\n            image_score = sum(image_scores) / max(1, sum(image_weights))\n\n            for lang in lang:\n                textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score)\n\n        flat_textract_scores = [\n            score for lang in textract_scores for score in textract_scores[lang]\n        ]\n        benchmark_stats[\"textract\"] = {\n            \"avg_score\": sum(flat_textract_scores) / len(flat_textract_scores),\n            \"lang_scores\": {\n                lang: sum(scores) / len(scores)\n                for lang, scores in textract_scores.items()\n            },\n            \"time_per_img\": textract_time / len(images),\n        }\n        print(len(flat_textract_scores))\n\n        with open(os.path.join(result_path, \"textract_scores.json\"), \"w+\") as f:\n            json.dump(textract_scores, f)\n\n    with open(os.path.join(result_path, \"results.json\"), \"w+\", encoding=\"utf-8\") as f:\n        json.dump(benchmark_stats, f)\n\n    key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]\n    table_headers = [\"Model\", \"Time per page (s)\", \"Avg Score\"] + key_languages\n    table_data = [\n        [\n            \"surya\",\n            benchmark_stats[\"surya\"][\"time_per_img\"],\n            benchmark_stats[\"surya\"][\"avg_score\"],\n        ]\n        + [benchmark_stats[\"surya\"][\"lang_scores\"][lang] for lang in key_languages],\n    ]\n    if tesseract:\n        table_data.append(\n            [\n                \"tesseract\",\n                benchmark_stats[\"tesseract\"][\"time_per_img\"],\n                benchmark_stats[\"tesseract\"][\"avg_score\"],\n            ]\n            + [\n                benchmark_stats[\"tesseract\"][\"lang_scores\"].get(lang, 0)\n                for lang in key_languages\n            ]\n        )\n    if textract:\n        table_data.append(\n            [\n                \"textract\",\n                benchmark_stats[\"textract\"][\"time_per_img\"],\n                benchmark_stats[\"textract\"][\"avg_score\"],\n            ]\n            + [\n                benchmark_stats[\"textract\"][\"lang_scores\"][lang]\n                for lang in key_languages\n            ],\n        )\n\n    print(tabulate(table_data, headers=table_headers, tablefmt=\"github\"))\n    print(\n        \"Only a few major languages are displayed. See the result path for additional languages.\"\n    )\n\n    if debug >= 1:\n        bad_detections = []\n        for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):\n            if score < 0.8:\n                bad_detections.append((idx, lang, score))\n        print(f\"Found {len(bad_detections)} bad detections. Writing to file...\")\n        with open(os.path.join(result_path, \"bad_detections.json\"), \"w+\") as f:\n            json.dump(bad_detections, f)\n\n    if debug == 2:\n        for idx, (image, pred, ref_text, bbox, lang) in enumerate(\n            zip(images, predictions_by_image, line_text, bboxes, lang_list)\n        ):\n            pred_image_name = f\"{'_'.join(lang)}_{idx}_pred.png\"\n            ref_image_name = f\"{'_'.join(lang)}_{idx}_ref.png\"\n            pred_text = [line.text for line in pred.text_lines]\n            pred_image = draw_text_on_image(bbox, pred_text, image.size)\n            pred_image.save(os.path.join(result_path, pred_image_name))\n            ref_image = draw_text_on_image(bbox, ref_text, image.size)\n            ref_image.save(os.path.join(result_path, ref_image_name))\n            image.save(os.path.join(result_path, f\"{'_'.join(lang)}_{idx}_image.png\"))\n\n    print(f\"Wrote results to {result_path}\")\n\n    if print_results:\n        for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)):\n            print(f\"Image {idx}\")\n            print(\"----\")\n            for line_idx, (pred_line, ref_line) in enumerate(\n                zip(pred.text_lines, ref_text)\n            ):\n                print(f\"Sample {line_idx}\")\n                print(f\"Pred: {pred_line.text}\")\n                print(f\"Ref: {ref_line}\")\n                print()\n\n    if settings.TORCH_DEVICE == \"xla\":\n        import torch_xla.debug.metrics as met\n\n        print(met.short_metrics_report())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/table_recognition.py",
    "content": "import click\nimport collections\nimport json\n\nfrom surya.debug.draw import draw_bboxes_on_image\nfrom tabulate import tabulate\n\nfrom surya.input.processing import convert_if_not_rgb\nfrom surya.table_rec import TableRecPredictor\nfrom surya.settings import settings\nfrom benchmark.utils.metrics import penalized_iou_score\nfrom benchmark.utils.tatr import load_tatr, batch_inference_tatr\nimport os\nimport time\nimport datasets\n\n\n@click.command(help=\"Benchmark table rec dataset\")\n@click.option(\n    \"--results_dir\",\n    type=str,\n    help=\"Path to JSON file with benchmark results.\",\n    default=os.path.join(settings.RESULT_DIR, \"benchmark\"),\n)\n@click.option(\n    \"--max_rows\",\n    type=int,\n    help=\"Maximum number of images to run benchmark on.\",\n    default=512,\n)\n@click.option(\"--tatr\", is_flag=True, help=\"Run table transformer.\", default=False)\n@click.option(\"--debug\", is_flag=True, help=\"Enable debug mode.\", default=False)\ndef main(results_dir: str, max_rows: int, tatr: bool, debug: bool):\n    table_rec_predictor = TableRecPredictor()\n\n    pathname = \"table_rec_bench\"\n    # These have already been shuffled randomly, so sampling from the start is fine\n    split = \"train\"\n    if max_rows is not None:\n        split = f\"train[:{max_rows}]\"\n    dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)\n    images = list(dataset[\"image\"])\n    images = convert_if_not_rgb(images)\n\n    if settings.TABLE_REC_STATIC_CACHE:\n        # Run through one batch to compile the model\n        table_rec_predictor(images[:1])\n\n    start = time.time()\n    table_rec_predictions = table_rec_predictor(images)\n    surya_time = time.time() - start\n\n    folder_name = os.path.basename(pathname).split(\".\")[0]\n    result_path = os.path.join(results_dir, folder_name)\n    os.makedirs(result_path, exist_ok=True)\n\n    page_metrics = collections.OrderedDict()\n    mean_col_iou = 0\n    mean_row_iou = 0\n    for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)):\n        row = dataset[idx]\n        pred_row_boxes = [p.bbox for p in pred.rows]\n        pred_col_bboxes = [p.bbox for p in pred.cols]\n        actual_row_bboxes = [r[\"bbox\"] for r in row[\"rows\"]]\n        actual_col_bboxes = [c[\"bbox\"] for c in row[\"columns\"]]\n        row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)\n        col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)\n        page_results = {\n            \"row_score\": row_score,\n            \"col_score\": col_score,\n            \"row_count\": len(actual_row_bboxes),\n            \"col_count\": len(actual_col_bboxes),\n        }\n\n        mean_col_iou += col_score\n        mean_row_iou += row_score\n\n        page_metrics[idx] = page_results\n\n        if debug:\n            # Save debug images\n            draw_img = image.copy()\n            draw_bboxes_on_image(\n                pred_row_boxes,\n                draw_img,\n                [f\"Row {i}\" for i in range(len(pred_row_boxes))],\n            )\n            draw_bboxes_on_image(\n                pred_col_bboxes,\n                draw_img,\n                [f\"Col {i}\" for i in range(len(pred_col_bboxes))],\n                color=\"blue\",\n            )\n            draw_img.save(os.path.join(result_path, f\"{idx}_bbox.png\"))\n\n            actual_draw_image = image.copy()\n            draw_bboxes_on_image(\n                actual_row_bboxes,\n                actual_draw_image,\n                [f\"Row {i}\" for i in range(len(actual_row_bboxes))],\n            )\n            draw_bboxes_on_image(\n                actual_col_bboxes,\n                actual_draw_image,\n                [f\"Col {i}\" for i in range(len(actual_col_bboxes))],\n                color=\"blue\",\n            )\n            actual_draw_image.save(os.path.join(result_path, f\"{idx}_actual.png\"))\n\n    mean_col_iou /= len(table_rec_predictions)\n    mean_row_iou /= len(table_rec_predictions)\n\n    out_data = {\n        \"surya\": {\n            \"time\": surya_time,\n            \"mean_row_iou\": mean_row_iou,\n            \"mean_col_iou\": mean_col_iou,\n            \"page_metrics\": page_metrics,\n        }\n    }\n\n    if tatr:\n        tatr_model = load_tatr()\n        start = time.time()\n        tatr_predictions = batch_inference_tatr(tatr_model, images, 1)\n        tatr_time = time.time() - start\n\n        page_metrics = collections.OrderedDict()\n        mean_col_iou = 0\n        mean_row_iou = 0\n        for idx, pred in enumerate(tatr_predictions):\n            row = dataset[idx]\n            pred_row_boxes = [p[\"bbox\"] for p in pred[\"rows\"]]\n            pred_col_bboxes = [p[\"bbox\"] for p in pred[\"cols\"]]\n            actual_row_bboxes = [r[\"bbox\"] for r in row[\"rows\"]]\n            actual_col_bboxes = [c[\"bbox\"] for c in row[\"columns\"]]\n            row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)\n            col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)\n            page_results = {\n                \"row_score\": row_score,\n                \"col_score\": col_score,\n                \"row_count\": len(actual_row_bboxes),\n                \"col_count\": len(actual_col_bboxes),\n            }\n\n            mean_col_iou += col_score\n            mean_row_iou += row_score\n\n            page_metrics[idx] = page_results\n\n        mean_col_iou /= len(tatr_predictions)\n        mean_row_iou /= len(tatr_predictions)\n\n        out_data[\"tatr\"] = {\n            \"time\": tatr_time,\n            \"mean_row_iou\": mean_row_iou,\n            \"mean_col_iou\": mean_col_iou,\n            \"page_metrics\": page_metrics,\n        }\n\n    with open(os.path.join(result_path, \"results.json\"), \"w+\", encoding=\"utf-8\") as f:\n        json.dump(out_data, f, indent=4)\n\n    table = [\n        [\"Model\", \"Row Intersection\", \"Col Intersection\", \"Time Per Image\"],\n        [\n            \"Surya\",\n            f\"{out_data['surya']['mean_row_iou']:.2f}\",\n            f\"{out_data['surya']['mean_col_iou']:.5f}\",\n            f\"{surya_time / len(images):.5f}\",\n        ],\n    ]\n\n    if tatr:\n        table.append(\n            [\n                \"Table transformer\",\n                f\"{out_data['tatr']['mean_row_iou']:.2f}\",\n                f\"{out_data['tatr']['mean_col_iou']:.5f}\",\n                f\"{tatr_time / len(images):.5f}\",\n            ]\n        )\n\n    print(tabulate(table, headers=\"firstrow\", tablefmt=\"github\"))\n\n    print(\n        \"Intersection is the average of the intersection % between each actual row/column, and the predictions.  With penalties for too many/few predictions.\"\n    )\n    print(\n        \"Note that table transformers is unbatched, since the example code in the repo is unbatched.\"\n    )\n    print(f\"Wrote results to {result_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/texify.py",
    "content": "import os.path\nimport re\nimport time\nfrom pathlib import Path\nfrom typing import List\n\nimport click\nimport datasets\nfrom tabulate import tabulate\nfrom bs4 import BeautifulSoup\n\nfrom surya.common.surya.schema import TaskNames\nfrom surya.settings import settings\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor, OCRResult\nimport json\nfrom rapidfuzz.distance import Levenshtein\n\n\ndef normalize_text(text):\n    soup = BeautifulSoup(text, \"html.parser\")\n    # Unwrap math tags\n    for tag in soup.find_all():\n        if tag.name == \"math\":\n            tag.unwrap()\n    text = soup.get_text()\n    text = re.sub(r\"\\n\", \" \", text)\n    text = re.sub(r\"\\s+\", \" \", text)\n    return text.strip()\n\n\ndef score_text(predictions, references):\n    lev_dist = []\n    for p, r in zip(predictions, references):\n        p = normalize_text(p)\n        r = normalize_text(r)\n        lev_dist.append(Levenshtein.normalized_distance(p, r))\n\n    return sum(lev_dist) / len(lev_dist)\n\n\ndef inference_texify(\n    source_data, predictor: RecognitionPredictor, line_mode: bool = False\n):\n    images = [sd[\"image\"] for sd in source_data]\n    mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes\n    tasks = [mode] * len(images)\n    bboxes = [[[0, 0, image.width, image.height]] for image in images]\n    texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes)\n    out_data = [\n        {\n            \"text\": texify_predictions[i].text_lines[0].text,\n            \"equation\": source_data[i][\"equation\"],\n        }\n        for i in range(len(texify_predictions))\n    ]\n\n    return out_data\n\n\n@click.command(help=\"Benchmark the performance of texify.\")\n@click.option(\n    \"--ds_name\",\n    type=str,\n    help=\"Path to dataset file with source images/equations.\",\n    default=settings.TEXIFY_BENCHMARK_DATASET,\n)\n@click.option(\n    \"--results_dir\",\n    type=str,\n    help=\"Path to JSON file with benchmark results.\",\n    default=os.path.join(settings.RESULT_DIR, \"benchmark\"),\n)\n@click.option(\n    \"--max_rows\", type=int, help=\"Maximum number of images to benchmark.\", default=None\n)\n@click.option(\n    \"--line_mode\", is_flag=True, help=\"Use line mode for texify.\", default=False\n)\ndef main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):\n    foundation_predictor = FoundationPredictor()\n    predictor = RecognitionPredictor(foundation_predictor)\n    ds = datasets.load_dataset(ds_name, split=\"train\")\n\n    if max_rows:\n        ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)\n\n    start = time.time()\n    predictions = inference_texify(ds, predictor, line_mode)\n    time_taken = time.time() - start\n\n    text = [p[\"text\"] for p in predictions]\n    references = [p[\"equation\"] for p in predictions]\n    scores = score_text(text, references)\n\n    write_data = {\n        \"scores\": scores,\n        \"text\": [{\"prediction\": p, \"reference\": r} for p, r in zip(text, references)],\n    }\n\n    score_table = [[\"texify\", write_data[\"scores\"], time_taken]]\n    score_headers = [\"edit\", \"time taken (s)\"]\n    score_dirs = [\"⬇\", \"⬇\"]\n\n    score_headers = [f\"{h} {d}\" for h, d in zip(score_headers, score_dirs)]\n    table = tabulate(score_table, headers=[\"Method\", *score_headers])\n    print()\n    print(table)\n\n    result_path = Path(results_dir) / \"texify_bench\"\n    result_path.mkdir(parents=True, exist_ok=True)\n    with open(result_path / \"results.json\", \"w\", encoding=\"utf-8\") as f:\n        json.dump(write_data, f, indent=4)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/utils/__init__.py",
    "content": ""
  },
  {
    "path": "benchmark/utils/bbox.py",
    "content": "import fitz as pymupdf\nfrom surya.common.util import rescale_bbox\n\n\ndef get_pdf_lines(pdf_path, img_sizes):\n    doc = pymupdf.open(pdf_path)\n    page_lines = []\n    for idx, img_size in enumerate(img_sizes):\n        page = doc[idx]\n        blocks = page.get_text(\"dict\", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)[\"blocks\"]\n\n        line_boxes = []\n        for block_idx, block in enumerate(blocks):\n            for l in block[\"lines\"]:\n                line_boxes.append(list(l[\"bbox\"]))\n\n        page_box = page.bound()\n        pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]\n        line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]\n        page_lines.append(line_boxes)\n\n    return page_lines\n\ndef merge_boxes(box1, box2):\n    return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))\n\n\ndef join_lines(bboxes, max_gap=5):\n    to_merge = {}\n    for i, box1 in bboxes:\n        for z, box2 in bboxes[i + 1:]:\n            j = i + z + 1\n            if box1 == box2:\n                continue\n\n            if box1[0] <= box2[0] and box1[2] >= box2[2]:\n                if abs(box1[1] - box2[3]) <= max_gap:\n                    if i not in to_merge:\n                        to_merge[i] = []\n                    to_merge[i].append(j)\n\n    merged_boxes = set()\n    merged = []\n    for i, box in bboxes:\n        if i in merged_boxes:\n            continue\n\n        if i in to_merge:\n            for j in to_merge[i]:\n                box = merge_boxes(box, bboxes[j][1])\n                merged_boxes.add(j)\n\n        merged.append(box)\n    return merged\n"
  },
  {
    "path": "benchmark/utils/metrics.py",
    "content": "from functools import partial\nfrom itertools import repeat\n\nimport numpy as np\nfrom concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor\n\n\ndef box_area(box):\n    return (box[2] - box[0]) * (box[3] - box[1])\n\n\ndef calculate_iou(box1, box2, box1_only=False):\n    intersection = intersection_area(box1, box2)\n    union = box_area(box1)\n    if not box1_only:\n        union += box_area(box2) - intersection\n\n    if union == 0:\n        return 0\n    return intersection / union\n\n\ndef match_boxes(preds, references):\n    num_actual = len(references)\n    num_predicted = len(preds)\n\n    iou_matrix = np.zeros((num_actual, num_predicted))\n    for i, actual in enumerate(references):\n        for j, pred in enumerate(preds):\n            iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)\n\n    sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]\n    sorted_ious = iou_matrix.flatten()[sorted_indices]\n    actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)\n\n    assigned_actual = set()\n    assigned_pred = set()\n\n    matches = []\n    for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):\n        i, j = idx\n        if i not in assigned_actual and j not in assigned_pred:\n            iou_val = iou_matrix[i, j]\n            if iou_val > .95: # Account for rounding on box edges\n                iou_val = 1.0\n            matches.append((i, j, iou_val))\n            assigned_actual.add(i)\n            assigned_pred.add(j)\n\n    unassigned_actual = set(range(num_actual)) - assigned_actual\n    unassigned_pred = set(range(num_predicted)) - assigned_pred\n    matches.extend([(i, None, -1.0) for i in unassigned_actual])\n    matches.extend([(None, j, 0.0) for j in unassigned_pred])\n\n    return matches\n\ndef penalized_iou_score(preds, references):\n    matches = match_boxes(preds, references)\n    iou = sum([match[2] for match in matches]) / len(matches)\n    return iou\n\ndef intersection_pixels(box1, box2):\n    x_left = max(box1[0], box2[0])\n    y_top = max(box1[1], box2[1])\n    x_right = min(box1[2], box2[2])\n    y_bottom = min(box1[3], box2[3])\n\n    if x_right < x_left or y_bottom < y_top:\n        return set()\n\n    x_left, x_right = int(x_left), int(x_right)\n    y_top, y_bottom = int(y_top), int(y_bottom)\n\n    coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))\n    pixels = set(zip(coords[0].flat, coords[1].flat))\n\n    return pixels\n\n\ndef calculate_coverage(box, other_boxes, penalize_double=False):\n    box_area = (box[2] - box[0]) * (box[3] - box[1])\n    if box_area == 0:\n        return 0\n\n    # find total coverage of the box\n    covered_pixels = set()\n    double_coverage = list()\n    for other_box in other_boxes:\n        ia = intersection_pixels(box, other_box)\n        double_coverage.append(list(covered_pixels.intersection(ia)))\n        covered_pixels = covered_pixels.union(ia)\n\n    # Penalize double coverage - having multiple bboxes overlapping the same pixels\n    double_coverage_penalty = len(double_coverage)\n    if not penalize_double:\n        double_coverage_penalty = 0\n    covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)\n    return covered_pixels_count / box_area\n\n\ndef intersection_area(box1, box2):\n    x_left = max(box1[0], box2[0])\n    y_top = max(box1[1], box2[1])\n    x_right = min(box1[2], box2[2])\n    y_bottom = min(box1[3], box2[3])\n\n    if x_right < x_left or y_bottom < y_top:\n        return 0.0\n\n    return (x_right - x_left) * (y_bottom - y_top)\n\n\ndef calculate_coverage_fast(box, other_boxes, penalize_double=False):\n    box = np.array(box)\n    other_boxes = np.array(other_boxes)\n\n    # Calculate box area\n    box_area = (box[2] - box[0]) * (box[3] - box[1])\n    if box_area == 0:\n        return 0\n\n    x_left = np.maximum(box[0], other_boxes[:, 0])\n    y_top = np.maximum(box[1], other_boxes[:, 1])\n    x_right = np.minimum(box[2], other_boxes[:, 2])\n    y_bottom = np.minimum(box[3], other_boxes[:, 3])\n\n    widths = np.maximum(0, x_right - x_left)\n    heights = np.maximum(0, y_bottom - y_top)\n    intersect_areas = widths * heights\n\n    total_intersect = np.sum(intersect_areas)\n\n    return min(1.0, total_intersect / box_area)\n\n\ndef precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):\n    if len(references) == 0:\n        return {\n            \"precision\": 1,\n            \"recall\": 1,\n        }\n\n    if len(preds) == 0:\n        return {\n            \"precision\": 0,\n            \"recall\": 0,\n        }\n\n    # If we're not penalizing double coverage, we can use a faster calculation\n    coverage_func = calculate_coverage_fast\n    if penalize_double:\n        coverage_func = calculate_coverage\n\n    with ThreadPoolExecutor(max_workers=workers) as executor:\n        precision_func = partial(coverage_func, penalize_double=penalize_double)\n        precision_iou = executor.map(precision_func, preds, repeat(references))\n        reference_iou = executor.map(coverage_func, references, repeat(preds))\n\n    precision_classes = [1 if i > threshold else 0 for i in precision_iou]\n    precision = sum(precision_classes) / len(precision_classes)\n\n    recall_classes = [1 if i > threshold else 0 for i in reference_iou]\n    recall = sum(recall_classes) / len(recall_classes)\n\n    return {\n        \"precision\": precision,\n        \"recall\": recall,\n    }\n\n\ndef mean_coverage(preds, references):\n    coverages = []\n\n    for box1 in references:\n        coverage = calculate_coverage(box1, preds)\n        coverages.append(coverage)\n\n    for box2 in preds:\n        coverage = calculate_coverage(box2, references)\n        coverages.append(coverage)\n\n    # Calculate the average coverage over all comparisons\n    if len(coverages) == 0:\n        return 0\n    coverage = sum(coverages) / len(coverages)\n    return {\"coverage\": coverage}\n\n\ndef rank_accuracy(preds, references):\n    # Preds and references need to be aligned so each position refers to the same bbox\n    pairs = []\n    for i, pred in enumerate(preds):\n        for j, pred2 in enumerate(preds):\n            if i == j:\n                continue\n            pairs.append((i, j, pred > pred2))\n\n    # Find how many of the prediction rankings are correct\n    correct = 0\n    for i, ref in enumerate(references):\n        for j, ref2 in enumerate(references):\n            if (i, j, ref > ref2) in pairs:\n                correct += 1\n\n    return correct / len(pairs)"
  },
  {
    "path": "benchmark/utils/scoring.py",
    "content": "import math\nfrom typing import List\n\nfrom rapidfuzz import fuzz\n\n\ndef overlap_score(pred_lines: List[str], reference_lines: List[str]):\n    line_scores = []\n    line_weights = []\n    line_match = {}\n    for i, pred_line in enumerate(pred_lines):\n        max_score = 0\n        line_weight = 1\n        match = None\n        for j, ref_line in enumerate(reference_lines):\n            score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100\n            if score > max_score:\n                max_score = score\n                line_weight = math.sqrt(len(ref_line))\n                match = j\n        line_scores.append(max_score)\n        line_weights.append(line_weight)\n        line_match[i] = match\n    line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]\n\n    return line_scores, line_weights, line_match\n\n\ndef overlap_score_exact(pred_lines: List[str], reference_lines: List[str]):\n    line_scores = []\n    line_weights = []\n    assert len(pred_lines) == len(reference_lines)\n\n    for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)):\n        score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100\n        weight = math.sqrt(len(ref_line))\n        line_scores.append(score * weight)\n        line_weights.append(weight)\n\n    return line_scores, line_weights\n"
  },
  {
    "path": "benchmark/utils/tatr.py",
    "content": "import torch\nfrom transformers import AutoModelForObjectDetection\nfrom surya.settings import settings\nimport numpy as np\n\n\nclass MaxResize(object):\n    def __init__(self, max_size=800):\n        self.max_size = max_size\n\n    def __call__(self, image):\n        width, height = image.size\n        current_max_size = max(width, height)\n        scale = self.max_size / current_max_size\n        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))\n\n        return resized_image\n\n\ndef to_tensor(image):\n    # Convert PIL Image to NumPy array\n    np_image = np.array(image).astype(np.float32)\n\n    # Rearrange dimensions to [C, H, W] format\n    np_image = np_image.transpose((2, 0, 1))\n\n    # Normalize to [0.0, 1.0]\n    np_image /= 255.0\n\n    return torch.from_numpy(np_image)\n\n\ndef normalize(tensor, mean, std):\n    for t, m, s in zip(tensor, mean, std):\n        t.sub_(m).div_(s)\n    return tensor\n\n\ndef structure_transform(image):\n    image = MaxResize(1000)(image)\n    tensor = to_tensor(image)\n    normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n    return normalized_tensor\n\n\ndef box_cxcywh_to_xyxy(x):\n    x_c, y_c, w, h = x.unbind(-1)\n    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]\n    return torch.stack(b, dim=1)\n\n\ndef rescale_bboxes(out_bbox, size):\n    width, height = size\n    boxes = box_cxcywh_to_xyxy(out_bbox)\n    boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)\n    return boxes\n\n\ndef outputs_to_objects(outputs, img_sizes, id2label):\n    m = outputs.logits.softmax(-1).max(-1)\n    batch_labels = list(m.indices.detach().cpu().numpy())\n    batch_scores = list(m.values.detach().cpu().numpy())\n    batch_bboxes = outputs['pred_boxes'].detach().cpu()\n\n    batch_objects = []\n    for i in range(len(img_sizes)):\n        pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]\n        pred_scores = batch_scores[i]\n        pred_labels = batch_labels[i]\n\n        objects = []\n        for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):\n            class_label = id2label[int(label)]\n            if not class_label == 'no object':\n                objects.append({\n                    'label': class_label,\n                    'score': float(score),\n                    'bbox': [float(elem) for elem in bbox]}\n                )\n\n        rows = []\n        cols = []\n        for cell in objects:\n            if cell[\"label\"] == \"table column\":\n                cols.append(cell)\n\n            if cell[\"label\"] == \"table row\":\n                rows.append(cell)\n        batch_objects.append({\n            \"rows\": rows,\n            \"cols\": cols\n        })\n\n    return batch_objects\n\n\ndef load_tatr():\n    return AutoModelForObjectDetection.from_pretrained(\"microsoft/table-transformer-structure-recognition-v1.1-all\").to(settings.TORCH_DEVICE_MODEL)\n\n\ndef batch_inference_tatr(model, images, batch_size):\n    device = model.device\n    rows_cols = []\n    for i in range(0, len(images), batch_size):\n        batch_images = images[i:i + batch_size]\n        pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)\n\n        # forward pass\n        with torch.no_grad():\n            outputs = model(pixel_values)\n\n        id2label = model.config.id2label\n        id2label[len(model.config.id2label)] = \"no object\"\n        rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))\n    return rows_cols"
  },
  {
    "path": "benchmark/utils/tesseract.py",
    "content": "from typing import List, Optional\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom surya.input.processing import slice_bboxes_from_image\nfrom surya.settings import settings\nimport os\nfrom concurrent.futures import ProcessPoolExecutor\nfrom surya.recognition.languages import CODE_TO_LANGUAGE\nfrom surya.recognition import RecognitionPredictor\nfrom surya.detection import DetectionPredictor\n\n\ndef surya_lang_to_tesseract(code: str) -> Optional[str]:\n    lang_str = CODE_TO_LANGUAGE[code]\n    try:\n        tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]\n    except KeyError:\n        return None\n    return tess_lang\n\n\ndef tesseract_ocr(img, bboxes, lang: str):\n    import pytesseract\n    line_imgs = slice_bboxes_from_image(img, bboxes)\n    config = f'--tessdata-dir \"{settings.TESSDATA_PREFIX}\"'\n    lines = []\n    for line_img in line_imgs:\n        line = pytesseract.image_to_string(line_img, lang=lang, config=config)\n        lines.append(line)\n    return lines\n\n\ndef tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):\n    tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())\n    if not cpus:\n        cpus = os.cpu_count()\n    tess_parallel_cores = min(tess_parallel_cores, cpus)\n\n    # Tesseract uses up to 4 processes per instance\n    # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images\n    tess_parallel = max(tess_parallel_cores // 2, 1)\n\n    with ProcessPoolExecutor(max_workers=tess_parallel) as executor:\n        tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc=\"Running tesseract OCR\")\n        tess_text = list(tess_text)\n    return tess_text\n\n\ndef tesseract_bboxes(img):\n    import pytesseract\n    from pytesseract import Output\n    arr_img = np.asarray(img, dtype=np.uint8)\n    ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)\n\n    bboxes = []\n    n_boxes = len(ocr['level'])\n    for i in range(n_boxes):\n        # It is possible to merge by line here with line number, but it gives bad results.\n        _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]\n        bbox = (x, y, x + w, y + h)\n        bboxes.append(bbox)\n\n    return bboxes\n\n\ndef tesseract_parallel(imgs):\n    # Tesseract uses 4 threads per instance\n    tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())\n    cpus = os.cpu_count()\n    tess_parallel_cores = min(tess_parallel_cores, cpus)\n\n    # Tesseract uses 4 threads per instance\n    tess_parallel = max(tess_parallel_cores // 4, 1)\n\n    with ProcessPoolExecutor(max_workers=tess_parallel) as executor:\n        tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc=\"Running tesseract bbox detection\")\n        tess_bboxes = list(tess_bboxes)\n    return tess_bboxes\n\n\nTESS_CODE_TO_LANGUAGE = {\n    \"afr\": \"Afrikaans\",\n    \"amh\": \"Amharic\",\n    \"ara\": \"Arabic\",\n    \"asm\": \"Assamese\",\n    \"aze\": \"Azerbaijani\",\n    \"bel\": \"Belarusian\",\n    \"ben\": \"Bengali\",\n    \"bod\": \"Tibetan\",\n    \"bos\": \"Bosnian\",\n    \"bre\": \"Breton\",\n    \"bul\": \"Bulgarian\",\n    \"cat\": \"Catalan\",\n    \"ceb\": \"Cebuano\",\n    \"ces\": \"Czech\",\n    \"chi_sim\": \"Chinese\",\n    \"chr\": \"Cherokee\",\n    \"cym\": \"Welsh\",\n    \"dan\": \"Danish\",\n    \"deu\": \"German\",\n    \"dzo\": \"Dzongkha\",\n    \"ell\": \"Greek\",\n    \"eng\": \"English\",\n    \"epo\": \"Esperanto\",\n    \"est\": \"Estonian\",\n    \"eus\": \"Basque\",\n    \"fas\": \"Persian\",\n    \"fin\": \"Finnish\",\n    \"fra\": \"French\",\n    \"fry\": \"Western Frisian\",\n    \"guj\": \"Gujarati\",\n    \"gla\": \"Scottish Gaelic\",\n    \"gle\": \"Irish\",\n    \"glg\": \"Galician\",\n    \"heb\": \"Hebrew\",\n    \"hin\": \"Hindi\",\n    \"hrv\": \"Croatian\",\n    \"hun\": \"Hungarian\",\n    \"hye\": \"Armenian\",\n    \"iku\": \"Inuktitut\",\n    \"ind\": \"Indonesian\",\n    \"isl\": \"Icelandic\",\n    \"ita\": \"Italian\",\n    \"jav\": \"Javanese\",\n    \"jpn\": \"Japanese\",\n    \"kan\": \"Kannada\",\n    \"kat\": \"Georgian\",\n    \"kaz\": \"Kazakh\",\n    \"khm\": \"Khmer\",\n    \"kir\": \"Kyrgyz\",\n    \"kor\": \"Korean\",\n    \"lao\": \"Lao\",\n    \"lat\": \"Latin\",\n    \"lav\": \"Latvian\",\n    \"lit\": \"Lithuanian\",\n    \"mal\": \"Malayalam\",\n    \"mar\": \"Marathi\",\n    \"mkd\": \"Macedonian\",\n    \"mlt\": \"Maltese\",\n    \"mon\": \"Mongolian\",\n    \"msa\": \"Malay\",\n    \"mya\": \"Burmese\",\n    \"nep\": \"Nepali\",\n    \"nld\": \"Dutch\",\n    \"nor\": \"Norwegian\",\n    \"ori\": \"Oriya\",\n    \"pan\": \"Punjabi\",\n    \"pol\": \"Polish\",\n    \"por\": \"Portuguese\",\n    \"pus\": \"Pashto\",\n    \"ron\": \"Romanian\",\n    \"rus\": \"Russian\",\n    \"san\": \"Sanskrit\",\n    \"sin\": \"Sinhala\",\n    \"slk\": \"Slovak\",\n    \"slv\": \"Slovenian\",\n    \"snd\": \"Sindhi\",\n    \"spa\": \"Spanish\",\n    \"sqi\": \"Albanian\",\n    \"srp\": \"Serbian\",\n    \"swa\": \"Swahili\",\n    \"swe\": \"Swedish\",\n    \"syr\": \"Syriac\",\n    \"tam\": \"Tamil\",\n    \"tel\": \"Telugu\",\n    \"tgk\": \"Tajik\",\n    \"tha\": \"Thai\",\n    \"tir\": \"Tigrinya\",\n    \"tur\": \"Turkish\",\n    \"uig\": \"Uyghur\",\n    \"ukr\": \"Ukrainian\",\n    \"urd\": \"Urdu\",\n    \"uzb\": \"Uzbek\",\n    \"vie\": \"Vietnamese\",\n    \"yid\": \"Yiddish\"\n}\n\nTESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}\n"
  },
  {
    "path": "benchmark/utils/textract.py",
    "content": "import os\nfrom concurrent.futures import ThreadPoolExecutor\nfrom tqdm import tqdm\nimport traceback\n\nfrom surya.input.processing import slice_bboxes_from_image\nfrom surya.recognition import RecognitionPredictor\n\ndef textract_ocr(extractor, img):\n    try:\n        document = extractor.detect_document_text(file_source=img)\n        return [line.text for line in document.lines]\n    except:\n        traceback.print_exc()\n        return [None]\n\ndef textract_ocr_parallel(imgs, cpus=None):\n    from textractor import Textractor # Optional dependency\n\n    extractor = Textractor(profile_name='default')\n    parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())\n    if not cpus:\n        cpus = os.cpu_count()\n    parallel_cores = min(parallel_cores, cpus)\n\n    with ThreadPoolExecutor(max_workers=parallel_cores) as executor:\n        textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc=\"Running textract OCR\")\n        textract_text = list(textract_text)\n    return textract_text"
  },
  {
    "path": "benchmark/utils/verify_benchmark_scores.py",
    "content": "import json\nimport click\n\n\ndef verify_layout(data):\n    scores = data[\"metrics\"]\n    for layout_type, metrics in scores.items():\n        if layout_type == \"List\":  # Skip lists since none appear early on\n            continue\n\n        if metrics[\"precision\"] <= 0.6 or metrics[\"recall\"] <= 0.6:\n            raise ValueError(\"Scores do not meet the required threshold\")\n\n\ndef verify_det(data):\n    scores = data[\"metrics\"][\"surya\"]\n    if scores[\"precision\"] <= 0.9 or scores[\"recall\"] <= 0.9:\n        raise ValueError(\"Scores do not meet the required threshold\")\n\n\ndef verify_rec(data):\n    scores = data[\"surya\"]\n    if scores[\"avg_score\"] <= 0.9:\n        raise ValueError(\"Scores do not meet the required threshold\")\n\n\ndef verify_order(data):\n    score = data[\"mean_accuracy\"]\n    if score < 0.75:\n        raise ValueError(\"Scores do not meet the required threshold\")\n\n\ndef verify_table_rec(data):\n    row_score = data[\"surya\"][\"mean_row_iou\"]\n    col_score = data[\"surya\"][\"mean_col_iou\"]\n\n    if row_score < 0.75 or col_score < 0.75:\n        raise ValueError(\"Scores do not meet the required threshold\")\n\n\ndef verify_texify(data):\n    edit_dist = data[\"scores\"]\n    if edit_dist > 0.2:\n        raise ValueError(\"Scores do not meet the required threshold\")\n\n\n@click.command(help=\"Verify benchmark scores\")\n@click.argument(\"file_path\", type=str)\n@click.option(\n    \"--bench_type\", type=str, help=\"Type of benchmark to verify\", default=\"detection\"\n)\ndef main(file_path, bench_type):\n    with open(file_path, \"r\") as file:\n        data = json.load(file)\n\n    if bench_type == \"detection\":\n        verify_det(data)\n    elif bench_type == \"recognition\":\n        verify_rec(data)\n    elif bench_type == \"layout\":\n        verify_layout(data)\n    elif bench_type == \"ordering\":\n        verify_order(data)\n    elif bench_type == \"table_recognition\":\n        verify_table_rec(data)\n    elif bench_type == \"texify\":\n        verify_texify(data)\n    else:\n        raise ValueError(\"Invalid benchmark type\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "detect_layout.py",
    "content": "from surya.scripts.detect_layout import detect_layout_cli\n\nif __name__ == \"__main__\":\n    detect_layout_cli()\n"
  },
  {
    "path": "detect_text.py",
    "content": "from surya.scripts.detect_text import detect_text_cli\n\nif __name__ == \"__main__\":\n    detect_text_cli()\n\n\n\n\n\n\n\n"
  },
  {
    "path": "ocr_app.py",
    "content": "from surya.scripts.run_streamlit_app import streamlit_app_cli\n\nif __name__ == \"__main__\":\n    streamlit_app_cli()"
  },
  {
    "path": "ocr_latex.py",
    "content": "from surya.scripts.ocr_latex import ocr_latex_cli\n\nif __name__ == \"__main__\":\n    ocr_latex_cli()\n"
  },
  {
    "path": "ocr_text.py",
    "content": "from surya.scripts.ocr_text import ocr_text_cli\n\nif __name__ == \"__main__\":\n    ocr_text_cli()\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.poetry]\nname = \"surya-ocr\"\nversion = \"0.17.1\"\ndescription = \"OCR, layout, reading order, and table recognition in 90+ languages\"\nauthors = [\"Vik Paruchuri <vik.paruchuri@gmail.com>\"]\nreadme = \"README.md\"\nlicense = \"GPL-3.0-or-later\"\nrepository = \"https://github.com/VikParuchuri/surya\"\nkeywords = [\"ocr\", \"pdf\", \"text detection\", \"text recognition\", \"tables\"]\npackages = [\n    {include = \"surya\"}\n]\n\n[tool.poetry.dependencies]\npython = \"^3.10\"\ntransformers = \">=4.56.1\"\ntorch = \"^2.7.0\"\npydantic = \"^2.5.3\"\npydantic-settings = \"^2.1.0\"\npython-dotenv = \"^1.0.0\"\npillow = \"^10.2.0\"\npypdfium2 = \"=4.30.0\"\nfiletype = \"^1.2.0\"\nclick = \"^8.1.8\"\nplatformdirs = \"^4.3.6\"\nopencv-python-headless = \"==4.11.0.86\"\neinops = \"^0.8.1\"\npre-commit = \"^4.2.0\"\n\n[tool.poetry.group.dev.dependencies]\njupyter = \"^1.0.0\"\npytesseract = \"^0.3.10\"\npymupdf = \"^1.23.8\"\ndatasets = \"^2.16.1\"\nrapidfuzz = \"^3.6.1\"\nstreamlit = \"^1.31.0\"\npytest = \"^8.3.4\"\npdftext = \"^0.5.1\"\ntabulate = \"^0.9.0\"\n\n[tool.poetry.scripts]\nsurya_detect = \"surya.scripts.detect_text:detect_text_cli\"\nsurya_ocr = \"surya.scripts.ocr_text:ocr_text_cli\"\nsurya_layout = \"surya.scripts.detect_layout:detect_layout_cli\"\nsurya_gui = \"surya.scripts.run_streamlit_app:streamlit_app_cli\"\nsurya_table = \"surya.scripts.table_recognition:table_recognition_cli\"\nsurya_latex_ocr = \"surya.scripts.ocr_latex:ocr_latex_cli\"\ntexify_gui = \"surya.scripts.run_texify_app:texify_app_cli\"\n\n[build-system]\nrequires = [\"poetry-core\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[[tool.poetry.source]]\nname = \"libtpu-releases\"\nurl = \"https://storage.googleapis.com/libtpu-releases/index.html\"\npriority = \"supplemental\"\n\n[[tool.poetry.source]]\nname = \"libtpu-wheels\"\nurl = \"https://storage.googleapis.com/libtpu-wheels/index.html\"\npriority = \"supplemental\"\n\n[tool.poetry.group.xla]\noptional = true\n\n[tool.poetry.group.xla.dependencies]\ntorch-xla = {version = \"^2.4.1\", extras = [\"tpu\"]}\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\ntestpaths=tests\npythonpath=.\nfilterwarnings =\n    ignore::UserWarning\n    ignore::PendingDeprecationWarning\n    ignore::DeprecationWarning"
  },
  {
    "path": "signatures/version1/cla.json",
    "content": "{\n  \"signedContributors\": [\n    {\n      \"name\": \"rishiraj\",\n      \"id\": 44090649,\n      \"comment_id\": 2170578748,\n      \"created_at\": \"2024-06-15T19:31:20Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 135\n    },\n    {\n      \"name\": \"mmacvicar\",\n      \"id\": 59354,\n      \"comment_id\": 2236493182,\n      \"created_at\": \"2024-07-18T13:17:43Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 152\n    },\n    {\n      \"name\": \"jimexist\",\n      \"id\": 622789,\n      \"comment_id\": 2255151376,\n      \"created_at\": \"2024-07-29T07:23:55Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 160\n    },\n    {\n      \"name\": \"michaeldriscoll-avant\",\n      \"id\": 85255083,\n      \"comment_id\": 2259143427,\n      \"created_at\": \"2024-07-30T20:21:33Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 161\n    },\n    {\n      \"name\": \"EdoardoPona\",\n      \"id\": 29152472,\n      \"comment_id\": 2271115922,\n      \"created_at\": \"2024-08-06T11:58:00Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 167\n    },\n    {\n      \"name\": \"hidenori-endo\",\n      \"id\": 15546605,\n      \"comment_id\": 2307217499,\n      \"created_at\": \"2024-08-23T14:31:17Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 182\n    },\n    {\n      \"name\": \"dobosevych\",\n      \"id\": 12053536,\n      \"comment_id\": 2430376828,\n      \"created_at\": \"2024-10-22T21:48:34Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 220\n    },\n    {\n      \"name\": \"iammosespaulr\",\n      \"id\": 28682735,\n      \"comment_id\": 2447941238,\n      \"created_at\": \"2024-10-30T17:55:23Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 235\n    },\n    {\n      \"name\": \"ArthurMor4is\",\n      \"id\": 42987302,\n      \"comment_id\": 2515315717,\n      \"created_at\": \"2024-12-03T18:37:45Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 255\n    },\n    {\n      \"name\": \"tarun-menta\",\n      \"id\": 66506307,\n      \"comment_id\": 2543457960,\n      \"created_at\": \"2024-12-15T05:43:33Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 261\n    },\n    {\n      \"name\": \"jonaskahn\",\n      \"id\": 4338500,\n      \"comment_id\": 2556622097,\n      \"created_at\": \"2024-12-20T09:36:20Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 269\n    },\n    {\n      \"name\": \"kumsumit\",\n      \"id\": 95072784,\n      \"comment_id\": 2574534622,\n      \"created_at\": \"2025-01-07T07:05:59Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 276\n    },\n    {\n      \"name\": \"kevinhu\",\n      \"id\": 6051736,\n      \"comment_id\": 2614135351,\n      \"created_at\": \"2025-01-25T23:34:12Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 291\n    },\n    {\n      \"name\": \"zanussbaum\",\n      \"id\": 33707069,\n      \"comment_id\": 3008673416,\n      \"created_at\": \"2025-06-26T14:20:46Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 403\n    },\n    {\n      \"name\": \"mebriki\",\n      \"id\": 35892987,\n      \"comment_id\": 3154706976,\n      \"created_at\": \"2025-08-05T10:54:27Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 418\n    },\n    {\n      \"name\": \"starikovplusplus\",\n      \"id\": 56602036,\n      \"comment_id\": 3168958011,\n      \"created_at\": \"2025-08-08T18:29:50Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 423\n    },\n    {\n      \"name\": \"sandy0kwon\",\n      \"id\": 78377296,\n      \"comment_id\": 3207932260,\n      \"created_at\": \"2025-08-20T20:07:15Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 434\n    },\n    {\n      \"name\": \"n0kovo\",\n      \"id\": 16690056,\n      \"comment_id\": 3208251881,\n      \"created_at\": \"2025-08-20T22:22:06Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 435\n    },\n    {\n      \"name\": \"davidxifeng\",\n      \"id\": 158052,\n      \"comment_id\": 3249594859,\n      \"created_at\": \"2025-09-03T14:52:16Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 445\n    },\n    {\n      \"name\": \"u-ashish\",\n      \"id\": 14264791,\n      \"comment_id\": 3258734182,\n      \"created_at\": \"2025-09-05T15:16:48Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 447\n    },\n    {\n      \"name\": \"Mohking1\",\n      \"id\": 63689545,\n      \"comment_id\": 3314908963,\n      \"created_at\": \"2025-09-20T11:21:42Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 462\n    },\n    {\n      \"name\": \"wkpark\",\n      \"id\": 232347,\n      \"comment_id\": 3330009557,\n      \"created_at\": \"2025-09-24T17:42:55Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 464\n    },\n    {\n      \"name\": \"coval3nte\",\n      \"id\": 65908512,\n      \"comment_id\": 3848768229,\n      \"created_at\": \"2026-02-04T17:28:32Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 483\n    },\n    {\n      \"name\": \"bailey-coding\",\n      \"id\": 29517254,\n      \"comment_id\": 3955014177,\n      \"created_at\": \"2026-02-24T22:09:52Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 487\n    },\n    {\n      \"name\": \"Br1an67\",\n      \"id\": 29810238,\n      \"comment_id\": 3979412700,\n      \"created_at\": \"2026-03-01T07:32:18Z\",\n      \"repoId\": 741297064,\n      \"pullRequestNo\": 489\n    }\n  ]\n}"
  },
  {
    "path": "static/fonts/.gitignore",
    "content": "*\n!.gitignore"
  },
  {
    "path": "surya/__init__.py",
    "content": ""
  },
  {
    "path": "surya/common/__init__.py",
    "content": "\n\n\n"
  },
  {
    "path": "surya/common/adetr/decoder.py",
    "content": "from typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom transformers import PretrainedConfig\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_outputs import BaseModelOutputWithNoAttention\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.xla import mark_step\n\n_MAX_SQRT_GRADIENT = 1000.0\n\n\nclass WrappedEmbedding(nn.Embedding):\n    def forward(self, input_ids, *args, **kwargs):\n        return super().forward(input_ids)\n\n\nclass SuryaADETRDecoderRMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        variance = x.pow(2).mean(-1, keepdim=True)\n\n        # Add clipping to prevent division by zero\n        variance = torch.clamp(variance, min=self.eps)\n        return x * torch.rsqrt(variance)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        # Clamp to float16 range\n        f16_info = torch.finfo(x.dtype)\n        output = output.clamp(min=f16_info.min, max=f16_info.max)\n        output = torch.where(\n            torch.isnan(output), torch.tensor(0.0, device=output.device), output\n        )\n        return output.type_as(x)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n\n\nALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm)\n\n\nclass SuryaADETRDecoderRotaryEmbedding(nn.Module):\n    def __init__(self, dim, base=10000, device=None):\n        super().__init__()\n        self.dim = dim\n        self.base = base\n        inv_freq = 1.0 / (\n            self.base\n            ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)\n        )\n        self.register_buffer(\"inv_freq\", tensor=inv_freq, persistent=False)\n\n    @torch.no_grad()\n    # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder\n    def forward(self, x, position_ids, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        self.inv_freq.to(x.device)\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(\n            1, 2\n        )\n        emb = torch.cat((freqs, freqs), dim=-1)\n        cos = emb.cos()\n        sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass SuryaADETRDecoderSdpaCrossAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\n    Modified for GQA\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig):\n        super().__init__()\n        self.config = config\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads\n\n        self.q_proj = nn.Linear(\n            self.hidden_size,\n            self.num_attention_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.k_proj = nn.Linear(\n            self.config.encoder_hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            self.config.encoder_hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_attention_heads * self.head_dim, self.hidden_size, bias=True\n        )\n        self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(\n            self.head_dim,\n            base=config.rope_theta,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # Encoder attention mask currently ignored\n\n        bsz, q_len, _ = hidden_states.size()\n        _, v_len, _ = encoder_hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        query_states = query_states.view(\n            bsz, q_len, self.num_attention_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if self.key_states is None:\n            key_states = self.k_proj(encoder_hidden_states)\n            value_states = self.v_proj(encoder_hidden_states)\n            key_states = key_states.view(\n                bsz, v_len, self.num_key_value_heads, self.head_dim\n            ).transpose(1, 2)\n            value_states = value_states.view(\n                bsz, v_len, self.num_key_value_heads, self.head_dim\n            ).transpose(1, 2)\n            if use_cache:\n                self._update_cache(key_states, value_states)\n        else:\n            key_states = self.key_states\n            value_states = self.value_states\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=None,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            scale=self.head_dim**-0.5,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n    def _clear_cache(self):\n        if self.value_states is not None:\n            del self.value_states\n        if self.key_states is not None:\n            del self.key_states\n\n    def _setup_cache(self, batch_size, device, dtype=None):\n        # Setup initial caches\n        self.value_states = None\n        self.key_states = None\n\n    @torch.no_grad()\n    def _update_cache(self, key_states, value_states, **cache_kwargs):\n        self.value_states = value_states\n        self.key_states = key_states\n\n\nclass SuryaADETRDecoderSdpaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None):\n        super().__init__()\n        self.config = config\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads\n\n        self.q_proj = nn.Linear(\n            self.hidden_size,\n            self.num_attention_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.k_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_attention_heads * self.head_dim, self.hidden_size, bias=True\n        )\n        self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(\n            self.head_dim,\n            base=config.rope_theta,\n        )\n\n        self.static_cache = static_cache\n        self.max_boxes = max_boxes\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        use_cache: bool = False,\n        window_attn: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        # Final is bsz, num_attention_heads, seq_len, head_dim\n        query_states = query_states.view(\n            bsz, q_len, self.num_attention_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if use_cache and hasattr(self, \"key_states\"):\n            cache_kwargs = {\n                \"cache_position\": cache_position,\n                \"window_attn\": window_attn,\n            }\n            key_states, value_states = self._update_cache(\n                key_states, value_states, **cache_kwargs\n            )\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:\n            # Mask is batch, head, seq_len, kv_len\n            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]\n            if cache_position is not None and self.static_cache:\n                current_pos = cache_position[-1]\n                causal_mask[:, :, :, current_pos + 1 :] = torch.finfo(\n                    causal_mask.dtype\n                ).min\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            scale=self.head_dim**-0.5,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n    def _setup_cache(self, batch_size, device, dtype=None):\n        if dtype is None and self.config.torch_dtype is not None:\n            dtype = self.config.torch_dtype\n        dtype = dtype if dtype is not None else torch.float32\n\n        # Setup initial caches\n        self.value_states = None\n        self.key_states = None\n\n        if self.static_cache:\n            cache_shape = (\n                batch_size,\n                self.num_key_value_heads,\n                self.max_boxes,\n                self.head_dim,\n            )\n            self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)\n            self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)\n\n    def _clear_cache(self):\n        if self.value_states is not None:\n            del self.value_states\n        if self.key_states is not None:\n            del self.key_states\n\n    def _update_static_cache(self, key_states, value_states, **cache_kwargs):\n        cache_position = cache_kwargs.get(\"cache_position\")\n        k_out, v_out = (\n            self.key_states.to(key_states.device),\n            self.value_states.to(value_states.device),\n        )\n\n        k_out[:, :, cache_position] = key_states.to(k_out.dtype)\n        v_out[:, :, cache_position] = value_states.to(v_out.dtype)\n\n        self.key_states, self.value_states = k_out, v_out\n        return k_out, v_out\n\n    def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):\n        k_out = key_states\n        if self.key_states is not None:\n            k_out = torch.cat([self.key_states, key_states], dim=2)\n\n        v_out = value_states\n        if self.value_states is not None:\n            v_out = torch.cat([self.value_states, value_states], dim=2)\n\n        self.key_states, self.value_states = k_out, v_out\n        return k_out, v_out\n\n    @torch.no_grad()\n    def _update_cache(self, key_states, value_states, **cache_kwargs):\n        if self.static_cache:\n            return self._update_static_cache(key_states, value_states, **cache_kwargs)\n\n        return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)\n\n\nclass SuryaADETRDecoderMlp(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        if config.hidden_activation is None:\n            config.hidden_activation = \"gelu_pytorch_tanh\"\n        hidden_activation = config.hidden_activation\n        self.act_fn = ACT2FN[hidden_activation]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass SuryaADETRDecoderLayer(nn.Module):\n    def __init__(self, config, layer_idx, static_cache=False, max_boxes=None):\n        super().__init__()\n        self.cross_pre_norm = SuryaADETRDecoderRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.temporal_pre_norm = SuryaADETRDecoderRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n        self.temporal_block = None\n        if layer_idx in config.self_attn_layers:\n            self.temporal_block = SuryaADETRDecoderSdpaAttention(\n                config, static_cache=static_cache, max_boxes=max_boxes\n            )\n\n        self.cross_attn_block = None\n        if layer_idx in config.cross_attn_layers:\n            self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config)\n\n        self.window_attn = layer_idx not in config.global_attn_layers\n        self.channel_pre_norm = SuryaADETRDecoderRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.mlp_block = SuryaADETRDecoderMlp(config)\n\n        self.double_residual_flow = getattr(config, \"double_residual_flow\", False)\n\n    def forward(\n        self,\n        activations: torch.Tensor,\n        position_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        encoder_attention_mask: torch.Tensor = None,\n        cache_position: torch.Tensor = None,\n        use_cache: bool = None,\n    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:\n        if self.double_residual_flow:\n            return self.double_res_forward(\n                activations,\n                position_ids,\n                attention_mask,\n                encoder_hidden_states,\n                encoder_attention_mask,\n                cache_position,\n                use_cache,\n            )\n\n        hidden_states = activations\n        if self.cross_attn_block is not None:\n            # Do cross-attention on encoder outputs\n            cross_attn_inputs = self.cross_pre_norm(hidden_states)\n            cross_attn_path = self.cross_attn_block(\n                cross_attn_inputs,\n                encoder_hidden_states,\n                attention_mask,\n                encoder_attention_mask,\n                use_cache=use_cache,\n            )\n            hidden_states = cross_attn_path + hidden_states\n\n        if self.temporal_block is not None:\n            temporal_inputs = self.temporal_pre_norm(\n                hidden_states\n            )  # RMSNorm introduces slight slight differences\n            temporal_path = self.temporal_block(\n                temporal_inputs,\n                position_ids,\n                attention_mask,\n                cache_position=cache_position,\n                use_cache=use_cache,\n                window_attn=self.window_attn,\n            )\n\n            hidden_states = temporal_path + hidden_states\n\n        block_input = hidden_states\n        hidden_states = self.channel_pre_norm(block_input)\n        hidden_states = self.mlp_block(hidden_states)\n        hidden_states = hidden_states + block_input\n\n        return hidden_states\n\n    def double_res_forward(\n        self,\n        activations: torch.Tensor,\n        position_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        encoder_attention_mask: torch.Tensor = None,\n        cache_position: torch.Tensor = None,\n        use_cache: bool = None,\n    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:\n        raw_activations = activations\n\n        if self.cross_attn_block is not None:\n            # Do cross-attention on encoder outputs\n            cross_attn_inputs = self.cross_pre_norm(activations)\n            cross_attn_path = self.cross_attn_block(\n                cross_attn_inputs,\n                encoder_hidden_states,\n                attention_mask,\n                encoder_attention_mask,\n                use_cache=use_cache,\n            )\n            cross_attn_output = cross_attn_path + raw_activations\n        else:\n            cross_attn_output = raw_activations\n\n        if self.temporal_block is not None:\n            inputs_normalized = self.temporal_pre_norm(\n                cross_attn_output\n            )  # RMSNorm introduces slight slight differences\n            hidden_states = self.temporal_block(\n                inputs_normalized,\n                position_ids,\n                attention_mask,\n                cache_position=cache_position,\n                use_cache=use_cache,\n                window_attn=self.window_attn,\n            )\n\n            residual = hidden_states + raw_activations\n        else:\n            residual = cross_attn_output\n\n        hidden_states = self.channel_pre_norm(residual)\n        hidden_states = self.mlp_block(hidden_states)\n\n        hidden_states = hidden_states + residual\n        return hidden_states\n\n\nclass SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel):\n    config_class = PretrainedConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"SuryaADETRDecoderLayer\"]\n    _skip_keys_device_placement = [\"cache\"]\n    _supports_flash_attn_2 = False\n    _supports_sdpa = False  # we can't compare with eager for now\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n\n    def _init_weights(self, module):\n        if isinstance(module, SuryaADETRDecoderSdpaAttention):\n            torch.nn.init.normal_(\n                module.q_proj.weight, mean=0.0, std=self.config.init_std\n            )\n            torch.nn.init.normal_(\n                module.k_proj.weight, mean=0.0, std=self.config.init_std\n            )\n            torch.nn.init.normal_(\n                module.v_proj.weight, mean=0.0, std=self.config.init_std\n            )\n\n            torch.nn.init.normal_(\n                module.o_proj.weight, mean=0.0, std=self.config.init_std\n            )\n        elif isinstance(module, nn.Linear):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)\n            if getattr(module, \"bias\", None) is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.init_std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _setup_cache(self, config, batch, device, dtype):\n        layers = getattr(self, \"model\", self).layers\n        for layer in layers:\n            if layer.temporal_block:\n                layer.temporal_block._setup_cache(batch, device, dtype)\n            if layer.cross_attn_block:\n                layer.cross_attn_block._setup_cache(batch, device, dtype)\n\n    def _clear_cache(self):\n        layers = getattr(self, \"model\", self).layers\n        for layer in layers:\n            if layer.temporal_block:\n                layer.temporal_block._clear_cache()\n            if layer.cross_attn_block:\n                layer.cross_attn_block._clear_cache()\n\n    def reset_cache(self, batch, device, dtype):\n        pass\n\n    def _tie_weights(self):\n        pass\n\n    def tie_weights(self):\n        pass\n\n\nclass SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`]\n\n    Args:\n        config: PretrainedConfig\n    \"\"\"\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        embedder: nn.Module = None,\n        max_boxes: int = None,\n        static_cache: bool = False,\n    ):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.causal = config.causal\n\n        self.embed_tokens = embedder\n        self.max_boxes = max_boxes\n        self.static_cache = static_cache\n\n        self.layers = nn.ModuleList(\n            [\n                SuryaADETRDecoderLayer(\n                    config, layer_idx, static_cache=static_cache, max_boxes=max_boxes\n                )\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.final_norm = SuryaADETRDecoderRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.gradient_checkpointing = False\n\n        self.register_buffer(\n            \"normalizer\",\n            torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32),\n            persistent=False,\n        )\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        input_boxes_counts: torch.LongTensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        prefill: bool = False,\n    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            use_cache = False\n\n        inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)\n        hidden_states = inputs_embeds\n\n        if use_cache and prefill:\n            self._setup_cache(\n                self.config,\n                hidden_states.shape[0],\n                hidden_states.device,\n                hidden_states.dtype,\n            )\n\n        if cache_position is None:\n            cache_position = torch.arange(\n                hidden_states.shape[1], device=hidden_states.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position\n        )\n\n        all_hidden_states = () if output_hidden_states else None\n        for i, residual_block in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n            if self.gradient_checkpointing and self.training:\n                hidden_states = self._gradient_checkpointing_func(\n                    residual_block.__call__,\n                    hidden_states,\n                    position_ids,\n                    causal_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    cache_position,\n                    use_cache,\n                )\n            else:\n                hidden_states = residual_block(\n                    hidden_states,\n                    position_ids,\n                    causal_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    cache_position,\n                    use_cache,\n                )\n\n        hidden_states = self.final_norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)\n\n        return BaseModelOutputWithNoAttention(\n            last_hidden_state=hidden_states,\n            hidden_states=all_hidden_states,\n        )\n\n    # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n    # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n    # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n    # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n    # Ignore copy\n    def _update_causal_mask(self, attention_mask, input_tensor, cache_position):\n        if not self.causal:\n            return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        target_length = max(self.max_boxes, sequence_length)\n\n        diagonal = torch.full(\n            (sequence_length, target_length),\n            fill_value=min_dtype,\n            dtype=dtype,\n            device=device,\n        )\n        causal_mask = diagonal\n        if sequence_length != 1:\n            # Select the upper triangular part of the matrix, but unmask current token (the diagonal)\n            # triu will be the min_dtype, everything else is 0 (attended to)\n            causal_mask = torch.triu(diagonal, diagonal=1)\n\n        causal_mask *= torch.arange(\n            target_length, device=device\n        ) > cache_position.reshape(-1, 1)\n        causal_mask = causal_mask[None, None, :, :].expand(\n            input_tensor.shape[0], 1, -1, -1\n        )\n        if attention_mask is not None:\n            causal_mask = (\n                causal_mask.clone()\n            )  # copy to contiguous memory for in-place edit\n            if attention_mask.dim() == 2:\n                # Mask positions in the causal mask that are masked in the attention mask\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[\n                    :, None, None, :\n                ].eq(0.0)\n                causal_mask[..., :mask_length] = causal_mask[\n                    ..., :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n\n        if attention_mask is not None and attention_mask.device.type == \"cuda\":\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n"
  },
  {
    "path": "surya/common/donut/encoder.py",
    "content": "import collections.abc\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.pytorch_utils import (\n    find_pruneable_heads_and_indices,\n    meshgrid,\n    prune_linear_layer,\n)\nfrom transformers.utils import ModelOutput\nfrom transformers import DonutSwinConfig\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.xla import mark_step\n\n_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]\n\n\n@dataclass\n# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin\nclass DonutSwinEncoderOutput(ModelOutput):\n    last_hidden_state: torch.FloatTensor = None\n    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None\n    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n\n\n@dataclass\nclass DonutSwinModelOutput(ModelOutput):\n    last_hidden_state: torch.FloatTensor = None\n\n\n# Copied from transformers.models.swin.modeling_swin.window_partition\ndef window_partition(input_feature, window_size):\n    \"\"\"\n    Partitions the given input into windows.\n    \"\"\"\n    batch_size, height, width, num_channels = input_feature.shape\n    input_feature = input_feature.view(\n        batch_size,\n        height // window_size,\n        window_size,\n        width // window_size,\n        window_size,\n        num_channels,\n    )\n    windows = (\n        input_feature.permute(0, 1, 3, 2, 4, 5)\n        .contiguous()\n        .view(-1, window_size, window_size, num_channels)\n    )\n    return windows\n\n\n# Copied from transformers.models.swin.modeling_swin.window_reverse\ndef window_reverse(windows, window_size, height, width):\n    \"\"\"\n    Merges windows to produce higher resolution features.\n    \"\"\"\n    num_channels = windows.shape[-1]\n    windows = windows.view(\n        -1,\n        height // window_size,\n        width // window_size,\n        window_size,\n        window_size,\n        num_channels,\n    )\n    windows = (\n        windows.permute(0, 1, 3, 2, 4, 5)\n        .contiguous()\n        .view(-1, height, width, num_channels)\n    )\n    return windows\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin\nclass DonutSwinEmbeddings(nn.Module):\n    \"\"\"\n    Construct the patch and position embeddings. Optionally, also the mask token.\n    \"\"\"\n\n    def __init__(self, config, use_mask_token=False):\n        super().__init__()\n\n        self.patch_embeddings = DonutSwinPatchEmbeddings(config)\n        num_patches = self.patch_embeddings.num_patches\n        self.patch_grid = self.patch_embeddings.grid_size\n        self.mask_token = (\n            nn.Parameter(torch.zeros(1, 1, config.embed_dim))\n            if use_mask_token\n            else None\n        )\n\n        self.position_embeddings = None\n        self.row_embeddings = None\n        self.column_embeddings = None\n        if config.use_absolute_embeddings:\n            self.position_embeddings = nn.Parameter(\n                torch.zeros(1, num_patches + 1, config.embed_dim)\n            )\n\n        if hasattr(config, \"use_2d_embeddings\") and config.use_2d_embeddings:\n            self.row_embeddings = nn.Parameter(\n                torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)\n            )\n            self.column_embeddings = nn.Parameter(\n                torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)\n            )\n\n        self.norm = nn.LayerNorm(config.embed_dim)\n\n    def interpolate_pos_encoding(\n        self, embeddings: torch.Tensor, height: int, width: int\n    ) -> torch.Tensor:\n        \"\"\"\n        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher\n        resolution images.\n\n        Source:\n        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174\n        \"\"\"\n\n        num_patches = embeddings.shape[1] - 1\n        num_positions = self.position_embeddings.shape[1] - 1\n        if num_patches == num_positions and height == width:\n            return self.position_embeddings\n        class_pos_embed = self.position_embeddings[:, 0]\n        patch_pos_embed = self.position_embeddings[:, 1:]\n        dim = embeddings.shape[-1]\n        h0 = height // self.config.patch_size\n        w0 = width // self.config.patch_size\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        h0, w0 = h0 + 0.1, w0 + 0.1\n        patch_pos_embed = patch_pos_embed.reshape(\n            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim\n        )\n        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed,\n            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),\n            mode=\"bicubic\",\n            align_corners=False,\n        )\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor],\n        bool_masked_pos: Optional[torch.BoolTensor] = None,\n        interpolate_pos_encoding: bool = False,\n    ) -> Tuple[torch.Tensor]:\n        _, num_channels, height, width = pixel_values.shape\n        embeddings, output_dimensions = self.patch_embeddings(pixel_values)\n        embeddings = self.norm(embeddings)\n        batch_size, seq_len, _ = embeddings.size()\n\n        if bool_masked_pos is not None:\n            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)\n            # replace the masked visual tokens by mask_tokens\n            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)\n            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask\n\n        if self.position_embeddings is not None:\n            if interpolate_pos_encoding:\n                embeddings = embeddings + self.interpolate_pos_encoding(\n                    embeddings, height, width\n                )\n            else:\n                embeddings = embeddings + self.position_embeddings[:, :seq_len]\n\n        if self.row_embeddings is not None and self.column_embeddings is not None:\n            # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...\n            row_embeddings = self.row_embeddings[\n                :, : output_dimensions[0], :\n            ].repeat_interleave(output_dimensions[1], dim=1)\n            column_embeddings = self.column_embeddings[\n                :, : output_dimensions[1], :\n            ].repeat(1, output_dimensions[0], 1)\n\n            embeddings = embeddings + row_embeddings + column_embeddings\n\n        return embeddings, output_dimensions\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin\nclass DonutSwinPatchEmbeddings(nn.Module):\n    \"\"\"\n    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial\n    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a\n    Transformer.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        image_size, patch_size = config.image_size, config.patch_size\n        num_channels, hidden_size = config.num_channels, config.embed_dim\n        image_size = (\n            image_size\n            if isinstance(image_size, collections.abc.Iterable)\n            else (image_size, image_size)\n        )\n        patch_size = (\n            patch_size\n            if isinstance(patch_size, collections.abc.Iterable)\n            else (patch_size, patch_size)\n        )\n        num_patches = (image_size[1] // patch_size[1]) * (\n            image_size[0] // patch_size[0]\n        )\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.num_channels = num_channels\n        self.num_patches = num_patches\n        self.grid_size = (\n            image_size[0] // patch_size[0],\n            image_size[1] // patch_size[1],\n        )\n\n        self.projection = nn.Conv2d(\n            num_channels, hidden_size, kernel_size=patch_size, stride=patch_size\n        )\n\n    def maybe_pad(self, pixel_values, height, width):\n        if width % self.patch_size[1] != 0:\n            pad_values = (0, self.patch_size[1] - width % self.patch_size[1])\n            pixel_values = nn.functional.pad(pixel_values, pad_values)\n        if height % self.patch_size[0] != 0:\n            pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])\n            pixel_values = nn.functional.pad(pixel_values, pad_values)\n        return pixel_values\n\n    def forward(\n        self, pixel_values: Optional[torch.FloatTensor]\n    ) -> Tuple[torch.Tensor, Tuple[int]]:\n        _, num_channels, height, width = pixel_values.shape\n        # pad the input to be divisible by self.patch_size, if needed\n        pixel_values = self.maybe_pad(pixel_values, height, width)\n        embeddings = self.projection(pixel_values)\n        _, _, height, width = embeddings.shape\n        output_dimensions = (height, width)\n        embeddings = embeddings.flatten(2).transpose(1, 2)\n\n        return embeddings, output_dimensions\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging\nclass DonutSwinPatchMerging(nn.Module):\n    \"\"\"\n    Patch Merging Layer.\n\n    Args:\n        input_resolution (`Tuple[int]`):\n            Resolution of input feature.\n        dim (`int`):\n            Number of input channels.\n        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):\n            Normalization layer class.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_resolution: Tuple[int],\n        dim: int,\n        norm_layer: nn.Module = nn.LayerNorm,\n    ) -> None:\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def maybe_pad(self, input_feature, height, width):\n        should_pad = (height % 2 == 1) or (width % 2 == 1)\n        if should_pad:\n            pad_values = (0, 0, 0, width % 2, 0, height % 2)\n            input_feature = nn.functional.pad(input_feature, pad_values)\n\n        return input_feature\n\n    def forward(\n        self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]\n    ) -> torch.Tensor:\n        height, width = input_dimensions\n        # `dim` is height * width\n        batch_size, dim, num_channels = input_feature.shape\n\n        input_feature = input_feature.view(batch_size, height, width, num_channels)\n        # pad input to be disible by width and height, if needed\n        input_feature = self.maybe_pad(input_feature, height, width)\n        # [batch_size, height/2, width/2, num_channels]\n        input_feature_0 = input_feature[:, 0::2, 0::2, :]\n        # [batch_size, height/2, width/2, num_channels]\n        input_feature_1 = input_feature[:, 1::2, 0::2, :]\n        # [batch_size, height/2, width/2, num_channels]\n        input_feature_2 = input_feature[:, 0::2, 1::2, :]\n        # [batch_size, height/2, width/2, num_channels]\n        input_feature_3 = input_feature[:, 1::2, 1::2, :]\n        # batch_size height/2 width/2 4*num_channels\n        input_feature = torch.cat(\n            [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1\n        )\n        input_feature = input_feature.view(\n            batch_size, -1, 4 * num_channels\n        )  # batch_size height/2*width/2 4*C\n\n        input_feature = self.norm(input_feature)\n        input_feature = self.reduction(input_feature)\n\n        return input_feature\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin\nclass DonutSwinSelfAttention(nn.Module):\n    def __init__(self, config, dim, num_heads, num_kv_heads, window_size):\n        super().__init__()\n        if dim % num_heads != 0:\n            raise ValueError(\n                f\"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})\"\n            )\n\n        self.num_attention_heads = num_heads\n        self.num_kv_heads = num_kv_heads\n        self.kv_repeats = self.num_attention_heads // self.num_kv_heads\n        self.attention_head_size = int(dim / num_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n        self.kv_head_size = self.num_kv_heads * self.attention_head_size\n        self.window_size = (\n            window_size\n            if isinstance(window_size, collections.abc.Iterable)\n            else (window_size, window_size)\n        )\n\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(\n                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads\n            )\n        )\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(meshgrid([coords_h, coords_w], indexing=\"ij\"))\n        coords_flatten = torch.flatten(coords, 1)\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()\n        relative_coords[:, :, 0] += self.window_size[0] - 1\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.query = nn.Linear(\n            self.all_head_size, self.all_head_size, bias=config.qkv_bias\n        )\n        self.key = nn.Linear(\n            self.all_head_size, self.kv_head_size, bias=config.qkv_bias\n        )\n        self.value = nn.Linear(\n            self.all_head_size, self.kv_head_size, bias=config.qkv_bias\n        )\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (\n            self.num_attention_heads,\n            self.attention_head_size,\n        )\n        x = x.view(new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def transpose_kv_for_scores(self, x, repeats):\n        new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)\n        x = x.view(new_x_shape)\n        x = x.repeat(\n            1, 1, repeats, 1\n        )  # repeat the values for each key-value head to match query dim\n        return x.permute(0, 2, 1, 3).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor]:\n        batch_size, dim, num_channels = hidden_states.shape\n        mixed_query_layer = self.query(hidden_states)\n\n        # Final is (batch_size, num_attention_heads, seq_len, attention_head_size)\n        key_layer = self.transpose_kv_for_scores(\n            self.key(hidden_states), self.kv_repeats\n        )\n        value_layer = self.transpose_kv_for_scores(\n            self.value(hidden_states), self.kv_repeats\n        )\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)\n        ]\n        relative_position_bias = relative_position_bias.view(\n            self.window_size[0] * self.window_size[1],\n            self.window_size[0] * self.window_size[1],\n            -1,\n        )\n        relative_position_bias = (\n            relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)\n        )\n        relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1)\n\n        if attention_mask is None:\n            attention_mask = relative_position_bias\n        else:\n            mask_shape = attention_mask.shape[0]\n            repeat_count = batch_size // mask_shape\n            attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)\n            attention_mask = attention_mask + relative_position_bias\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_layer,\n            key_layer,\n            value_layer,\n            attn_mask=attention_mask,\n            dropout_p=0.0,\n            scale=self.attention_head_size**-0.5,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(batch_size, dim, num_channels)\n\n        outputs = (attn_output,)\n        return outputs\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput\nclass DonutSwinSelfOutput(nn.Module):\n    def __init__(self, config, dim):\n        super().__init__()\n        self.dense = nn.Linear(dim, dim)\n\n    def forward(\n        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor\n    ) -> torch.Tensor:\n        return self.dense(hidden_states)\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin\nclass DonutSwinAttention(nn.Module):\n    def __init__(self, config, dim, num_heads, num_kv_heads, window_size):\n        super().__init__()\n        self.self = DonutSwinSelfAttention(\n            config, dim, num_heads, num_kv_heads, window_size\n        )\n        self.output = DonutSwinSelfOutput(config, dim)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads,\n            self.self.num_attention_heads,\n            self.self.attention_head_size,\n            self.pruned_heads,\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = (\n            self.self.attention_head_size * self.self.num_attention_heads\n        )\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor]:\n        self_outputs = self.self(\n            hidden_states, attention_mask, head_mask, output_attentions\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n        outputs = (attention_output,) + self_outputs[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinIntermediate\nclass DonutSwinIntermediate(nn.Module):\n    def __init__(self, config, dim):\n        super().__init__()\n        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinOutput\nclass DonutSwinOutput(nn.Module):\n    def __init__(self, config, dim):\n        super().__init__()\n        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        return self.dense(hidden_states)\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin\nclass DonutSwinLayer(nn.Module):\n    def __init__(\n        self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0\n    ):\n        super().__init__()\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.shift_size = shift_size\n        self.window_size = config.window_size\n        self.input_resolution = input_resolution\n        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)\n        self.attention = DonutSwinAttention(\n            config, dim, num_heads, num_kv_heads, window_size=self.window_size\n        )\n        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)\n        self.intermediate = DonutSwinIntermediate(config, dim)\n        self.output = DonutSwinOutput(config, dim)\n\n    def set_shift_and_window_size(self, input_resolution):\n        if min(input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = int(0)\n            self.window_size = (\n                torch.min(torch.tensor(input_resolution))\n                if torch.jit.is_tracing()\n                else min(input_resolution)\n            )\n\n    def get_attn_mask(self, height, width, dtype, device):\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)\n            height_slices = (\n                slice(0, -self.window_size),\n                slice(-self.window_size, -self.shift_size),\n                slice(-self.shift_size, None),\n            )\n            width_slices = (\n                slice(0, -self.window_size),\n                slice(-self.window_size, -self.shift_size),\n                slice(-self.shift_size, None),\n            )\n            count = 0\n            for height_slice in height_slices:\n                for width_slice in width_slices:\n                    img_mask[:, height_slice, width_slice, :] = count\n                    count += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(\n                attn_mask != 0, float(-100.0)\n            ).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n        return attn_mask\n\n    def maybe_pad(self, hidden_states, height, width):\n        pad_right = (self.window_size - width % self.window_size) % self.window_size\n        pad_bottom = (self.window_size - height % self.window_size) % self.window_size\n        pad_values = (0, 0, 0, pad_right, 0, pad_bottom)\n        hidden_states = nn.functional.pad(hidden_states, pad_values)\n        return hidden_states, pad_values\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_dimensions: Tuple[int, int],\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n        always_partition: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not always_partition:\n            self.set_shift_and_window_size(input_dimensions)\n        else:\n            pass\n        height, width = input_dimensions\n        batch_size, _, channels = hidden_states.size()\n        shortcut = hidden_states\n\n        hidden_states = self.layernorm_before(hidden_states)\n\n        hidden_states = hidden_states.view(batch_size, height, width, channels)\n\n        # pad hidden_states to multiples of window size\n        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)\n\n        _, height_pad, width_pad, _ = hidden_states.shape\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_hidden_states = torch.roll(\n                hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)\n            )\n        else:\n            shifted_hidden_states = hidden_states\n\n        # partition windows\n        hidden_states_windows = window_partition(\n            shifted_hidden_states, self.window_size\n        )\n        hidden_states_windows = hidden_states_windows.view(\n            -1, self.window_size * self.window_size, channels\n        )\n        attn_mask = self.get_attn_mask(\n            height_pad,\n            width_pad,\n            dtype=hidden_states.dtype,\n            device=hidden_states_windows.device,\n        )\n\n        attention_outputs = self.attention(\n            hidden_states_windows,\n            attn_mask,\n            head_mask,\n            output_attentions=output_attentions,\n        )\n\n        attention_output = attention_outputs[0]\n\n        attention_windows = attention_output.view(\n            -1, self.window_size, self.window_size, channels\n        )\n        shifted_windows = window_reverse(\n            attention_windows, self.window_size, height_pad, width_pad\n        )\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            attention_windows = torch.roll(\n                shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)\n            )\n        else:\n            attention_windows = shifted_windows\n\n        was_padded = pad_values[3] > 0 or pad_values[5] > 0\n        if was_padded:\n            attention_windows = attention_windows[:, :height, :width, :].contiguous()\n\n        attention_windows = attention_windows.view(batch_size, height * width, channels)\n\n        hidden_states = shortcut + attention_windows\n\n        layer_output = self.layernorm_after(hidden_states)\n        layer_output = self.intermediate(layer_output)\n        layer_output = hidden_states + self.output(layer_output)\n\n        layer_outputs = (\n            (layer_output, attention_outputs[1])\n            if output_attentions\n            else (layer_output,)\n        )\n        return layer_outputs\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin\nclass DonutSwinStage(nn.Module):\n    def __init__(\n        self,\n        config,\n        layer_num,\n        dim,\n        input_resolution,\n        depth,\n        num_heads,\n        num_kv_heads,\n        downsample,\n    ):\n        super().__init__()\n        self.config = config\n        self.dim = dim\n        self.blocks = nn.ModuleList(\n            [\n                DonutSwinLayer(\n                    config=config,\n                    dim=dim,\n                    input_resolution=input_resolution,\n                    num_heads=num_heads,\n                    num_kv_heads=num_kv_heads,\n                    shift_size=0 if (i % 2 == 0) else config.window_size // 2,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                input_resolution, dim=dim, norm_layer=nn.LayerNorm\n            )\n        else:\n            self.downsample = None\n\n        self.pointing = False\n\n        self.positional_encoding = None\n        if config.use_positional_embeddings:\n            self.positional_encoding = self.build_2d_sincos_position_embedding(\n                input_resolution[1],\n                input_resolution[0],\n                embed_dim=dim,\n            )\n\n    @staticmethod\n    def build_2d_sincos_position_embedding(\n        width,\n        height,\n        embed_dim=256,\n        temperature=10000.0,\n        device=\"cpu\",\n        dtype=torch.float32,\n    ):\n        grid_w = torch.arange(int(width), dtype=dtype, device=device)\n        grid_h = torch.arange(int(height), dtype=dtype, device=device)\n        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing=\"ij\")\n        if embed_dim % 4 != 0:\n            raise ValueError(\n                \"Embed dimension must be divisible by 4 for 2D sin-cos position embedding\"\n            )\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim\n        omega = 1.0 / (temperature**omega)\n\n        out_w = grid_w.flatten()[..., None] @ omega[None]\n        out_h = grid_h.flatten()[..., None] @ omega[None]\n\n        return torch.concat(\n            [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1\n        )[None, :, :]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_dimensions: Tuple[int, int],\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n        always_partition: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor]:\n        height, width = input_dimensions\n\n        if self.positional_encoding is not None:\n            hidden_states = hidden_states + self.positional_encoding.to(\n                hidden_states.dtype\n            ).to(hidden_states.device)\n\n        for i, layer_module in enumerate(self.blocks):\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n\n            layer_outputs = layer_module(\n                hidden_states,\n                input_dimensions,\n                layer_head_mask,\n                output_attentions,\n                always_partition,\n            )\n\n            hidden_states = layer_outputs[0]\n\n        hidden_states_before_downsampling = hidden_states\n        if self.downsample is not None:\n            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2\n            output_dimensions = (height, width, height_downsampled, width_downsampled)\n            hidden_states = self.downsample(\n                hidden_states_before_downsampling, input_dimensions\n            )\n        else:\n            output_dimensions = (height, width, height, width)\n\n        stage_outputs = (\n            hidden_states,\n            hidden_states_before_downsampling,\n            output_dimensions,\n        )\n\n        if output_attentions:\n            stage_outputs += layer_outputs[1:]\n        return stage_outputs\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin\nclass DonutSwinEncoder(nn.Module):\n    def __init__(self, config, grid_size):\n        super().__init__()\n        self.num_layers = len(config.depths)\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                DonutSwinStage(\n                    config=config,\n                    layer_num=i_layer,\n                    dim=int(config.embed_dim * 2**i_layer),\n                    input_resolution=(\n                        grid_size[0] // (2**i_layer),\n                        grid_size[1] // (2**i_layer),\n                    ),\n                    depth=config.depths[i_layer],\n                    num_heads=config.num_heads[i_layer],\n                    num_kv_heads=config.num_kv_heads[i_layer]\n                    if hasattr(config, \"num_kv_heads\")\n                    else config.num_heads[i_layer],\n                    downsample=DonutSwinPatchMerging\n                    if (i_layer < self.num_layers - 1)\n                    else None,\n                )\n                for i_layer in range(self.num_layers)\n            ]\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_dimensions: Tuple[int, int],\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n        output_hidden_states: Optional[bool] = False,\n        output_hidden_states_before_downsampling: Optional[bool] = False,\n        always_partition: Optional[bool] = False,\n        return_dict: Optional[bool] = True,\n    ) -> Union[Tuple, DonutSwinEncoderOutput]:\n        all_hidden_states = () if output_hidden_states else None\n        all_reshaped_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n\n        if output_hidden_states:\n            batch_size, _, hidden_size = hidden_states.shape\n            # rearrange b (h w) c -> b c h w\n            reshaped_hidden_state = hidden_states.view(\n                batch_size, *input_dimensions, hidden_size\n            )\n            reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)\n            all_hidden_states += (hidden_states,)\n            all_reshaped_hidden_states += (reshaped_hidden_state,)\n\n        for i, layer_module in enumerate(self.layers):\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    layer_module.__call__,\n                    hidden_states,\n                    input_dimensions,\n                    layer_head_mask,\n                    output_attentions,\n                    always_partition,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    input_dimensions,\n                    layer_head_mask,\n                    output_attentions,\n                    always_partition,\n                )\n\n            hidden_states = layer_outputs[0]\n            hidden_states_before_downsampling = layer_outputs[1]\n            output_dimensions = layer_outputs[2]\n            input_dimensions = (output_dimensions[-2], output_dimensions[-1])\n\n            if output_hidden_states and output_hidden_states_before_downsampling:\n                batch_size, _, hidden_size = hidden_states_before_downsampling.shape\n                # rearrange b (h w) c -> b c h w\n                # here we use the original (not downsampled) height and width\n                reshaped_hidden_state = hidden_states_before_downsampling.view(\n                    batch_size,\n                    *(output_dimensions[0], output_dimensions[1]),\n                    hidden_size,\n                )\n                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)\n                all_hidden_states += (hidden_states_before_downsampling,)\n                all_reshaped_hidden_states += (reshaped_hidden_state,)\n            elif output_hidden_states and not output_hidden_states_before_downsampling:\n                batch_size, _, hidden_size = hidden_states.shape\n                # rearrange b (h w) c -> b c h w\n                reshaped_hidden_state = hidden_states.view(\n                    batch_size, *input_dimensions, hidden_size\n                )\n                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)\n                all_hidden_states += (hidden_states,)\n                all_reshaped_hidden_states += (reshaped_hidden_state,)\n\n            if output_attentions:\n                all_self_attentions += layer_outputs[3:]\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, all_hidden_states, all_self_attentions]\n                if v is not None\n            )\n\n        return DonutSwinEncoderOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            reshaped_hidden_states=all_reshaped_hidden_states,\n        )\n\n\n# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin\nclass DonutSwinPreTrainedModel(SuryaPreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = DonutSwinConfig\n    base_model_prefix = \"swin\"\n    main_input_name = \"pixel_values\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DonutSwinStage\"]\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Conv2d)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n"
  },
  {
    "path": "surya/common/donut/processor.py",
    "content": "from typing import Dict, Union, Optional, List, Iterable\n\nimport cv2\nfrom torch import TensorType\nfrom transformers import ImageProcessingMixin\nfrom transformers.image_processing_utils import BatchFeature\nfrom transformers.image_transforms import pad, normalize\nfrom transformers.image_utils import (\n    ImageInput,\n    ChannelDimension,\n    make_list_of_images,\n    get_image_size,\n)\nimport numpy as np\nfrom PIL import Image\nimport PIL\nfrom transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.settings import settings\n\n\nclass SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin):\n    def __init__(\n        self,\n        *args,\n        max_size=None,\n        align_long_axis=False,\n        rescale_factor: Union[int, float] = 1 / 255,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n\n        self.patch_size = kwargs.get(\"patch_size\", (4, 4))\n        self.max_size = max_size\n        self.do_align_long_axis = align_long_axis\n        self.resample = Image.Resampling.BILINEAR\n        self.rescale_factor = rescale_factor\n        self.image_mean = (\n            image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN\n        )\n        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD\n\n    def __call__(self, images, **kwargs) -> PIL.Image.Image:\n        \"\"\"Preprocess an image or a batch of images.\"\"\"\n        return self.preprocess(images, **kwargs)\n\n    @classmethod\n    def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):\n        max_width, max_height = size[\"width\"], size[\"height\"]\n\n        resized_image = cv2.resize(\n            image, (max_width, max_height), interpolation=interpolation\n        )\n        resized_image = resized_image.transpose(2, 0, 1)\n\n        return resized_image\n\n    def process_inner(self, images: List[np.ndarray]):\n        assert images[0].shape[2] == 3  # RGB input images, channel dim last\n\n        if self.do_align_long_axis:\n            # Rotate if the bbox is wider than it is tall\n            images = [\n                SuryaEncoderImageProcessor.align_long_axis(\n                    image, size=self.max_size, input_data_format=ChannelDimension.LAST\n                )\n                for image in images\n            ]\n\n            # Verify that the image is wider than it is tall\n            for img in images:\n                assert img.shape[1] >= img.shape[0]\n\n        # This also applies the right channel dim format, to channel x height x width\n        images = [\n            SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample)\n            for img in images\n        ]\n        assert images[0].shape[0] == 3  # RGB input images, channel dim first\n\n        # Convert to float32 for rescale/normalize\n        images = [img.astype(np.float32) for img in images]\n\n        # Pads with 255 (whitespace)\n        # Pad to max size to improve performance\n        max_size = self.max_size\n        images = [\n            SuryaEncoderImageProcessor.pad_image(\n                image=image,\n                size=max_size,\n                input_data_format=ChannelDimension.FIRST,\n                pad_value=settings.RECOGNITION_PAD_VALUE,\n            )\n            for image in images\n        ]\n\n        # Rescale and normalize\n        for idx in range(len(images)):\n            images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype(\n                np.float32\n            )\n\n        images = [\n            SuryaEncoderImageProcessor.normalize(\n                img,\n                mean=self.image_mean,\n                std=self.image_std,\n                input_data_format=ChannelDimension.FIRST,\n            )\n            for img in images\n        ]\n\n        return images\n\n    def preprocess(\n        self,\n        images: ImageInput,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        **kwargs,\n    ) -> PIL.Image.Image:\n        images = make_list_of_images(images)\n\n        # Convert to numpy for later processing steps\n        images = [np.array(img) for img in images]\n        images = self.process_inner(images)\n\n        data = {\"pixel_values\": images}\n        return BatchFeature(data=data, tensor_type=return_tensors)\n\n    @classmethod\n    def pad_image(\n        cls,\n        image: np.ndarray,\n        size: Dict[str, int],\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        pad_value: float = 0.0,\n    ) -> np.ndarray:\n        output_height, output_width = size[\"height\"], size[\"width\"]\n        input_height, input_width = get_image_size(image, channel_dim=input_data_format)\n\n        delta_width = output_width - input_width\n        delta_height = output_height - input_height\n\n        assert delta_width >= 0 and delta_height >= 0\n\n        pad_top = delta_height // 2\n        pad_left = delta_width // 2\n\n        pad_bottom = delta_height - pad_top\n        pad_right = delta_width - pad_left\n\n        padding = ((pad_top, pad_bottom), (pad_left, pad_right))\n        return pad(\n            image,\n            padding,\n            data_format=data_format,\n            input_data_format=input_data_format,\n            constant_values=pad_value,\n        )\n\n    @classmethod\n    def align_long_axis(\n        cls, image: np.ndarray, size: Dict[str, int], **kwargs\n    ) -> np.ndarray:\n        input_height, input_width = image.shape[:2]\n        output_height, output_width = size[\"height\"], size[\"width\"]\n\n        if (output_width < output_height and input_width > input_height) or (\n            output_width > output_height and input_width < input_height\n        ):\n            image = np.rot90(image, 3)\n\n        return image\n\n    @classmethod\n    def normalize(\n        cls,\n        image: np.ndarray,\n        mean: Union[float, Iterable[float]],\n        std: Union[float, Iterable[float]],\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        **kwargs,\n    ) -> np.ndarray:\n        return normalize(\n            image,\n            mean=mean,\n            std=std,\n            data_format=data_format,\n            input_data_format=input_data_format,\n            **kwargs,\n        )\n"
  },
  {
    "path": "surya/common/load.py",
    "content": "from typing import Optional, Any\n\nimport torch\n\nfrom surya.settings import settings\n\n\nclass ModelLoader:\n    def __init__(self, checkpoint: Optional[str] = None):\n        self.checkpoint = checkpoint\n\n    def model(\n        self,\n        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,\n        dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,\n        attention_implementation: Optional[str] = None,\n    ) -> Any:\n        raise NotImplementedError()\n\n    def processor(\n        self,\n        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,\n        dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,\n    ) -> Any:\n        raise NotImplementedError()\n"
  },
  {
    "path": "surya/common/polygon.py",
    "content": "import copy\nfrom typing import List, Optional\n\nimport numpy as np\nfrom pydantic import BaseModel, field_validator, computed_field\nimport numbers\n\n\nclass PolygonBox(BaseModel):\n    polygon: List[List[float]]\n    confidence: Optional[float] = None\n\n    @field_validator(\"polygon\", mode=\"before\")\n    @classmethod\n    def convert_bbox_to_polygon(cls, value):\n        if isinstance(value, (list, tuple)) and len(value) == 4:\n            if all(isinstance(x, numbers.Number) for x in value):\n                value = [float(v) for v in value]\n                x_min, y_min, x_max, y_max = value\n                polygon = [\n                    [x_min, y_min],\n                    [x_max, y_min],\n                    [x_max, y_max],\n                    [x_min, y_max],\n                ]\n                return polygon\n            elif all(\n                isinstance(point, (list, tuple)) and len(point) == 2 for point in value\n            ):\n                value = [[float(v) for v in point] for point in value]\n                return value\n        elif isinstance(value, np.ndarray):\n            if value.shape == (4, 2):\n                return value.tolist()\n\n        raise ValueError(\n            f\"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)].  All values must be numeric. You passed {value} of type {type(value)}.  The first value is of type {type(value[0])}.\"\n        )\n\n    @property\n    def height(self):\n        return self.bbox[3] - self.bbox[1]\n\n    @property\n    def width(self):\n        return self.bbox[2] - self.bbox[0]\n\n    @property\n    def area(self):\n        return self.width * self.height\n\n    @computed_field\n    @property\n    def bbox(self) -> List[float]:\n        x_coords = [point[0] for point in self.polygon]\n        y_coords = [point[1] for point in self.polygon]\n        return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]\n\n    def rescale(self, processor_size, image_size):\n        # Point is in x, y format\n        page_width, page_height = processor_size\n\n        img_width, img_height = image_size\n        width_scaler = img_width / page_width\n        height_scaler = img_height / page_height\n\n        for corner in self.polygon:\n            corner[0] = int(corner[0] * width_scaler)\n            corner[1] = int(corner[1] * height_scaler)\n\n    def round(self, divisor):\n        for corner in self.polygon:\n            corner[0] = int(corner[0] / divisor) * divisor\n            corner[1] = int(corner[1] / divisor) * divisor\n\n    def fit_to_bounds(self, bounds):\n        new_corners = copy.deepcopy(self.polygon)\n        for corner in new_corners:\n            corner[0] = max(min(corner[0], bounds[2]), bounds[0])\n            corner[1] = max(min(corner[1], bounds[3]), bounds[1])\n        self.polygon = new_corners\n\n    def merge(self, other):\n        x1 = min(self.bbox[0], other.bbox[0])\n        y1 = min(self.bbox[1], other.bbox[1])\n        x2 = max(self.bbox[2], other.bbox[2])\n        y2 = max(self.bbox[3], other.bbox[3])\n        self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]\n\n    def merge_left(self, other):\n        x1 = min(self.bbox[0], other.bbox[0])\n        self.polygon[0][0] = x1\n        self.polygon[3][0] = x1\n\n    def merge_right(self, other):\n        x2 = max(self.bbox[2], other.bbox[2])\n        self.polygon[1][0] = x2\n        self.polygon[2][0] = x2\n\n    def expand(self, x_margin: float, y_margin: float):\n        new_polygon = []\n        x_margin = x_margin * self.width\n        y_margin = y_margin * self.height\n        for idx, poly in enumerate(self.polygon):\n            if idx == 0:\n                new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)])\n            elif idx == 1:\n                new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)])\n            elif idx == 2:\n                new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)])\n            elif idx == 3:\n                new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])\n        self.polygon = new_polygon\n\n    def intersection_polygon(self, other) -> List[List[float]]:\n        new_poly = []\n        for i in range(4):\n            if i == 0:\n                new_corner = [\n                    max(self.polygon[0][0], other.polygon[0][0]),\n                    max(self.polygon[0][1], other.polygon[0][1]),\n                ]\n            elif i == 1:\n                new_corner = [\n                    min(self.polygon[1][0], other.polygon[1][0]),\n                    max(self.polygon[1][1], other.polygon[1][1]),\n                ]\n            elif i == 2:\n                new_corner = [\n                    min(self.polygon[2][0], other.polygon[2][0]),\n                    min(self.polygon[2][1], other.polygon[2][1]),\n                ]\n            elif i == 3:\n                new_corner = [\n                    max(self.polygon[3][0], other.polygon[3][0]),\n                    min(self.polygon[3][1], other.polygon[3][1]),\n                ]\n            new_poly.append(new_corner)\n\n        return new_poly\n\n    def intersection_area(self, other, x_margin=0, y_margin=0):\n        x_overlap = self.x_overlap(other, x_margin)\n        y_overlap = self.y_overlap(other, y_margin)\n        return x_overlap * y_overlap\n\n    def x_overlap(self, other, x_margin=0):\n        return max(\n            0,\n            min(self.bbox[2] + x_margin, other.bbox[2] + x_margin)\n            - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin),\n        )\n\n    def y_overlap(self, other, y_margin=0):\n        return max(\n            0,\n            min(self.bbox[3] + y_margin, other.bbox[3] + y_margin)\n            - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),\n        )\n\n    def intersection_pct(self, other, x_margin=0, y_margin=0):\n        assert 0 <= x_margin <= 1\n        assert 0 <= y_margin <= 1\n        if self.area == 0:\n            return 0\n\n        if x_margin:\n            x_margin = int(min(self.width, other.width) * x_margin)\n        if y_margin:\n            y_margin = int(min(self.height, other.height) * y_margin)\n\n        intersection = self.intersection_area(other, x_margin, y_margin)\n        return intersection / self.area\n\n    def shift(self, x_shift: float | None = None, y_shift: float | None = None):\n        if x_shift is not None:\n            for corner in self.polygon:\n                corner[0] += x_shift\n        if y_shift is not None:\n            for corner in self.polygon:\n                corner[1] += y_shift\n\n    def clamp(self, bbox: List[float]):\n        for corner in self.polygon:\n            corner[0] = max(min(corner[0], bbox[2]), bbox[0])\n            corner[1] = max(min(corner[1], bbox[3]), bbox[1])\n\n    @property\n    def center(self):\n        return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]\n\n    def distance(self, other):\n        center = self.center\n        other_center = other.center\n\n        return (\n            (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2\n        ) ** 0.5\n\n    def __hash__(self):\n        return hash(tuple(self.bbox))\n"
  },
  {
    "path": "surya/common/predictor.py",
    "content": "from typing import Optional\nimport torch\nimport torch.nn.functional as F\n\nfrom surya.common.load import ModelLoader\nfrom surya.settings import settings\n\n\nclass BasePredictor:\n    model_loader_cls = ModelLoader\n    batch_size: Optional[int] = None\n    default_batch_sizes = {\"cpu\": 1, \"mps\": 1, \"cuda\": 1}\n    torch_dtype = settings.MODEL_DTYPE\n\n    @property\n    def disable_tqdm(self) -> bool:\n        return self._disable_tqdm\n\n    @disable_tqdm.setter\n    def disable_tqdm(self, value: bool) -> None:\n        self._disable_tqdm = bool(value)\n\n    def __init__(\n        self,\n        checkpoint: Optional[str] = None,\n        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,\n        dtype: Optional[torch.dtype | str] = None,\n        attention_implementation: Optional[str] = None,\n    ):\n        if dtype is None:\n            dtype = self.torch_dtype\n\n        self.model = None\n        self.processor = None\n        loader = self.model_loader_cls(checkpoint)\n\n        self.model = loader.model(device, dtype, attention_implementation)\n        self.processor = loader.processor()\n\n        self._disable_tqdm = settings.DISABLE_TQDM\n\n    def to(self, device_dtype: torch.device | str | None = None):\n        model_moved = False\n        if hasattr(self, \"model\") and self.model:\n            self.model.to(device_dtype)\n            model_moved = True\n        if hasattr(self, \"foundation_predictor\") and self.foundation_predictor:\n            self.foundation_predictor.model.to(device_dtype)\n            model_moved = True\n\n        if not model_moved:\n            raise ValueError(\"Model not loaded\")\n\n    def get_batch_size(self):\n        batch_size = self.batch_size\n        if batch_size is None:\n            batch_size = self.default_batch_sizes[\"cpu\"]\n            if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes:\n                batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]\n        return batch_size\n\n    @staticmethod\n    def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):\n        current_batch_size = tensor.shape[0]\n        if current_batch_size >= batch_size:\n            return tensor\n\n        if len(tensor.shape) == 1:\n            # If tensor is 1D, we need to pad it to the batch size\n            pad_size = batch_size - current_batch_size\n            return F.pad(tensor, (0, pad_size), mode=\"constant\", value=0)\n\n        pad_size = batch_size - current_batch_size\n        padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)\n\n        return F.pad(tensor, padding, mode=\"constant\", value=0)\n\n    def __call__(self, *args, **kwargs):\n        raise NotImplementedError()\n"
  },
  {
    "path": "surya/common/pretrained.py",
    "content": "from typing import Optional\n\nfrom transformers import PreTrainedModel\nfrom transformers.utils import is_flash_attn_2_available\n\n\nclass SuryaPreTrainedModel(PreTrainedModel):\n    # No-op if we pass attention, so we can set attention however we want in the config\n    def _check_and_adjust_attn_implementation(\n        self, attn_implementation: Optional[str], **kwargs\n    ):\n        if attn_implementation is None:\n            try:\n                self._sdpa_can_dispatch(True)\n                attn_implementation = \"sdpa\"\n            except (ValueError, ImportError):\n                attn_implementation = \"eager\"\n\n            if self._supports_flash_attn and is_flash_attn_2_available():\n                attn_implementation = \"flash_attention_2\"\n\n        return attn_implementation\n"
  },
  {
    "path": "surya/common/s3.py",
    "content": "import json\nimport os\nimport shutil\nimport tempfile\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom pathlib import Path\n\nimport requests\nfrom tqdm import tqdm\n\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n# Lock file expiration time in seconds (10 minutes)\nLOCK_EXPIRATION = 600\n\n\ndef join_urls(url1: str, url2: str):\n    url1 = url1.rstrip(\"/\")\n    url2 = url2.lstrip(\"/\")\n    return f\"{url1}/{url2}\"\n\n\ndef get_model_name(pretrained_model_name_or_path: str):\n    return pretrained_model_name_or_path.split(\"/\")[0]\n\n\ndef download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024):\n    local_path = Path(local_path)\n    try:\n        response = requests.get(remote_path, stream=True, allow_redirects=True)\n        response.raise_for_status()  # Raise an exception for bad status codes\n\n        # Get file size from headers for progress bar\n        total_size = int(response.headers.get('content-length', 0))\n        \n        # Create progress bar with file name and size info\n        filename = local_path.name\n        pbar = tqdm(\n            total=total_size,\n            unit='B',\n            unit_scale=True,\n            unit_divisor=1024,\n            desc=f\"Downloading {filename}\",\n            miniters=1\n        )\n\n        with open(local_path, \"wb\") as f:\n            downloaded = 0\n            for chunk in response.iter_content(chunk_size=chunk_size):\n                if chunk:\n                    f.write(chunk)\n                    downloaded += len(chunk)\n                    pbar.update(len(chunk))\n        \n        pbar.close()\n        return local_path\n    except Exception as e:\n        if local_path.exists():\n            local_path.unlink()\n        logger.error(f\"Download error for file {remote_path}: {str(e)}\")\n        raise\n\n\ndef check_manifest(local_dir: str):\n    local_dir = Path(local_dir)\n    manifest_path = local_dir / \"manifest.json\"\n    if not os.path.exists(manifest_path):\n        return False\n\n    try:\n        with open(manifest_path, \"r\") as f:\n            manifest = json.load(f)\n        for file in manifest[\"files\"]:\n            if not os.path.exists(local_dir / file):\n                return False\n    except Exception:\n        return False\n\n    return True\n\n\ndef download_directory(remote_path: str, local_dir: str):\n    model_name = get_model_name(remote_path)\n    s3_url = join_urls(settings.S3_BASE_URL, remote_path)\n    # Check to see if it's already downloaded\n    model_exists = check_manifest(local_dir)\n    if model_exists:\n        return\n\n    # Use tempfile.TemporaryDirectory to automatically clean up\n    with tempfile.TemporaryDirectory() as temp_dir:\n        # Download the manifest file\n        manifest_file = join_urls(s3_url, \"manifest.json\")\n        manifest_path = os.path.join(temp_dir, \"manifest.json\")\n        download_file(manifest_file, manifest_path)\n\n        # List and download all files\n        with open(manifest_path, \"r\") as f:\n            manifest = json.load(f)\n\n        pbar = tqdm(\n            desc=f\"Downloading {model_name} model to {local_dir}\",\n            total=len(manifest[\"files\"]),\n        )\n\n        with ThreadPoolExecutor(\n            max_workers=settings.PARALLEL_DOWNLOAD_WORKERS\n        ) as executor:\n            futures = []\n            for file in manifest[\"files\"]:\n                remote_file = join_urls(s3_url, file)\n                local_file = os.path.join(temp_dir, file)\n                futures.append(executor.submit(download_file, remote_file, local_file))\n\n            for future in futures:\n                future.result()\n                pbar.update(1)\n\n        pbar.close()\n\n        # Move all files to new directory\n        for file in os.listdir(temp_dir):\n            shutil.move(os.path.join(temp_dir, file), local_dir)\n\n\nclass S3DownloaderMixin:\n    s3_prefix = \"s3://\"\n\n    @classmethod\n    def get_local_path(cls, pretrained_model_name_or_path) -> str:\n        if pretrained_model_name_or_path.startswith(cls.s3_prefix):\n            pretrained_model_name_or_path = pretrained_model_name_or_path.replace(\n                cls.s3_prefix, \"\"\n            )\n            cache_dir = settings.MODEL_CACHE_DIR\n            local_path = os.path.join(cache_dir, pretrained_model_name_or_path)\n            os.makedirs(local_path, exist_ok=True)\n        else:\n            local_path = \"\"\n        return local_path\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):\n        # Allow loading models directly from the hub, or using s3\n        if not pretrained_model_name_or_path.startswith(cls.s3_prefix):\n            return super().from_pretrained(\n                pretrained_model_name_or_path, *args, **kwargs\n            )\n\n        local_path = cls.get_local_path(pretrained_model_name_or_path)\n        pretrained_model_name_or_path = pretrained_model_name_or_path.replace(\n            cls.s3_prefix, \"\"\n        )\n\n        # Retry logic for downloading the model folder\n        retries = 3\n        delay = 5\n        attempt = 0\n        success = False\n        while not success and attempt < retries:\n            try:\n                download_directory(pretrained_model_name_or_path, local_path)\n                success = True  # If download succeeded\n            except Exception as e:\n                logger.error(\n                    f\"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}\"\n                )\n                attempt += 1\n                if attempt < retries:\n                    logger.info(f\"Retrying in {delay} seconds...\")\n                    time.sleep(delay)  # Wait before retrying\n                else:\n                    logger.error(\n                        f\"Failed to download {pretrained_model_name_or_path} after {retries} attempts.\"\n                    )\n                    raise e  # Reraise exception after max retries\n\n        return super().from_pretrained(local_path, *args, **kwargs)\n"
  },
  {
    "path": "surya/common/surya/__init__.py",
    "content": "import warnings\nfrom typing import Optional, Tuple, TypedDict\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.common.surya.config import SuryaModelConfig\nfrom surya.common.surya.decoder import SuryaDecoderModel\nfrom surya.common.surya.embedder import SimpleTokenEmbedder\nfrom surya.common.surya.encoder import SuryaEncoderModel\nfrom surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat\nfrom surya.common.xla import get_nearest_pad\nfrom surya.settings import settings\n\nfrom surya.logging import get_logger\n\nlogger = get_logger()\n\n\n@dataclass\nclass SuryaModelOutput(CausalLMOutputWithPast):\n    bbox_logits: torch.FloatTensor = None\n    lm_logits: torch.FloatTensor = None\n\n\nclass FlashAttentionKwargs(TypedDict, total=False):\n    \"\"\"\n    Keyword arguments for Flash Attention with Compile.\n\n    Attributes:\n        cu_seq_lens_q (`torch.LongTensor`, *optional*)\n            Gets cumlative sequence length for query state.\n        cu_seq_lens_k (`torch.LongTensor`, *optional*)\n            Gets cumlative sequence length for key state.\n        max_length_q (`int`, *optional*):\n            Maximum sequence length for query state.\n        max_length_k (`int`, *optional*):\n            Maximum sequence length for key state.\n    \"\"\"\n\n    cu_seq_lens_q: Optional[torch.LongTensor]\n    cu_seq_lens_k: Optional[torch.LongTensor]\n    max_length_q: Optional[int]\n    max_length_k: Optional[int]\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs): ...\n\n\nclass DistanceProjection(nn.Module):\n    def __init__(self, in_features: int, out_features: int):\n        super().__init__()\n        self.fc1 = nn.Linear(in_features, out_features)\n        self.act = nn.SiLU()\n        self.fc2 = nn.Linear(out_features, out_features)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.fc2(x)\n        return x\n\n    def init_weights(self):\n        nn.init.xavier_uniform_(self.fc1.weight)\n        nn.init.xavier_uniform_(self.fc2.weight)\n        nn.init.zeros_(self.fc1.bias)\n        nn.init.zeros_(self.fc2.bias)\n\n\nclass BboxHead(nn.Module):\n    def __init__(self, in_features: int, out_features: int):\n        super().__init__()\n        self.proj_layers = nn.ModuleList(\n            [nn.Linear(in_features, in_features) for _ in range(6)]\n        )\n        self.act = nn.SiLU()\n        self.out_proj = nn.Linear(in_features, out_features)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for layer in self.proj_layers:\n            x = layer(x)\n            x = self.act(x)\n\n        x = self.out_proj(x)\n        return x\n\n\nclass SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):\n    config_class = SuryaModelConfig\n    supports_gradient_checkpointing = True\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n    _supports_attention_backend = True\n    main_input_name = \"input_ids\"\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(\n        self,\n        config: SuryaModelConfig,\n        embedder: SimpleTokenEmbedder = None,\n        vision_encoder: SuryaEncoderModel = None,\n        decoder: SuryaDecoderModel = None,\n        **kwargs,\n    ):\n        super().__init__(config, **kwargs)\n\n        if vision_encoder is None:\n            vision_encoder = SuryaEncoderModel(config.vision_encoder)\n\n        if decoder is None:\n            decoder = SuryaDecoderModel(config.decoder)\n\n        if embedder is None:\n            embedder = SimpleTokenEmbedder(config)\n\n        self.vision_encoder = vision_encoder\n        self.decoder = decoder\n        self.embedder = embedder\n\n        # Simple encoding for image patches\n        self.img_w_embed = nn.Embedding(\n            self.config.image_embed_encoding_size,\n            self.config.hidden_size,\n        )\n\n        self.img_h_embed = nn.Embedding(\n            self.config.image_embed_encoding_size,\n            self.config.hidden_size,\n        )\n\n        # Tying configs\n        self.vision_encoder.config = self.config.vision_encoder\n        self.decoder.config = self.config.decoder\n\n        self.bbox_head = BboxHead(config.hidden_size, 6)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)\n\n        if (\n            self.config.multi_output_distance is not None\n            and self.config.multi_output_distance > 0\n        ):\n            self.multi_output_projections = nn.ModuleList(\n                [\n                    DistanceProjection(\n                        in_features=config.hidden_size, out_features=config.hidden_size\n                    )\n                    for _ in range(self.config.multi_output_distance)\n                ]\n            )\n\n    def tie_weights(self):\n        self._tie_weights()\n\n    def _tie_weights(self):\n        # Tie weights of lm head and token embedder\n        self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)\n\n    def get_output_embeddings(self) -> nn.Module:\n        return self.lm_head\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.embedder.token_embed\n\n    def set_output_embeddings(self, new_embeddings: nn.Module):\n        self.lm_head = new_embeddings\n\n    def set_input_embeddings(self, new_embeddings: nn.Module):\n        self.embedder.token_embed = new_embeddings\n\n    def maybe_static_pad_image_inputs(\n        self,\n        chunk_pixels: torch.Tensor,\n        chunk_grid_thw: torch.Tensor,\n        actual_chunk_len: int,\n        encoder_chunk_size: int,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        valid_embed_len = actual_chunk_len // (\n            self.vision_encoder.spatial_merge_size**2\n        )\n        if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size:\n            padding_len = encoder_chunk_size - actual_chunk_len\n            chunk_pixels = F.pad(\n                chunk_pixels,\n                (0, 0, 0, padding_len),\n                mode=\"constant\",\n                value=0.0,\n            )\n\n            padding_grid = torch.tensor(\n                [[1, 2, padding_len // 2]],\n                device=chunk_grid_thw.device,\n                dtype=chunk_grid_thw.dtype,\n            )\n            chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0)\n\n        return chunk_pixels, chunk_grid_thw, valid_embed_len\n\n    def get_image_embeddings(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: torch.Tensor,\n        encoder_chunk_size: int,\n        valid_batch_size: torch.Tensor | None = None,\n        max_batch_size: int | None = None,\n    ):\n        # embed all images with the vision encoder after they have already been tiled and flattened into a single batch\n        chunks = [0]\n        grid_chunks = [0]\n        curr_chunk_len = 0\n        curr_seq_len = 0\n        for i in range(len(grid_thw)):\n            curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item()\n            if curr_chunk_len > encoder_chunk_size:\n                chunks.append(curr_chunk_len + curr_seq_len)\n                curr_seq_len += curr_chunk_len\n                curr_chunk_len = 0\n                grid_chunks.append(i + 1)\n\n        if curr_chunk_len > 0:\n            chunks.append(pixel_values.shape[0])\n            grid_chunks.append(len(grid_thw))\n\n        assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], (\n            f\"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}\"\n        )\n\n        logger.debug(\n            f\"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}\"\n        )\n        embeddings = []\n        for i in range(len(chunks) - 1):\n            start = chunks[i]\n            end = chunks[i + 1]\n            grid_start = grid_chunks[i]\n            grid_end = grid_chunks[i + 1]\n\n            chunk_pixels = pixel_values[start:end]\n            chunk_grid_thw = grid_thw[grid_start:grid_end]\n            actual_chunk_len = end - start\n            chunk_pixels, chunk_grid_thw, valid_embed_len = (\n                self.maybe_static_pad_image_inputs(\n                    chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size\n                )\n            )\n\n            chunk_embeddings = self.vision_encoder.embed_images(\n                image_batch=chunk_pixels.unsqueeze(0).to(device=self.device),\n                grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device),\n            )\n            embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0))\n\n        if len(embeddings) == 0:\n            raise ValueError(\n                \"No image embeddings were generated. Check the input images and grid sizes.\"\n            )\n        elif len(embeddings) == 1:\n            embeddings = embeddings[0]\n        else:\n            embeddings = torch.cat(embeddings, dim=0)\n\n        encoding_2d = self.get_2d_learned_embeddings(\n            grid_thw,\n            device=embeddings.device,\n            bbox_size=self.config.image_embed_encoding_multiplier,\n        )\n        assert embeddings.shape[0] == encoding_2d.shape[0], (\n            f\"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}\"\n        )\n        assert embeddings.shape[1] == encoding_2d.shape[1], (\n            f\"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}\"\n        )\n\n        embeddings = embeddings + encoding_2d\n\n        return embeddings\n\n    def embed_ids_boxes_images(\n        self,\n        input_ids,\n        image_embeddings,\n        encoder_chunk_size: int,\n        valid_batch_size: torch.Tensor | None = None,\n        input_boxes: torch.Tensor | None = None,\n        embed_boxes: torch.Tensor | None = None,\n    ):\n        \"\"\"\n        Insert embedded image tiles into the corresponding positions into the full input sequence\n\n        Positions to insert new tokens are indicated by the special image token index\n        \"\"\"\n        # This is batched in the inner call\n        inputs_embeds = self.embedder.embed(\n            input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes\n        )\n\n        if image_embeddings is not None:\n            special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)\n            special_image_mask = special_image_mask.expand_as(inputs_embeds)\n            if inputs_embeds[special_image_mask].numel() != image_embeddings.numel():\n                n_image_tokens = torch.sum((input_ids == self.config.image_token_id))\n                n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1]\n                warnings.warn(\n                    f\"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results\"\n                )\n            image_features = image_embeddings.to(inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(\n                special_image_mask, image_features\n            )\n        else:\n            assert (input_ids == self.config.image_token_id).sum() == 0, (\n                \"Image tokens were present in the input but no input images were provided\"\n            )\n\n        return inputs_embeds\n\n    def get_2d_learned_embeddings(\n        self,\n        grid_thw,\n        device: str | torch.device = \"cpu\",\n        bbox_size: int = 256,\n    ):\n        all_embeddings = []\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.config.merge_size,\n                grid_w // self.config.merge_size,\n            )\n\n            # Scale to 0-1024\n            llm_grid_h = (\n                torch.arange(llm_grid_h, device=device)\n                / max(1, (llm_grid_h - 1))\n                * bbox_size\n            )\n            llm_grid_w = (\n                torch.arange(llm_grid_w, device=device)\n                / max(1, (llm_grid_w - 1))\n                * bbox_size\n            )\n\n            llm_grid_w_idx = llm_grid_w.to(torch.long)\n            llm_grid_h_idx = llm_grid_h.to(torch.long)\n\n            llm_grid_w = self.img_w_embed(llm_grid_w_idx)\n            llm_grid_h = self.img_h_embed(llm_grid_h_idx)\n\n            full_grid = llm_grid_h[:, None] + llm_grid_w[None, :]\n\n            flattened = full_grid.flatten(\n                0, 1\n            )  # Flatten first dimension, so they are seq_len x embed_dim\n            all_embeddings.append(flattened)\n        return torch.concat(\n            all_embeddings, dim=0\n        )  # Shape is num_image_tokens x embed_dim\n\n    def get_logits(self, hidden_states):\n        assert hidden_states.shape[1] == 1, (\n            \"Multi output predictions only applied on the last token\"\n        )\n\n        all_lm_logits = []\n        all_bbox_logits = []\n\n        current_hidden = hidden_states\n\n        # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions\n        for i in range(self.config.multi_output_distance + 1):\n            if i > 0:\n                current_hidden = self.multi_output_projections[i - 1](current_hidden)\n\n            lm_logits = self.lm_head(current_hidden)\n            bbox_logits = F.sigmoid(self.bbox_head(current_hidden))\n\n            all_lm_logits.append(lm_logits)\n            all_bbox_logits.append(bbox_logits)\n\n        # Concatenate along sequence dimension (dim=1)\n        final_lm_logits = torch.cat(all_lm_logits, dim=1)\n        final_bbox_logits = torch.cat(all_bbox_logits, dim=1)\n\n        return final_lm_logits, final_bbox_logits\n\n    def forward(\n        self,\n        input_ids=None,\n        image_embeddings=None,\n        labels=None,\n        image_tiles=None,\n        grid_thw=None,\n        inputs_embeds=None,\n        attention_mask=None,\n        position_ids=None,\n        cache_position=None,\n        past_key_values=None,\n        output_hidden_states=False,\n        output_attentions=False,\n        use_cache=False,\n        encoder_chunk_size=32768,\n        cache_idxs=None,\n        num_valid_tokens=None,\n        prefill=True,\n        text_lengths=None,\n        valid_batch_size: torch.Tensor = None,\n        input_boxes=None,\n        embed_boxes=None,\n        logits_to_keep=None,\n        **kwargs: KwargsForCausalLM,\n    ):\n        if any([\n            input_ids is None,\n            position_ids is None,\n            cache_position is None,\n            (\n                prefill\n                and not (\n                    (image_tiles is not None and grid_thw is not None)\n                    or image_embeddings is not None\n                )\n            ),\n        ]):\n            raise ValueError(\n                \"`input_ids`, `position_ids`, and `cache_position` **must** be specified. \"\n                \"For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`.\"\n            )\n\n\n        inputs_embeds = self.embed_ids_boxes_images(\n            input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes\n        )\n\n        # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder\n        # Skipped during decoding since not required\n        if self.decoder.config._attn_implementation == \"flash_attention_2\" and prefill:\n            # Needed for CPU -> GPU\n            from surya.common.surya.flash_attn_utils import _get_unpad_data\n            batch_size, query_length, _ = inputs_embeds.shape\n            indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(\n                attention_mask\n            )\n            kwargs[\"batch_size\"] = batch_size\n            kwargs[\"query_length\"] = query_length\n            kwargs[\"indices_k\"] = indices_k\n            kwargs[\"cu_seqlens_k\"] = cu_seqlens_k\n            kwargs[\"max_seqlen_in_batch_k\"] = max_seqlen_in_batch_k\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n\n        attention_mask = causal_mask\n        outputs = self.decoder(\n            inputs_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            return_dict=True,\n            use_cache=use_cache,\n            cache_idxs=cache_idxs,\n            num_valid_tokens=num_valid_tokens,\n            prefill=prefill,\n            text_lengths=text_lengths,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        if logits_to_keep is not None:\n            hidden_states = hidden_states[:, -logits_to_keep:, :]\n        hidden_states = hidden_states.contiguous()\n\n        loss = None\n        if labels is not None:\n            # Training, return full logits\n            lm_logits = self.lm_head(hidden_states)\n            bbox_logits = None\n            vocab_size = lm_logits.shape[-1]\n            labels = torch.roll(labels, shifts=-1, dims=-1)\n            loss = F.cross_entropy(\n                lm_logits.view(-1, vocab_size), labels.view(-1), reduction=\"mean\"\n            )\n        else:\n            lm_logits, bbox_logits = self.get_logits(hidden_states)\n\n        return SuryaModelOutput(\n            loss=loss,\n            bbox_logits=bbox_logits,\n            lm_logits=lm_logits,\n            hidden_states=outputs.hidden_states if output_hidden_states else None,\n            attentions=outputs.attentions if output_attentions else None,\n            past_key_values=outputs.past_key_values,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        if self.decoder.config._attn_implementation == \"flash_attention_2\":\n            return attention_mask\n\n        # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        target_length = (\n            attention_mask.shape[-1]\n            if isinstance(attention_mask, torch.Tensor)\n            else past_key_values.max_cache_len\n        )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            device=device,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            config=self.config,\n            past_key_values=past_key_values,\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\"]\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        config: SuryaModelConfig,\n        past_key_values: Cache,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to plcae the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`.\n            batch_size (`torch.Tensor`):\n                Batch size.\n            config (`Qwen2Config`):\n                The model's configuration class\n            past_key_values (`Cache`):\n                The cache class that is being used currently to generate\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            # Batch-aware diagonal attend mask\n            diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(\n                0\n            ) > cache_position.unsqueeze(-1)\n            causal_mask = (\n                causal_mask.unsqueeze(0) * diagonal_attend_mask\n            )  # (batch_size, seq_len, target_len)\n            causal_mask = causal_mask[\n                :, None, :, :\n            ]  # (batch_size, 1, seq_len, target_len)\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                if attention_mask.shape[-1] > target_length:\n                    attention_mask = attention_mask[:, :target_length]\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[\n                    :, None, None, :\n                ].to(causal_mask.device)\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n        return causal_mask\n\nclass SuryaXLAModel(SuryaModel):\n    def get_image_embeddings(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: torch.Tensor,\n        encoder_chunk_size: int,\n        valid_batch_size: torch.Tensor | None = None,\n        max_batch_size: int | None = None,\n    ):\n        # embed all images with the vision encoder after they have already been tiled and flattened into a single batch\n        unpadded_max_grid_size = (\n            (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item()\n        )\n        max_grid_size = get_nearest_pad(\n            unpadded_max_grid_size,\n        )  # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw\n\n        # Always need 2 items in each row batch\n        if max_grid_size == unpadded_max_grid_size:\n            max_grid_size += 16\n\n        full_image_grid = torch.zeros(\n            (valid_batch_size, max_grid_size, pixel_values.shape[-1]),\n            dtype=pixel_values.dtype,\n        )\n\n        # Roll out into a full grid\n        seq_len = 0\n        row_grids = []\n        for i in range(valid_batch_size):\n            curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]\n            full_image_grid[i, -curr_sample_len:] = pixel_values[\n                seq_len : seq_len + curr_sample_len\n            ]\n            padded_len = max_grid_size - curr_sample_len\n            if padded_len > 0:\n                row_grid = torch.tensor(\n                    [\n                        [1, 4, padded_len // 4],\n                        grid_thw[i].tolist(),\n                    ],\n                    dtype=torch.long,\n                )\n            else:\n                row_grid = torch.tensor(\n                    [\n                        grid_thw[i].tolist(),\n                    ],\n                    dtype=torch.long,\n                )\n\n            row_grids.append(row_grid)\n            seq_len += curr_sample_len\n\n        # bsz, 2, 3\n        row_grids = torch.stack(row_grids, dim=0)\n\n        if settings.FOUNDATION_STATIC_CACHE:\n            # Pad to max batch size, repeat the final row\n            row_grids = pad_to_batch_size_repeat(\n                row_grids,\n                batch_size=max_batch_size,\n            )\n            full_image_grid = pad_to_batch_size(\n                full_image_grid,\n                batch_size=max_batch_size,\n            )\n\n        full_image_grid = full_image_grid.to(self.device)\n\n        embeddings = self.vision_encoder.embed_images(\n            image_batch=full_image_grid, grid_thw=row_grids.to(self.device)\n        )\n\n        encoding_2d = self.get_2d_learned_embeddings(\n            row_grids,\n            bbox_size=self.config.image_embed_encoding_multiplier,\n        )\n        embeddings += encoding_2d\n\n        return embeddings\n\n    def embed_ids_boxes_images(\n        self,\n        input_ids,\n        image_embeddings,\n        encoder_chunk_size: int,\n        valid_batch_size: torch.Tensor | None = None,\n        input_boxes: torch.Tensor | None = None,\n        embed_boxes: torch.Tensor | None = None,\n    ):\n        \"\"\"\n        Insert embedded image tiles into the corresponding positions into the full input sequence\n\n        Positions to insert new tokens are indicated by the special image token index\n        \"\"\"\n        # This is batched in the inner call\n        inputs_embeds = self.embedder.embed(\n            input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes\n        )\n\n        if image_embeddings is not None:\n            image_token_id_tensor = torch.tensor(\n                self.config.image_token_id,\n                device=inputs_embeds.device,\n                dtype=torch.long,\n            )\n            mask = input_ids == image_token_id_tensor\n            last_image_token_pos = (\n                mask.size(1)\n                - 1\n                - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True)\n            )\n            # Calculate start position to replace N positions ending at (and including) the last image token\n            start_positions = last_image_token_pos - image_embeddings[0].shape[0]\n            batch_size, insert_len = image_embeddings.shape[:2]\n\n            # Create position indices for each insertion\n            pos_indices = torch.arange(\n                insert_len, device=inputs_embeds.device\n            ).unsqueeze(0)\n            insert_positions = start_positions + pos_indices\n\n            idx = insert_positions.unsqueeze(-1).expand(\n                -1, -1, inputs_embeds.size(-1)\n            )  # [B,N,D]\n            inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings)\n\n        inputs_embeds = inputs_embeds * (\n            input_ids != self.config.pad_token_id\n        ).unsqueeze(-1).to(inputs_embeds.dtype)\n        return inputs_embeds\n\n    def get_2d_learned_embeddings(\n        self,\n        grid_thw,\n        bbox_size: int = 256,\n    ):\n        dev = grid_thw.device\n        all_row_coords = []\n        all_col_coords = []\n        for row_grid in grid_thw:\n            merge = self.config.merge_size\n\n            # per-sample grid sizes after merge\n            H = (row_grid[:, 1] // merge).long()  # (B,)\n            W = (row_grid[:, 2] // merge).long()  # (B,)\n\n            row_coords = torch.cat(\n                [\n                    torch.linspace(0, bbox_size, steps=int(h), device=dev)\n                    .round()\n                    .repeat_interleave(w)  # repeat each row value w times\n                    for h, w in zip(H.tolist(), W.tolist())\n                ]\n            )  # (full_grid_size,)\n\n            col_coords = torch.cat(\n                [\n                    torch.linspace(0, bbox_size, steps=int(w), device=dev)\n                    .round()\n                    .repeat(int(h))  # tile the column vector h times\n                    for h, w in zip(H.tolist(), W.tolist())\n                ]\n            )  # (full_grid_size,)\n            all_row_coords.append(row_coords)\n            all_col_coords.append(col_coords)\n        row_coords = torch.stack(all_row_coords, dim=0).to(self.device)\n        col_coords = torch.stack(all_col_coords, dim=0).to(self.device)\n\n        emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long())\n        return emb\n"
  },
  {
    "path": "surya/common/surya/config.py",
    "content": "from typing import Optional\nfrom transformers import PretrainedConfig\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.common.surya.encoder.config import SuryaEncoderConfig\nfrom surya.common.surya.decoder.config import SuryaDecoderConfig\n\n\nclass SuryaModelConfig(S3DownloaderMixin, PretrainedConfig):\n    model_type = \"surya-multimodal-foundation\"\n    is_composition = True\n\n    def __init__(\n        self,\n        vocab_size=65536,\n        bbox_size=1025,\n        blank_bbox_token_id=1025,\n        bos_token_id=0,\n        eos_token_id=1,\n        pad_token_id=2,\n        image_token_id=3,\n        register_token_ids=(4, 5, 6, 7),\n        eoi_token_id=8,\n        beacon_token_id=9,\n        special_token_count=4,\n        max_sequence_length=1536,\n        special_ocr_tokens=None,\n        vision_encoder=None,\n        decoder=None,\n        tasks: dict | None = None,\n        bbox_embed_size: int = 64,\n        num_register_tokens: int = 4,\n        image_embed_encoding_size: int = 1024,\n        image_embed_encoding_multiplier: int = 256,\n        num_beacon_tokens: int = 1,\n        beacon_token_interval: int = 4096,\n        sliding_window: Optional[int] = None,\n        multi_output_distance: int = 4,\n        max_multi_out: int = 8,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.is_encoder_decoder = False\n        self.vocab_size = vocab_size\n        self.bbox_size = bbox_size\n        self.blank_bbox_token_id = blank_bbox_token_id\n        self.image_token_id = image_token_id\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n        self.pad_token_id = pad_token_id\n        self.eoi_token_id = eoi_token_id\n        self.beacon_token_id = beacon_token_id\n        self.special_ocr_tokens = special_ocr_tokens\n        self.special_token_count = special_token_count  # pad, bos, etc, tokens\n        self.max_sequence_length = max_sequence_length\n        self.tasks = tasks\n        self.tie_word_embeddings = True\n        self.bbox_embed_size = bbox_embed_size\n        self.num_register_tokens = num_register_tokens\n        self.register_token_ids = register_token_ids\n        self.image_embed_encoding_size = image_embed_encoding_size\n        self.image_embed_encoding_multiplier = image_embed_encoding_multiplier\n        self.num_beacon_tokens = num_beacon_tokens\n        self.beacon_token_interval = beacon_token_interval\n        self.sliding_window = sliding_window\n        self.multi_output_distance = multi_output_distance\n        self.max_multi_out = max_multi_out\n\n        if self.sliding_window is None:\n            self.sliding_window = self.max_sequence_length\n\n        if isinstance(vision_encoder, dict):\n            vision_encoder = SuryaEncoderConfig(**vision_encoder)\n        elif vision_encoder is None:\n            vision_encoder = SuryaEncoderConfig()\n        self.vision_encoder = vision_encoder\n\n        if isinstance(decoder, dict):\n            decoder = SuryaDecoderConfig(**decoder)\n        elif decoder is None:\n            decoder = SuryaDecoderConfig()\n        self.decoder = decoder\n\n        self.hidden_size = self.decoder.hidden_size\n\n        self.patch_size = self.vision_encoder.spatial_patch_size\n        self.merge_size = self.vision_encoder.spatial_merge_size\n"
  },
  {
    "path": "surya/common/surya/decoder/__init__.py",
    "content": "from typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import (\n    Cache,\n)\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils import (\n    logging,\n)\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.surya.decoder.config import SuryaDecoderConfig\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen2MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(\n        query.dtype\n    )\n    attn_weights = nn.functional.dropout(\n        attn_weights, p=dropout, training=module.training\n    )\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass Qwen2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: SuryaDecoderConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=True\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=False\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        cache_idxs: Optional[List[int]] = None,\n        num_valid_tokens: Optional[List[int]] = None,\n        text_lengths: Optional[List[int]] = None,\n        prefill: bool = False,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism\n            cache_kwargs = {\n                \"sin\": sin,\n                \"cos\": cos,\n                \"cache_position\": cache_position,\n                \"cache_idxs\": cache_idxs,\n                \"num_valid_tokens\": num_valid_tokens,\n                \"prefill\": prefill,\n                \"text_lengths\": text_lengths,\n            }\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\n                \"output_attentions\", False\n            ):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            elif self.config._attn_implementation == \"flash_attention_2\":\n                # Needed for CPU -> GPU\n                from surya.common.surya.flash_attn_utils import (\n                    flash_attn_decode,\n                    flash_attn_prefill,\n                )\n\n                if prefill:\n                    attention_interface = flash_attn_prefill\n                else:\n                    attention_interface = flash_attn_decode\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[\n                    self.config._attn_implementation\n                ]\n\n        \"\"\"\n        IMPORTANT:\n        We sometimes use a custom sliding window impl. during training\n\n        We force this to None to ensure that the HF attention integrations do not\n        perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead\n        to infer the final mask\n\n        SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23)\n        \"\"\"\n        sliding_window = None\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=sliding_window,  # main diff with Llama\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen2DecoderLayer(nn.Module):\n    def __init__(self, config: SuryaDecoderConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)\n        self.mlp = Qwen2MLP(config)\n        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        cache_idxs: Optional[List[int]] = None,\n        num_valid_tokens: Optional[List[int]] = None,\n        text_lengths: Optional[List[int]] = None,\n        prefill: bool = False,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            cache_idxs=cache_idxs,\n            num_valid_tokens=num_valid_tokens,\n            text_lengths=text_lengths,\n            prefill=prefill,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        return outputs\n\n\nclass Qwen2RotaryEmbedding(nn.Module):\n    def __init__(self, config: SuryaDecoderConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\n                \"rope_type\", config.rope_scaling.get(\"type\")\n            )\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(\n                self.config, device, seq_len=seq_len\n            )\n            self.register_buffer(\n                \"inv_freq\", inv_freq, persistent=False\n            )  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if (\n            seq_len < self.original_max_seq_len\n            and self.max_seq_len_cached > self.original_max_seq_len\n        ):  # reset\n            # This .to() is needed if the model has been moved to a device after being initialized (because\n            # the buffer is automatically moved, but not the original copy)\n            self.original_inv_freq = self.original_inv_freq.to(device)\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = (\n            device_type\n            if isinstance(device_type, str) and device_type != \"mps\"\n            else \"cpu\"\n        )\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (\n                inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()\n            ).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass Qwen2PreTrainedModel(SuryaPreTrainedModel):\n    config_class = SuryaDecoderConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen2DecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nclass SuryaDecoderModel(Qwen2PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n    This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: SuryaDecoderConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.layers = nn.ModuleList(\n            [\n                Qwen2DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen2RotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        cache_idxs: Optional[List[int]] = None,\n        num_valid_tokens: Optional[List[int]] = None,\n        text_lengths: Optional[List[int]] = None,\n        prefill: bool = False,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if inputs_embeds is None:\n            raise ValueError(\"You must specify inputs_embeds\")\n\n        if cache_position is None:\n            raise ValueError(\"You must specify cache_position\")\n\n        if position_ids is None:\n            raise ValueError(\"You must specify position_ids\")\n\n        hidden_states = inputs_embeds\n        causal_mask = (\n            attention_mask  # We make the 4D mask in the combined model when needed\n        )\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=causal_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                cache_idxs=cache_idxs,\n                num_valid_tokens=num_valid_tokens,\n                prefill=prefill,\n                text_lengths=text_lengths,\n                **flash_attn_kwargs,\n            )\n\n            hidden_states = layer_outputs[0]\n\n        hidden_states = self.norm(hidden_states)\n\n        output = BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values if use_cache else None,\n        )\n        return output if return_dict else output.to_tuple()\n"
  },
  {
    "path": "surya/common/surya/decoder/config.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass SuryaDecoderConfig(PretrainedConfig):\n    model_type = \"qwen2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen2`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=4096,\n        intermediate_size=22016,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = False  # Disable sliding window\n        self.sliding_window = (\n            sliding_window  # we check `use_sliding_window` in the modeling code\n        )\n        self.max_window_layers = max_window_layers\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "surya/common/surya/embedder/__init__.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SimpleTokenEmbedder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.bbox_embed = nn.ModuleList(\n            [\n                nn.Embedding(\n                    config.bbox_size + config.special_token_count,\n                    config.bbox_embed_size,\n                )\n                for _ in range(6)\n            ]\n        )\n        self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1\n        self.max_bbox_size = config.bbox_size\n\n    def embed(\n        self,\n        input_tokens: torch.Tensor,\n        input_boxes: torch.Tensor | None,\n        embed_boxes: torch.Tensor,\n    ) -> torch.Tensor:\n        # Embed tokens\n        token_embeds = self.token_embed(input_tokens)\n\n        # Optionally embed boxes\n        if input_boxes is not None and embed_boxes.any():  # Is none in prefill\n            input_boxes = input_boxes.to(torch.long)\n            bbox_loss_ignore_mask = (\n                (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size)\n            ).unsqueeze(-1)\n            input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding)\n\n            bbox_embeds = torch.sum(\n                torch.stack(\n                    [\n                        self.bbox_embed[i](input_boxes[:, :, i])\n                        for i in range(len(self.bbox_embed))\n                    ],\n                    dim=-1,\n                ),\n                dim=-1,\n            )\n\n            bbox_embeds = F.pad(\n                bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0)\n            )\n            embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds)\n            bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds)\n\n            mask = embed_boxes & ~bbox_loss_ignore_mask\n            bbox_embeds *= mask.float()\n\n            token_embeds = token_embeds + bbox_embeds\n\n        return token_embeds\n"
  },
  {
    "path": "surya/common/surya/encoder/__init__.py",
    "content": "import math\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers.activations import ACT2FN\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.surya.encoder.config import SuryaEncoderConfig\nfrom surya.common.xla import get_nearest_pad\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nif settings.FOUNDATION_XLA:\n    import torch_xla.experimental.custom_kernel\n\nfrom surya.logging import get_logger\nlogger = get_logger()\n\n\nclass Qwen2_5_VLMLP(nn.Module):\n    def __init__(self, config, bias: bool = False):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_state):\n        return self.down_proj(\n            self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)\n        )\n\n\nclass Qwen2_5_VisionPatchEmbed(nn.Module):\n    def __init__(\n        self,\n        patch_size: int = 14,\n        temporal_patch_size: int = 2,\n        in_channels: int = 3,\n        embed_dim: int = 1152,\n    ) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.temporal_patch_size = temporal_patch_size\n        self.in_channels = in_channels\n        self.embed_dim = embed_dim\n\n        kernel_size = [temporal_patch_size, patch_size, patch_size]\n        self.proj = nn.Conv3d(\n            in_channels,\n            embed_dim,\n            kernel_size=kernel_size,\n            stride=kernel_size,\n            bias=False,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        bsz = hidden_states.shape[0]\n        hidden_states = hidden_states.view(\n            -1,\n            self.in_channels,\n            self.temporal_patch_size,\n            self.patch_size,\n            self.patch_size,\n        )\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(\n            bsz, -1, self.embed_dim\n        )\n        return hidden_states\n\n\nclass Qwen2_5_VisionRotaryEmbedding(nn.Module):\n    def __init__(self, dim: int, theta: float = 10000.0) -> None:\n        super().__init__()\n        self.inv_freq = 1.0 / (\n            theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)\n        )\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=\"cpu\", dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs\n\n\nclass Qwen2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen2_5_VLPatchMerger(nn.Module):\n    def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:\n        super().__init__()\n        self.hidden_size = context_dim * (spatial_merge_size**2)\n        self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)\n        self.mlp = nn.Sequential(\n            nn.Linear(self.hidden_size, self.hidden_size),\n            nn.GELU(),\n            nn.Linear(self.hidden_size, dim),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        bsz = x.shape[0]\n        x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size))\n        return x\n\n\ndef apply_rotary_pos_emb_flashatt(\n    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    from flash_attn.layers.rotary import apply_rotary_emb\n\n    cos = cos.chunk(2, dim=-1)[0].contiguous()\n    sin = sin.chunk(2, dim=-1)[0].contiguous()\n    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)\n    k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)\n    return q_embed, k_embed\n\n\nclass Qwen2_5_VLVisionXLASdpaAttention(nn.Module):\n    def __init__(self, dim: int, num_heads: int = 16) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.proj = nn.Linear(dim, dim)\n        self.head_dim = dim // num_heads\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]\n        q, k, v = (\n            self.qkv(hidden_states)\n            .reshape(bsz, seq_length, 3, self.num_heads, -1)\n            .permute(0, 2, 1, 3, 4)\n            .unbind(1)\n        )\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        else:\n            cos, sin = position_embeddings\n        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)\n\n        attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool)\n        cu_seqlens_cpu = cu_seqlens.cpu()\n        for j in range(bsz):\n            batch_seqlens = cu_seqlens_cpu[j]\n            for i in range(1, len(batch_seqlens)):\n                attention_mask[\n                    j,\n                    ...,\n                    batch_seqlens[i - 1] : batch_seqlens[i],\n                    batch_seqlens[i - 1] : batch_seqlens[i],\n                ] = True\n\n        attention_mask = attention_mask.to(q.device)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attention_mask,\n            dropout_p=0.0,\n        )\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, seq_length, -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5_VLVisionXLAFlashAttention2(nn.Module):\n    def __init__(self, dim: int, num_heads: int = 16) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.proj = nn.Linear(dim, dim)\n        self.head_dim = dim // num_heads\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        # Note, this is faster than SDPA, but pretty memory inefficient\n        # It also has significant accuracy issues\n\n        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]\n\n        # Single reshape to target layout - avoid multiple operations\n        q, k, v = (\n            self.qkv(hidden_states)\n            .reshape(bsz, seq_length, 3, self.num_heads, -1)\n            .permute(0, 2, 1, 3, 4)\n            .unbind(1)\n        )\n\n        # Apply rotary embeddings if provided\n        if position_embeddings is not None:\n            cos, sin = position_embeddings\n            q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)\n\n        # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim]\n        q = q.transpose(1, 2)  # [bsz, num_heads, seq_len, head_dim]\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        total_seqlen = q.shape[2]\n        # from cu_seqlens to segment ids for each position in dim 0\n        additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype)\n        min_val = torch.finfo(q.dtype).min\n\n        for i in range(bsz):\n            padding_end = cu_seqlens[i][1].item()\n            additive_bias[i, :, :, :padding_end] = min_val\n\n        additive_bias = additive_bias.to(hidden_states.device)\n\n        attn_scale = 1 / math.sqrt(self.head_dim)\n        attn_output = torch_xla.experimental.custom_kernel.flash_attention(\n            q, k, v, sm_scale=attn_scale, ab=additive_bias\n        )\n        attn_output = (\n            attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1)\n        )\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5_VLVisionFlashAttention2(nn.Module):\n    def __init__(self, dim: int, num_heads: int = 16) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        from flash_attn import flash_attn_varlen_func\n\n        bsz = hidden_states.shape[0]\n        seq_length = hidden_states.shape[1]\n        q, k, v = (\n            self.qkv(hidden_states)\n            .reshape(bsz, seq_length, 3, self.num_heads, -1)\n            .permute(0, 2, 1, 3, 4)\n            .unbind(1)\n        )\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        else:\n            cos, sin = position_embeddings\n\n        q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0))\n\n        q = q.squeeze(0)\n        k = k.squeeze(0)\n        v = v.squeeze(0)\n        cu_seqlens = cu_seqlens.squeeze(0)\n\n        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()\n        attn_output = flash_attn_varlen_func(\n            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen\n        ).reshape(bsz, seq_length, -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb_vision(\n    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    orig_q_dtype = q.dtype\n    orig_k_dtype = k.dtype\n    q, k = q.float(), k.float()\n    cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    q_embed = q_embed.to(orig_q_dtype)\n    k_embed = k_embed.to(orig_k_dtype)\n    return q_embed, k_embed\n\n\nclass Qwen2_5_VLVisionAttention(nn.Module):\n    def __init__(self, dim: int, num_heads: int = 16) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]\n        q, k, v = (\n            self.qkv(hidden_states)\n            .reshape(bsz, seq_length, 3, self.num_heads, -1)\n            .permute(0, 2, 1, 3, 4)\n            .unbind(1)\n        )\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        else:\n            cos, sin = position_embeddings\n\n        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)\n\n        attention_mask = torch.full(\n            [bsz, 1, seq_length, seq_length],\n            torch.finfo(q.dtype).min,\n            device=q.device,\n            dtype=q.dtype,\n        )\n        for j in range(bsz):\n            batch_seqlens = cu_seqlens[j]\n            for i in range(1, len(batch_seqlens)):\n                attention_mask[\n                    j,\n                    ...,\n                    batch_seqlens[i - 1] : batch_seqlens[i],\n                    batch_seqlens[i - 1] : batch_seqlens[i],\n                ] = 0\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)\n        attn_weights = attn_weights + attention_mask\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(q.dtype)\n        attn_output = torch.matmul(attn_weights, v)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, seq_length, -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5_VLVisionSdpaAttention(nn.Module):\n    def __init__(self, dim: int, num_heads: int = 16) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        self.qkv = nn.Linear(dim, dim * 3, bias=True)\n        self.proj = nn.Linear(dim, dim)\n\n    def unpack_qkv_with_mask(self, q, k, v, cu_seqlens):\n        \"\"\"\n        Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask.\n\n        Args:\n            q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim)\n            cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths\n\n        Returns:\n            batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)\n            batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)\n            batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)\n            attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len)\n                            with 0 for valid tokens and -inf for padding (for additive attention)\n        \"\"\"\n        device = q.device\n        dtype = q.dtype\n\n        batch_size = cu_seqlens.shape[0] - 1\n        num_heads = q.shape[1]\n        head_dim = q.shape[2]\n\n        seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]  # Keep as tensor\n        max_seq_len = seq_lengths.max().item()  # Use .max() on tensor\n\n        if settings.FOUNDATION_STATIC_CACHE:\n            # Pad max_seq_len to the nearest multiple for compilation\n            max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16)\n\n            # Pad batch_size to the nearest multiple for compilation\n            batch_size = get_nearest_pad(batch_size, pad_multiple=2)\n\n            # Ensure seq_lengths is a tensor of the correct size\n            seq_lengths = F.pad(\n                seq_lengths, (0, batch_size - seq_lengths.size(0)), \"constant\", 0\n            )\n\n        # some day, you may look at this, and think: \"what if I used repeat_interlave or some other fancy torch instead\"?\n        # don't do this - it's a path to madness.  For some reason, this loop is optimal\n\n        batch_indices = []\n        position_indices = []\n\n        for i, seq_len in enumerate(\n            seq_lengths.tolist()\n        ):  # Convert to list only for iteration\n            batch_indices.extend([i] * seq_len)\n            position_indices.extend(list(range(seq_len)))\n\n        batch_indices = torch.tensor(batch_indices, device=device)\n        position_indices = torch.tensor(position_indices, device=device)\n\n        batched_q = torch.zeros(\n            (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype\n        )\n        batched_k = torch.zeros_like(batched_q)\n        batched_v = torch.zeros_like(batched_q)\n\n        # Create additive attention mask\n        attention_mask = torch.full(\n            (batch_size, max_seq_len, max_seq_len),\n            fill_value=float(\"-inf\"),\n            device=device,\n            dtype=dtype,\n        )\n\n        # Create mask for valid positions\n        seq_range = torch.arange(max_seq_len, device=device)\n        valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze(\n            1\n        )  # (batch_size, max_seq_len)\n        valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(\n            1\n        )  # (batch_size, max_seq_len, max_seq_len)\n\n        # Simply use boolean indexing to set valid positions to 0\n        attention_mask[valid_2d] = 0\n\n        attention_mask = attention_mask.unsqueeze(\n            1\n        )  # (batch_size, 1, max_seq_len, max_seq_len)\n\n        batched_q[batch_indices, position_indices] = q\n        batched_k[batch_indices, position_indices] = k\n        batched_v[batch_indices, position_indices] = v\n\n        return (\n            batched_q,\n            batched_k,\n            batched_v,\n            attention_mask,\n            batch_indices,\n            position_indices,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        hidden_states = hidden_states.squeeze(0)\n        cu_seqlens = cu_seqlens.squeeze(0)\n\n        seq_length = hidden_states.shape[0]\n        q, k, v = (\n            self.qkv(hidden_states)\n            .reshape(seq_length, 3, self.num_heads, -1)\n            .permute(1, 0, 2, 3)\n            .unbind(0)\n        )\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        else:\n            cos, sin = position_embeddings\n        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)\n        q = q.squeeze(0)\n        k = k.squeeze(0)\n\n        q, k, v, attention_mask, batch_indices, position_indices = (\n            self.unpack_qkv_with_mask(q, k, v, cu_seqlens)\n        )\n        batch_size, max_seqlen = q.shape[:2]\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attention_mask,\n            dropout_p=0.0,\n        )\n        attn_output = attn_output.permute(0, 2, 1, 3).reshape(\n            batch_size, max_seqlen, -1\n        )  # Bring back to (batch_size, max_seqlen, hidden_dim)\n        attn_output = attn_output[batch_indices, position_indices]\n        attn_output = self.proj(attn_output)\n\n        return attn_output.unsqueeze(0)\n\n\nQWEN2_5_VL_VISION_ATTENTION_CLASSES = {\n    \"eager\": Qwen2_5_VLVisionAttention,\n    \"flash_attention_2\": Qwen2_5_VLVisionXLAFlashAttention2\n    if settings.FOUNDATION_XLA\n    else Qwen2_5_VLVisionFlashAttention2,\n    \"sdpa\": Qwen2_5_VLVisionXLASdpaAttention\n    if settings.FOUNDATION_XLA\n    else Qwen2_5_VLVisionSdpaAttention,\n}\n\n\nclass Qwen2_5_VLVisionBlock(nn.Module):\n    def __init__(self, config, attn_implementation: str = \"sdpa\") -> None:\n        super().__init__()\n        self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)\n        self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)\n        self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](\n            config.hidden_size, num_heads=config.num_heads\n        )\n        self.mlp = Qwen2_5_VLMLP(config, bias=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        hidden_states = hidden_states + self.attn(\n            self.norm1(hidden_states),\n            cu_seqlens=cu_seqlens,\n            rotary_pos_emb=rotary_pos_emb,\n            position_embeddings=position_embeddings,\n        )\n        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))\n        return hidden_states\n\n\nQwen2_5_VL_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2_5_VLConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\nclass Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel):\n    config_class = SuryaEncoderConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen2_5_VLDecoderLayer\", \"Qwen2_5_VLVisionBlock\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_static_cache = False  # TODO (joao): fix. torch.compile failing probably due to `cache_positions`\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, (nn.Linear, nn.Conv3d)):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nclass Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):\n    config_class = SuryaEncoderConfig\n    _no_split_modules = [\"Qwen2_5_VLVisionBlock\"]\n\n    def __init__(self, config, *inputs, **kwargs) -> None:\n        super().__init__(config, *inputs, **kwargs)\n        self.spatial_merge_size = config.spatial_merge_size\n        self.patch_size = config.patch_size\n        self.fullatt_block_indexes = config.fullatt_block_indexes\n        self.window_size = config.window_size\n        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size\n\n        self.patch_embed = Qwen2_5_VisionPatchEmbed(\n            patch_size=config.patch_size,\n            temporal_patch_size=config.temporal_patch_size,\n            in_channels=config.in_channels,\n            embed_dim=config.hidden_size,\n        )\n\n        head_dim = config.hidden_size // config.num_heads\n        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)\n\n        self.blocks = nn.ModuleList(\n            [\n                Qwen2_5_VLVisionBlock(config, config._attn_implementation)\n                for _ in range(config.depth)\n            ]\n        )\n        self.merger = Qwen2_5_VLPatchMerger(\n            dim=config.out_hidden_size,\n            context_dim=config.hidden_size,\n            spatial_merge_size=config.spatial_merge_size,\n        )\n        self.gradient_checkpointing = False\n\n    def rot_pos_emb(self, grid_thw):\n        rotary_pos_emb = []\n        grid_thw_list = grid_thw.cpu().tolist()\n        for batch_item in grid_thw_list:\n            row_pos_ids = []\n            heights = [h for _, h, _ in batch_item]\n            widths = [w for _, _, w in batch_item]\n            max_grid_size = max(heights + widths)\n            for t, h, w in batch_item:\n                hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n                hpos_ids = hpos_ids.reshape(\n                    h // self.spatial_merge_size,\n                    self.spatial_merge_size,\n                    w // self.spatial_merge_size,\n                    self.spatial_merge_size,\n                )\n                hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n                hpos_ids = hpos_ids.flatten()\n\n                wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n                wpos_ids = wpos_ids.reshape(\n                    h // self.spatial_merge_size,\n                    self.spatial_merge_size,\n                    w // self.spatial_merge_size,\n                    self.spatial_merge_size,\n                )\n                wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n                wpos_ids = wpos_ids.flatten()\n                # shape: token_count, 2\n                row_pos_ids.append(\n                    torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)\n                )\n            # shape: token_count, 2\n            pos_ids = torch.cat(row_pos_ids, dim=0)\n            rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n            rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1)\n            rotary_pos_emb.append(rotary_pos_emb_row)\n        rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0)\n        return rotary_pos_emb\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        grid_thw: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`):\n                The final hidden states of the model.\n            grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`):\n                The temporal, height and width of feature shape of each image in LLM.\n\n        Returns:\n            `torch.Tensor`: hidden_states.\n        \"\"\"\n        bsz, seq_len, _ = hidden_states.size()\n        hidden_states = self.patch_embed(hidden_states)  # (bsz, seq_len, hidden_dim)\n        rotary_pos_emb = self.rot_pos_emb(grid_thw)\n\n        # hidden_states = hidden_states.reshape(bsz, seq_len, -1)\n        # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1)\n        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to(\n            hidden_states.device\n        )\n        position_embeddings = (emb.cos(), emb.sin())\n\n        cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum(\n            dim=1,\n            # Select dtype based on the following factors:\n            #  - FA2 requires that cu_seqlens_q must have dtype int32\n            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw\n            # See https://github.com/huggingface/transformers/pull/34852 for more information\n            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,\n        )\n        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)\n        for layer_num, blk in enumerate(self.blocks):\n            if self.gradient_checkpointing and self.training:\n                hidden_states = self._gradient_checkpointing_func(\n                    blk.__call__,\n                    hidden_states,\n                    cu_seqlens,\n                    None,\n                    position_embeddings,\n                )\n            else:\n                hidden_states = blk(\n                    hidden_states,\n                    cu_seqlens=cu_seqlens,\n                    position_embeddings=position_embeddings,\n                )\n\n        hidden_states = self.merger(hidden_states)\n        return hidden_states\n\n\nclass SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel):\n    @property\n    def image_size(self) -> int:\n        config: SuryaEncoderConfig = self.config\n        if isinstance(config.image_size, tuple) and len(config.image_size) == 2:\n            return config.image_size\n        elif isinstance(config.image_size, int):\n            return (config.image_size, config.image_size)\n\n        raise ValueError(\n            f\"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}\"\n        )\n\n    @property\n    def hidden_size(self) -> int:\n        config: SuryaEncoderConfig = self.config\n        return config.hidden_size\n\n    def embed_images(\n        self,\n        image_batch: torch.Tensor,\n        grid_thw: torch.Tensor,\n    ) -> torch.Tensor:\n        return super().forward(\n            hidden_states=image_batch,\n            grid_thw=grid_thw,\n        )\n"
  },
  {
    "path": "surya/common/surya/encoder/config.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass SuryaEncoderConfig(PretrainedConfig):\n    model_type = \"qwen2_5_vl\"\n    base_config_key = \"vision_config\"\n\n    attribute_map = {\n        \"num_attention_heads\": \"num_heads\",\n        \"num_hidden_layers\": \"depth\",\n    }\n\n    def __init__(\n        self,\n        depth=8,\n        hidden_size=1280,\n        hidden_act=\"silu\",\n        intermediate_size=3420,\n        num_heads=16,\n        in_channels=3,\n        patch_size=14,\n        spatial_merge_size=2,\n        spatial_patch_size=14,\n        temporal_patch_size=1,\n        tokens_per_second=4,\n        window_size=112,\n        out_hidden_size=1280,\n        fullatt_block_indexes=(3, 7),\n        initializer_range=0.02,\n        image_size=4096,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.tokens_per_second = tokens_per_second\n        self.window_size = window_size\n        self.fullatt_block_indexes = fullatt_block_indexes\n        self.out_hidden_size = out_hidden_size\n        self.initializer_range = initializer_range\n        self.spatial_patch_size = spatial_patch_size\n        self.image_size = image_size\n"
  },
  {
    "path": "surya/common/surya/flash_attn_utils.py",
    "content": "from typing import Optional\nimport torch\nimport torch.nn.functional as F\nfrom flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func\nfrom flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache\nfrom flash_attn.bert_padding import index_first_axis as _index_first_axis\nfrom flash_attn.bert_padding import pad_input\n\ndef _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:\n    \"\"\"\n    Retrieves indexing data required to repad unpadded (ragged) tensors.\n\n    Arguments:\n        attention_mask (`torch.Tensor`):\n            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.\n\n    Return:\n        indices (`torch.Tensor`):\n            The indices of non-masked tokens from the flattened input sequence.\n        cu_seqlens (`torch.Tensor`):\n            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).\n        max_seqlen_in_batch (`int`):\n            Maximum sequence length in batch.\n    \"\"\"\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\ndef _upad_input(\n    query_layer: torch.Tensor,\n    key_layer: torch.Tensor,\n    value_layer: torch.Tensor,\n    query_length: int,\n    indices_k,\n    cu_seqlens_k,\n    max_seqlen_in_batch_k\n):\n    \"\"\"\n    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.\n\n    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary\n    tensors for query, key, value tensors.\n\n    Arguments:\n        query_layer (`torch.Tensor`):\n            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).\n        key_layer (`torch.Tensor`):\n            Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).\n        value_layer (`torch.Tensor`):\n            Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).\n        attention_mask (`torch.Tensor`):\n            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.\n        query_length (`int`):\n            Target length.\n\n    Return:\n        query_layer (`torch.Tensor`):\n            Query state without padding. Shape: (total_target_length, num_heads, head_dim).\n        key_layer (`torch.Tensor`):\n            Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).\n        value_layer (`torch.Tensor`):\n            Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).\n        indices_q (`torch.Tensor`):\n            The indices of non-masked tokens from the flattened input target sequence.\n        (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):\n            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).\n        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):\n            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).\n    \"\"\"\n    batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n    key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)\n    value_layer = _index_first_axis(\n        value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k\n    )\n    if query_length == kv_seq_len:\n        query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)\n        cu_seqlens_q = cu_seqlens_k\n        max_seqlen_in_batch_q = max_seqlen_in_batch_k\n        indices_q = indices_k\n    elif query_length == 1:\n        max_seqlen_in_batch_q = 1\n        cu_seqlens_q = torch.arange(\n            batch_size + 1, dtype=torch.int32, device=query_layer.device\n        )  # There is a memcpy here, that is very bad.\n        indices_q = cu_seqlens_q[:-1]\n        query_layer = query_layer.squeeze(1)\n    else:\n        raise NotImplementedError()\n\n    return (\n        query_layer,\n        key_layer,\n        value_layer,\n        indices_q,\n        (cu_seqlens_q, cu_seqlens_k),\n        (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n    )\n\ndef flash_attn_prefill(\n    module: torch.nn.Module,\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: torch.Tensor,\n    dropout: float,\n    scaling: float,\n    query_length: int,\n    batch_size: int,\n    indices_k: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_in_batch_k: int,\n    **kwargs\n):\n    \"\"\"\n    Wrapper for flash attention during the prefill stage\n    query_states must have shape (batch_size, num_heads, seq_len, head_dim)\n    key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)\n\n    This is the opposite of what is required by flash attention, but keeps parity with the HF convention\n\n    query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs\n    \"\"\"\n    query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)\n    q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input(\n        query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k\n    )\n    cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n    max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n    # Returning None for attn_weights to match other attention interfaces\n    flash_attn_out = _flash_attn_varlen_func(\n        q_flash,\n        k_flash,\n        v_flash,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n        max_seqlen_q=max_seqlen_in_batch_q,\n        max_seqlen_k=max_seqlen_in_batch_k,\n        dropout_p=dropout,\n        softmax_scale=scaling,\n        causal=module.is_causal,\n    )\n    return pad_input(flash_attn_out, indices_q, batch_size, query_length), None\n\n# NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility\n# This function is an order of magnitude faster than the prefill variant, or using the HF interface\ndef flash_attn_decode(\n    module: torch.nn.Module,\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: torch.Tensor,\n    scaling: float,\n    **kwargs,\n):\n    \"\"\"\n    Wrapper for flash attention during the decode stage\n    \n    query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage\n    key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)\n\n    This is the opposite of what is required by flash attention, but keeps parity with the HF convention\n\n    This function computes the left pad and cache seqlens to pass into FA2. For example - \n    Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token\n    attention_mask =\n    tensor([\n        [0, 0, 1, 1, 1, 0, 0, 0],  # ← batch 0\n        [0, 1, 1, 1, 1, 1, 1, 0],  # ← batch 1\n    ])\n    cache_leftpad = tensor([2, 1], dtype=torch.int32)\n    cache_seqlens = tensor([5, 7], dtype=torch.int32)\n    These values allow FlashAttention to use a static cache layout with efficient slicing during decoding.\n    \"\"\"\n    query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)\n\n    cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32)\n    cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1\n\n    # Returning None for attn_weights to match other attention interfaces\n    return _flash_attn_with_kvcache(\n        q=query_states,\n        k_cache=key_states,\n        v_cache=value_states,\n        cache_leftpad=cache_leftpad,\n        cache_seqlens=cache_seqlens,\n        causal=module.is_causal,\n        softmax_scale=scaling,\n    ), None"
  },
  {
    "path": "surya/common/surya/processor/__init__.py",
    "content": "import math\n\nimport cv2\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom torch.nn.utils.rnn import pad_sequence\n\nfrom typing import List, Optional, Tuple\n\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.processing_utils import ProcessorMixin\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.common.surya.processor.schema import (\n    TextInput,\n    ImageInput,\n    ProcessorOutput,\n)\nfrom surya.common.surya.schema import TaskNames\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n# Task agnostic tokens - Every task will use these in some form or another\nEOS_TOKEN = \"</S>\"\nEOI_TOKEN = \"<EOI>\"  # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways.\nIMAGE_TOKEN = \"<IMAGE>\"\nPAD_TOKEN = \"<PAD>\"\nNO_OUTPUT_TOKEN = \"<NOP>\"\nIMAGE_ROTATED_TOKEN = \"<ROT>\"\nREGISTER_TOKENS = [\"<REG1>\", \"<REG2>\", \"<REG3>\", \"<REG4>\"]\nBEACON_TOKEN = \"<BEACON>\"\nNOMATH_TOKEN = \"<NO-MATH>\"\n\n# Task specific tokens\nOCR_WITH_BOXES_BOS_TOKEN = \"<OCR-WB>\"\nOCR_WITHOUT_BOXES_BOS_TOKEN = \"<OCR-WOB>\"\nBLOCK_WITHOUT_BOXES_TOKEN = \"<BLOCKS-WOB>\"\nLAYOUT_BOS_TOKEN = \"<LAYOUT>\"\nTABLE_STRUCTURE_BOS_TOKEN = \"<TABLE-STRUCTURE>\"\n\n\nclass SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):\n    attributes = [\"image_processor\", \"ocr_tokenizer\"]\n    image_processor_class = \"BaseImageProcessor\"\n    ocr_tokenizer_class = \"PreTrainedTokenizer\"\n    rescale_factor = 1 / 255.0\n    image_mean = (0.485, 0.456, 0.406)\n    image_std = (0.229, 0.224, 0.225)\n\n    def __init__(\n        self,\n        ocr_tokenizer: PreTrainedTokenizer,\n        blank_bbox_token_id: int,\n        num_register_tokens: int,\n        patch_size: int,\n        merge_size: int,\n        num_beacon_tokens: int,\n        beacon_token_interval: int,\n        model_device: str,\n        **kwargs,\n    ):\n        self.ocr_tokenizer = ocr_tokenizer\n        self.patch_size = patch_size\n        self.merge_size = merge_size\n        self.num_register_tokens = num_register_tokens\n        self.num_beacon_tokens = num_beacon_tokens\n        self.beacon_token_interval = beacon_token_interval\n\n        self.tokenizer_vocab_size = 0\n        for attr in self.attributes:\n            if \"tokenizer\" in attr:\n                self.tokenizer_vocab_size += getattr(self, attr).vocab_size\n\n        self.offsets = {\"ocr\": 0}\n\n        # Create special token mapping\n        self.special_token_mapping = self.ocr_tokenizer.system_tokens\n\n        self.register_token_ids = [\n            self.special_token_mapping.get(r) for r in REGISTER_TOKENS\n        ]\n        self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN)\n        self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN)\n        self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN)\n        self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN)\n        self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN)\n        self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN)\n        self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN)\n        self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN)\n\n        self.bos_token_id = {\n            TaskNames.ocr_with_boxes: self.special_token_mapping.get(\n                OCR_WITH_BOXES_BOS_TOKEN\n            ),\n            TaskNames.ocr_without_boxes: self.special_token_mapping.get(\n                OCR_WITHOUT_BOXES_BOS_TOKEN\n            ),\n            TaskNames.block_without_boxes: self.special_token_mapping.get(\n                BLOCK_WITHOUT_BOXES_TOKEN\n            ),\n            TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN),\n            TaskNames.table_structure: self.special_token_mapping.get(\n                TABLE_STRUCTURE_BOS_TOKEN\n            ),\n        }\n\n        if self.image_token_id is None:\n            logger.warning(\"Warning: Image token not found in special tokens\")\n\n        self.blank_bbox_token_id = blank_bbox_token_id\n        self.bbox_pad_token_id = self.blank_bbox_token_id\n\n        self.ignore_bbox_token_ids = [\n            v\n            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()\n            if k not in self.ocr_tokenizer.special_tokens[\"math_external\"]\n        ]\n        math_end_token = \"</math>\"\n        self.math_start_token_ids = [\n            v\n            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()\n            if k in self.ocr_tokenizer.special_tokens[\"math_external\"]\n            and k != math_end_token\n        ]\n        self.math_end_token_ids = [\n            v\n            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()\n            if k == math_end_token\n        ]\n\n        if self.num_register_tokens > len(self.register_token_ids):\n            raise ValueError(\n                \"The number of register tokens requested exceeds the number of register tokens defined in the special token mapping.\"\n            )\n\n        self.image_mean = np.array(self.image_mean, dtype=np.float32)\n        self.image_std = np.array(self.image_std, dtype=np.float32)\n        self.model_device = model_device\n\n    @property\n    def vocab_size(self):\n        return self.tokenizer_vocab_size\n\n    def image_processor(self, image: Image.Image) -> np.ndarray:\n        # Convert to array\n        image = np.asarray(image, dtype=np.float32)\n        return image\n\n    @staticmethod\n    def scale_to_fit(\n        img: np.ndarray,\n        max_size: Tuple[int, int],\n        min_size: Tuple[int, int] = (168, 168),\n    ):\n        # Get current dimensions\n        height, width = img.shape[:2]\n\n        # Check for empty or invalid image\n        if width == 0 or height == 0:\n            return img\n\n        max_width, max_height = max_size\n        min_width, min_height = min_size\n\n        # Calculate pixel counts\n        current_pixels = width * height\n        max_pixels = max_width * max_height\n        min_pixels = min_width * min_height\n\n        if current_pixels > max_pixels:\n            scale_factor = (max_pixels / current_pixels) ** 0.5\n\n            new_width = math.floor(width * scale_factor)\n            new_height = math.floor(height * scale_factor)\n        elif current_pixels == 0:\n            return img\n        elif current_pixels < min_pixels:\n            scale_factor = (min_pixels / current_pixels) ** 0.5\n\n            new_width = math.ceil(width * scale_factor)\n            new_height = math.ceil(height * scale_factor)\n        else:\n            return img\n\n        return cv2.resize(\n            img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4\n        )\n\n    def _image_processor(self, image: np.ndarray):\n        image = image.astype(np.float64) * self.rescale_factor\n        image = (image.astype(np.float32) - self.image_mean) / self.image_std\n        return image\n\n    def _process_and_tile(\n        self, image: np.ndarray\n    ) -> Tuple[torch.Tensor, Tuple[int, int, int]]:\n        \"\"\"\n        Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio\n        and returns a tensor of image tiles.\n        \"\"\"\n        extra_multipler = (\n            4 if settings.FOUNDATION_XLA else 1\n        )  # Needed to force same size grid_thws per row with padding\n\n        factor = (\n            self.patch_size * self.merge_size * extra_multipler\n        )  # Make a multiple of window size\n\n        height, width = image.shape[:2]\n\n        h_bar = math.ceil(height / factor) * factor\n        w_bar = math.ceil(width / factor) * factor\n        if h_bar != height or w_bar != width:\n            if height == 0 or width == 0:\n                image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8)\n            else:\n                image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC)\n\n        # Handle scaling and normalization\n        image = self._image_processor(image)\n        height, width = image.shape[:2]\n\n        # Numpy array to torch tensor\n        img_tensor = torch.from_numpy(image.transpose(2, 0, 1))\n        patches = img_tensor.unsqueeze(0)\n\n        channel = patches.shape[1]\n        grid_t = patches.shape[0]\n        grid_h, grid_w = height // self.patch_size, width // self.patch_size\n\n        patches = patches.reshape(\n            grid_t,\n            1,\n            channel,\n            grid_h // self.merge_size,\n            self.merge_size,\n            self.patch_size,\n            grid_w // self.merge_size,\n            self.merge_size,\n            self.patch_size,\n        )\n        patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)\n        flatten_patches = patches.reshape(\n            grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size\n        )\n\n        return flatten_patches, (grid_t, grid_h, grid_w)\n\n    # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly\n    def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput:\n        rotated = image_input.get(\"rotated\", False)\n        image = image_input.get(\"image\", None)\n\n        assert image is not None, (\n            \"A PIL Image must be provided when the input type is `image`\"\n        )\n        image_tiles, grid_thw = self._process_and_tile(image)\n\n        num_tokens = image_tiles.shape[0] / self.merge_size**2\n        assert num_tokens.is_integer(), (\n            f\"Expected number of tokens to be an integer, got {num_tokens}\"\n        )\n\n        input_ids = [self.image_token_id] * int(num_tokens)\n        input_ids += self.register_token_ids[: self.num_register_tokens]\n\n        # Handle the image being rotated in the imdataset\n        if rotated:\n            input_ids = [self.image_rotated_token] + input_ids\n\n        return ProcessorOutput(\n            input_ids=input_ids,\n            image_tiles=image_tiles,\n            grid_thw=grid_thw,\n        )\n\n    def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput:\n        input_text = text_input.get(\"text\", None)\n        math_mode = text_input.get(\"math\", False)\n\n        input_ids = self.ocr_tokenizer(input_text, tasks=task)[\"input_ids\"][0]\n        input_ids = [self.offsets[\"ocr\"] + id for id in input_ids]\n\n        # nomath token does not work for layout\n        if not math_mode and task != \"layout\":\n            input_ids.insert(0, self.nomath_token)\n\n        return ProcessorOutput(\n            input_ids=input_ids,\n            image_tiles=None,\n            grid_thw=None,\n        )\n\n    def _process_input(self, input_dict: dict, task: str):\n        input_type = input_dict[\"type\"]\n        if input_type == \"image\":\n            return self._process_image_input(input_dict)\n        elif input_type == \"text\":\n            return self._process_text_input(input_dict, task)\n\n        raise NotImplementedError(f\"Input of type `{input_type}` is not implemented\")\n\n    # Peprocessing for OCR task\n    # The task is expected to have - image_dict, user_input_dict, output_dict\n    # use_input_dict is allowed to have an empty input which is fine, but needs to be present\n    def _process_ocr_with_boxes(\n        self,\n        mixed_input: List[dict],\n        bos_token_id: int,\n        task: str = TaskNames.ocr_with_boxes,\n    ):\n        processed_input_ids = []\n        all_image_tiles = []\n        all_grid_thw = []\n\n        # 1. Process the image input\n        for i, input_dict in enumerate(mixed_input):\n            processor_output = self._process_input(input_dict, task)\n            input_ids = processor_output[\"input_ids\"]\n            image_tiles = processor_output[\"image_tiles\"]\n            grid_thw = processor_output[\"grid_thw\"]\n\n            # Special handling of some delimiter tokens\n            if i == 1:\n                assert input_dict[\"type\"] == \"text\", (\n                    \"Expected text input for model input.\"\n                )\n                # Case for input - Add task specific bos token + end_of_input token\n                # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these\n                input_ids = [bos_token_id] + input_ids + [self.eoi_token_id]\n            if i == 2:\n                assert input_dict[\"type\"] == \"text\", (\n                    \"Expected text for final model input\"\n                )\n                input_ids = input_ids + [self.eos_token_id]\n            elif i > 2:\n                raise ValueError(f\"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}\")\n\n            # Some input types don't return any image tiles, accounting for that\n            if image_tiles is not None:\n                all_image_tiles.append(image_tiles)\n                all_grid_thw.append(grid_thw)\n\n            processed_input_ids.extend(input_ids)\n\n        return (\n            torch.tensor(processed_input_ids, dtype=torch.long),\n            all_image_tiles,\n            all_grid_thw,\n        )\n\n    def _process_layout(self, mixed_input: List[dict], bos_token_id: int):\n        return self._process_ocr_with_boxes(\n            mixed_input, bos_token_id=bos_token_id, task=\"layout\"\n        )\n\n    def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int):\n        return self._process_ocr_with_boxes(\n            mixed_input, bos_token_id=bos_token_id, task=\"table_structure\"\n        )\n\n    def _process_ocr_without_boxes(\n        self,\n        mixed_input: List[dict],\n        bos_token_id: int,\n        task: str = \"ocr_without_boxes\",\n    ):\n        # Boxes are set to None, so this will work\n        # TODO: improve this behavior\n        return self._process_ocr_with_boxes(\n            mixed_input, bos_token_id=bos_token_id, task=task\n        )\n\n    def _process_block_without_boxes(\n        self,\n        mixed_input: List[dict],\n        bos_token_id: int,\n        task: str = \"block_without_boxes\",\n    ):\n        return self._process_ocr_with_boxes(\n            mixed_input, bos_token_id=bos_token_id, task=task\n        )\n\n    def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]:\n        height, width, _ = image.shape\n        if height > width:  # Rotate vertical lines\n            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)\n            return image, True\n\n        return image, False\n\n    def __call__(\n        self,\n        mixed_batch: List[dict],\n        padding_side: Optional[str] = \"left\",\n        device: Optional[torch.device] = None,\n        pad_to_multiple: Optional[int] = None,\n    ):\n        all_image_tiles = []\n        all_input_ids = []\n        all_grid_thw = []\n\n        for b in mixed_batch:\n            mixed_input = b[\"inputs\"]\n            task = b[\"task\"]\n            assert task in self.bos_token_id, f\"Task {task} has no bos token defined.\"\n\n            # Select the correct processing function based on the task type\n            input_ids, image_tiles, grid_thw = getattr(self, f\"_process_{task}\")(\n                mixed_input, self.bos_token_id[task]\n            )\n\n            all_input_ids.append(input_ids)\n            all_image_tiles.extend(image_tiles)\n            all_grid_thw.extend(grid_thw)\n\n        batched_input_ids = pad_sequence(\n            all_input_ids,\n            batch_first=True,\n            padding_side=padding_side,\n            padding_value=self.pad_token_id,\n        )\n\n        if pad_to_multiple is not None:\n            current_len = batched_input_ids.shape[1]\n            # Calculate the next multiple of pad_to_multiple\n            padded_len = (\n                (current_len + pad_to_multiple - 1) // pad_to_multiple\n            ) * pad_to_multiple\n\n            if padded_len > current_len:\n                pad_len = padded_len - current_len\n                batched_input_ids = torch.nn.functional.pad(\n                    batched_input_ids, (pad_len, 0), value=self.pad_token_id\n                )\n\n        attention_mask = batched_input_ids.ne(self.pad_token_id)\n\n        # Generating position IDs that are independent of left and right padding;\n        # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked\n        position_ids = attention_mask.cumsum(dim=-1) - 1\n        position_ids[position_ids < 0] = (\n            0  # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0\n        )\n        position_ids = (\n            attention_mask.to(torch.long) * position_ids\n        )  # Ensure right pad ids get set to zero\n\n        batched_image_tiles = torch.cat(all_image_tiles, dim=0)\n        batched_grid_thw = torch.from_numpy(np.array(all_grid_thw))\n\n        # Pin memory for CUDA\n        if device == torch.device(\"cuda\"):\n            batched_image_tiles = batched_image_tiles.pin_memory()\n            batched_grid_thw = batched_grid_thw.pin_memory()\n            attention_mask = attention_mask.pin_memory()\n            batched_input_ids = batched_input_ids.pin_memory()\n            position_ids = position_ids.pin_memory()\n\n        return BatchFeature(\n            {\n                \"input_ids\": batched_input_ids,\n                \"image_tiles\": batched_image_tiles,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"grid_thw\": batched_grid_thw,\n            }\n        )\n\n    # Decode model outputs; Strips special tokens\n    def decode(self, tokens: List[int], task: str):\n        filtered_tokens = [\n            t\n            for t in tokens\n            if t not in self.special_token_mapping.values() and t != -100\n        ]  # Skip special tokens and loss ignore index\n        return self.ocr_tokenizer.decode(filtered_tokens, task=task)\n"
  },
  {
    "path": "surya/common/surya/processor/schema.py",
    "content": "from typing import TypedDict, Literal, List, Tuple\n\nimport torch\nfrom PIL import Image\n\n\nclass TaskDict(TypedDict):\n    datasets: List[str]\n    img_size: Tuple[int, int]\n\n\nclass TasksDict(TypedDict):\n    ocr_with_boxes: TaskDict\n    ocr_without_boxes: TaskDict\n    block_without_boxes: TaskDict\n\n\nclass ProcessorInput(TypedDict):\n    type: Literal[\"image\", \"ocr\", \"text\", \"empty_output\"]\n\n\nclass ImageInput(ProcessorInput):\n    type: Literal[\"image\"]\n    image: Image.Image\n    rotated: bool\n\n\nclass TextInput(ProcessorInput):\n    type: Literal[\"text\"]\n    text: str\n    math: bool\n\n\nclass ProcessorOutput(TypedDict):\n    input_ids: List[int]\n    image_tiles: torch.Tensor | None\n    grid_thw: torch.Tensor | None\n"
  },
  {
    "path": "surya/common/surya/processor/tokenizer.py",
    "content": "import html\nimport re\nfrom typing import List, Union, Dict, Optional, Tuple, Iterable\nimport numpy as np\nimport torch\nfrom tokenizers import AddedToken\nimport json\nimport os\nfrom transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer\n\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.common.surya.schema import TASK_NAMES, TaskNames\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n\ndef create_token_regex(tokens):\n    escaped_tokens = [re.escape(token) for token in tokens]\n    escaped_tokens.sort(key=len, reverse=True)\n    pattern = r\"^(\" + \"|\".join(escaped_tokens) + r\")\"\n    regex = re.compile(pattern)\n    return regex\n\n\nclass Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):\n    pass\n\nclass GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer):\n    \"\"\"\n    HuggingFace slow tokenizer implementing:\n      - UTF-16 code units as the base [0..65535]\n      - Math tokens as greedy-longest-match ids after UTF-16\n      - Literal special tokens after math tokens\n    Absolute ID layout:\n      [0 .. 65535]                      : UTF-16 units\n      [65536 .. 65536+M-1]              : math tokens\n      [65536+M .. 65536+M+S-1]          : special tokens\n    \"\"\"\n\n    vocab_files_names = {\n        \"vocab_file\": \"vocab_math.json\",  # {\"\\\\frac\": 0, \"\\\\alpha\": 1, ...} raw contiguous ids 0..M-1\n        \"specials_file\": \"specials.json\",  # [flat list for legacy]\n        \"specials_dict_file\": \"specials_dict.json\",  # category dict (preferred)\n    }\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n    is_fast = False\n\n    # ---------- helpers ----------\n    @staticmethod\n    def _to_utf16_units(s: str) -> List[int]:\n        b = s.encode(\"utf-16le\")\n        return [int.from_bytes(b[i : i + 2], \"little\") for i in range(0, len(b), 2)]\n\n    @staticmethod\n    def _from_utf16_units(units: List[int]) -> str:\n        b = bytearray()\n        for u in units:\n            b += int(u).to_bytes(2, \"little\")\n        return b.decode(\"utf-16le\", errors=\"ignore\")\n\n    class _TrieNode:\n        __slots__ = (\"child\", \"id\", \"leaf\")\n\n        def __init__(self):\n            self.child: Dict[str, \"GreedyMathUTF16Tokenizer._TrieNode\"] = {}\n            self.id: Optional[int] = None\n            self.leaf: bool = False\n\n    @classmethod\n    def _build_trie(\n        cls, token_to_id: Dict[str, int]\n    ) -> \"GreedyMathUTF16Tokenizer._TrieNode\":\n        root = cls._TrieNode()\n        for tok, tid in token_to_id.items():\n            node = root\n            for ch in tok:\n                node = node.child.setdefault(ch, cls._TrieNode())\n            node.leaf = True\n            node.id = tid\n        return root\n\n    def _build_escape_patterns(self, math_token_to_rawid):\n        \"\"\"Build pattern list from vocab commands that start with control characters.\n\n        Scans the math vocab for LaTeX commands that could be corrupted by JSON\n        escape sequence interpretation (e.g., \\\\begin becomes <backspace>egin).\n        \"\"\"\n        control_chars = {\n            '\\x08': 'b',  # backspace\n            '\\t': 't',    # tab\n            '\\n': 'n',    # newline\n            '\\r': 'r',    # carriage return\n            '\\f': 'f',    # form feed\n            '\\x07': 'a',  # bell\n            '\\x0b': 'v',  # vertical tab\n        }\n\n        patterns = {char: [] for char in control_chars}\n\n        for token in math_token_to_rawid.keys():\n            if token.startswith('\\\\') and len(token) > 1:\n                letter = token[1:2]  # First char after backslash\n                for ctrl_char, ctrl_letter in control_chars.items():\n                    if letter == ctrl_letter:\n                        # This token could be corrupted: \\token -> <ctrl>oken\n                        suffix = token[2:]  # Everything after \\X\n                        patterns[ctrl_char].append((suffix, token))\n\n        # Sort by length (longest first) to avoid partial matches\n        for char in patterns:\n            patterns[char].sort(key=lambda x: len(x[0]), reverse=True)\n\n        return patterns\n\n    @classmethod\n    def _encode_math_greedy(\n        cls,\n        s: str,\n        trie: \"GreedyMathUTF16Tokenizer._TrieNode\",\n        math_base: int,\n        debug: bool = False,\n    ) -> List[int]:\n        i, n = 0, len(s)\n        out: List[int] = []\n        while i < n:\n            node = trie\n            j = i\n            last_id = None\n            last_j = i\n            while j < n and (ch := s[j]) in node.child:\n                node = node.child[ch]\n                j += 1\n                if node.leaf:\n                    last_id, last_j = node.id, j\n            if last_id is not None:\n                if debug:\n                    print(f\"[MATH] matched {s[i:last_j]!r} -> {last_id}\")\n                out.append(math_base + last_id)\n                i = last_j\n            else:\n                units = cls._to_utf16_units(s[i])\n                if debug:\n                    print(f\"[MATH] fallback {s[i]!r} -> utf16 {units}\")\n                out.extend(units)\n                i += 1\n        return out\n\n    # ---------- init ----------\n    def __init__(\n        self,\n        vocab_file: Optional[str] = None,\n        specials_file: Optional[str] = None,\n        specials_dict_file: Optional[str] = None,\n        *,\n        # You can also pass programmatically instead of files:\n        math_vocab: Optional[Dict[str, int]] = None,\n        special_tokens: Optional[List[str]] = None,\n        special_tokens_dict: Optional[Dict[str, List[str]]] = None,\n        debug: bool = False,\n        # Standard HF special token kwargs:\n        bos_token: Optional[str] = None,\n        eos_token: Optional[str] = None,\n        pad_token: Optional[str] = None,\n        unk_token: Optional[str] = None,\n        **kwargs,\n    ):\n        # Load math vocab\n        if vocab_file and os.path.isfile(vocab_file):\n            with open(vocab_file, \"r\", encoding=\"utf-8\") as f:\n                mv = json.load(f)\n        else:\n            mv = math_vocab or {}\n\n        # Make math ids contiguous if needed\n        if mv:\n            max_id = max(mv.values())\n            if set(mv.values()) != set(range(max_id + 1)):\n                items = sorted(mv.items(), key=lambda kv: kv[1])\n                mv = {tok: i for i, (tok, _) in enumerate(items)}\n\n        # Load special tokens (prefer category dict; fallback to flat list or defaults)\n        sp_dict = None\n        if specials_dict_file and os.path.isfile(specials_dict_file):\n            with open(specials_dict_file, \"r\", encoding=\"utf-8\") as f:\n                sp_dict = json.load(f)\n        elif special_tokens_dict is not None:\n            sp_dict = dict(special_tokens_dict)\n\n        if sp_dict is None:\n            # Legacy path: flat list from file or provided/default list\n            if specials_file and os.path.isfile(specials_file):\n                with open(specials_file, \"r\", encoding=\"utf-8\") as f:\n                    sp_list_flat = json.load(f)\n            else:\n                sp_list_flat = special_tokens or SPECIAL_TOKENS\n            sp_dict = {\"all\": list(sp_list_flat)}\n\n        # Ensure \"all\" exists and is unique/preserved in order.\n        if \"all\" not in sp_dict or not isinstance(sp_dict[\"all\"], list):\n            order = [\n                \"system\",\n                \"formatting\",\n                \"math_external\",\n                \"script\",\n                \"layout\",\n                \"reasoning\",\n                \"table_structure\",\n                \"reserved\",\n            ]\n            seen = set()\n            all_tokens: List[str] = []\n            for k in order:\n                if k in sp_dict and isinstance(sp_dict[k], list):\n                    for t in sp_dict[k]:\n                        if t not in seen:\n                            all_tokens.append(t)\n                            seen.add(t)\n            sp_dict[\"all\"] = all_tokens\n\n        # Keep a copy of categories (if present) for downstream processor logic.\n        self.special_tokens = sp_dict\n        sp_list = list(sp_dict.get(\"all\", []))\n        # Regex list should favor longest-first to avoid partial matches.\n        specials_for_regex = sorted(sp_list, key=len, reverse=True)\n\n        self.debug = debug\n        self.UTF16_SPACE = 65536\n        self.math_token_to_rawid = dict(mv)  # 0..M-1\n        self.math_vocab_size = len(self.math_token_to_rawid)\n        self.MATH_BASE = self.UTF16_SPACE\n        self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size\n\n        # Maps\n        self.math_absid_to_token = {\n            self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items()\n        }\n        self.special_tokens_list = sp_list  # ID assignment order\n        self.special_to_absid = {\n            tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list)\n        }\n        self.absid_to_special = {v: k for k, v in self.special_to_absid.items()}\n\n        # Public attributes for legacy/processor:\n        # All specials mapping (token -> absolute id)\n        self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid)\n        # Subset used heavily by processor for quick access\n        self.reverse_special_token_mapping = {\n            v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()\n        }\n        self.LAYOUT_LABEL2ID = {\n            k: v\n            for k, v in self.SPECIAL_TOKEN_MAPPING.items()\n            if k in self.special_tokens[\"layout\"]\n        }\n        self.TABLE_STRUCTURE_LABEL2ID = {\n            k: v\n            for k, v in self.SPECIAL_TOKEN_MAPPING.items()\n            if k in self.special_tokens[\"table_structure\"]\n        }\n        if not self.special_tokens.get(\"system\", []):\n            print(\"Warning: No system tokens found in special_tokens\")\n\n        self.MATH_TAG_START = \"<math\"\n        self.MATH_END_TAG = \"</math>\"\n\n        sys_list = self.special_tokens.get(\"system\", [])\n        self.system_tokens: Dict[str, int] = {\n            t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid\n        }\n\n        # Regex for literal specials\n        self.specials_pattern = (\n            re.compile(r\"(\" + \"|\".join(re.escape(k) for k in specials_for_regex) + r\")\")\n            if specials_for_regex\n            else None\n        )\n\n        # Trie for math greedy match\n        self.trie = self._build_trie(self.math_token_to_rawid)\n\n        # Build escape fix patterns from vocab\n        self.latex_escape_patterns = self._build_escape_patterns(self.math_token_to_rawid)\n\n        # Tell HF about special tokens (metadata)\n        kwargs.setdefault(\"bos_token\", bos_token)\n        kwargs.setdefault(\"eos_token\", eos_token or \"</S>\")\n        kwargs.setdefault(\"pad_token\", pad_token or \"<PAD>\")\n        kwargs.setdefault(\"unk_token\", unk_token)\n\n        super().__init__(\n            vocab_file=vocab_file,\n            specials_file=specials_file,\n            specials_dict_file=specials_dict_file,\n            **kwargs,\n        )\n\n    # ---------- required HF surface ----------\n    @property\n    def vocab_size(self) -> int:\n        return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list)\n\n    def get_vocab(self) -> Dict[str, int]:\n        # Compact vocab: just math+specials with ABSOLUTE ids.\n        v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()}\n        v.update(self.special_to_absid)\n        return v\n\n    def __len__(self) -> int:\n        return self.vocab_size\n\n    # Core encode/decode on ABSOLUTE ids\n    def _encode_core(self, text: str) -> List[int]:\n        text = html.unescape(text)\n        ids: List[int] = []\n        in_math = False\n        chunks = self.specials_pattern.split(text) if self.specials_pattern else [text]\n        for chunk in chunks:\n            if chunk in self.special_to_absid:\n                ids.append(self.special_to_absid[chunk])\n                if chunk.startswith(\"<math\"):\n                    in_math = True\n                elif chunk.startswith(\"</math>\"):\n                    in_math = False\n                if self.debug:\n                    print(f\"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}\")\n                continue\n\n            if in_math:\n                ids.extend(\n                    self._encode_math_greedy(\n                        chunk, self.trie, self.MATH_BASE, debug=self.debug\n                    )\n                )\n            else:\n                units = self._to_utf16_units(chunk)\n                if self.debug and units:\n                    print(\n                        f\"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}\"\n                    )\n                ids.extend(units)\n        return ids\n\n    def _fix_latex_escapes(self, text: str) -> str:\n        \"\"\"Fix improperly escaped LaTeX commands in decoded text.\n\n        Operates on the complete decoded string, replacing control character\n        sequences with their intended LaTeX commands based on vocab patterns.\n        \"\"\"\n        result = []\n        i = 0\n        while i < len(text):\n            char = text[i]\n            if char in self.latex_escape_patterns:\n                # Check if any pattern matches\n                matched = False\n                for pattern, replacement in self.latex_escape_patterns[char]:\n                    if text[i+1:].startswith(pattern):\n                        result.append(replacement)\n                        i += 1 + len(pattern)\n                        matched = True\n                        break\n                if not matched:\n                    # Not a LaTeX command, keep the control char as-is\n                    result.append(char)\n                    i += 1\n            else:\n                result.append(char)\n                i += 1\n\n        return ''.join(result)\n\n    def _decode_core(self, ids: Iterable[int]) -> str:\n        out: List[str] = []\n        buf: List[int] = []\n\n        def flush():\n            if buf:\n                out.append(self._from_utf16_units(buf))\n                buf.clear()\n\n        for tid in ids:\n            if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE:\n                flush()\n                out.append(self.math_absid_to_token.get(tid, \"\"))\n            elif tid >= self.SPECIAL_BASE:\n                flush()\n                out.append(self.absid_to_special.get(tid, \"\"))\n            else:\n                buf.append(int(tid))\n        flush()\n        decoded = \"\".join(out)\n        return self._fix_latex_escapes(decoded)\n\n    # ---- Tokenizer interface ----\n    def _tokenize(self, text: str, **kwargs) -> List[str]:\n        ids = self._encode_core(text)\n        toks: List[str] = []\n        for i in ids:\n            if i < self.MATH_BASE:\n                toks.append(f\"<U+{i:04X}>\")\n            elif i < self.SPECIAL_BASE:\n                toks.append(self.math_absid_to_token.get(i, \"<UNK_MATH>\"))\n            else:\n                toks.append(self.absid_to_special.get(i, \"<UNK_SPECIAL>\"))\n        return toks\n\n    def _convert_token_to_id(self, token: str) -> int:\n        if token.startswith(\"<U+\") and token.endswith(\">\"):\n            try:\n                return int(token[3:-1], 16)  # UTF-16 unit\n            except Exception:\n                return self.unk_token_id if self.unk_token_id is not None else 0\n        # math or specials\n        if token in self.math_token_to_rawid:\n            return self.MATH_BASE + self.math_token_to_rawid[token]\n        if token in self.special_to_absid:\n            return self.special_to_absid[token]\n        # rare path: single-char token -> its UTF-16 unit\n        if len(token) == 1:\n            u = self._to_utf16_units(token)\n            if len(u) == 1:\n                return u[0]\n        return self.unk_token_id if self.unk_token_id is not None else 0\n\n    def _convert_id_to_token(self, index: int) -> str:\n        if index < self.MATH_BASE:\n            return f\"<U+{index:04X}>\"\n        if index < self.SPECIAL_BASE:\n            return self.math_absid_to_token.get(index, \"<UNK_MATH>\")\n        return self.absid_to_special.get(index, \"<UNK_SPECIAL>\")\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        ids = [self._convert_token_to_id(t) for t in tokens]\n        return self._decode_core(ids)\n\n    def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str:\n        # Accept int, list, tuple, numpy, torch\n        if hasattr(token_ids, \"tolist\"):\n            token_ids = token_ids.tolist()\n        elif isinstance(token_ids, int):\n            token_ids = [token_ids]\n        else:\n            token_ids = list(token_ids)\n        token_ids = [int(i) for i in token_ids]  # normalize early\n\n        if skip_special_tokens:\n            token_ids = [i for i in token_ids if i < self.SPECIAL_BASE]\n        return self._decode_core(token_ids)\n\n    # HF plumbing\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        out = (\n            list(token_ids_0)\n            if token_ids_1 is None\n            else list(token_ids_0) + list(token_ids_1)\n        )\n        # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id):\n        #     out.append(self.eos_token_id)\n        return out\n\n    def get_special_tokens_mask(\n        self,\n        token_ids_0: List[int],\n        token_ids_1: Optional[List[int]] = None,\n        already_has_special_tokens: bool = False,\n    ) -> List[int]:\n        def mask(seq: List[int]) -> List[int]:\n            return [1 if i >= self.SPECIAL_BASE else 0 for i in seq]\n\n        return (\n            mask(token_ids_0)\n            if token_ids_1 is None\n            else mask(token_ids_0) + mask(token_ids_1)\n        )\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        return [0] * (\n            len(token_ids_0)\n            if token_ids_1 is None\n            else len(token_ids_0) + len(token_ids_1)\n        )\n\n    # Save/load raw assets\n    def save_vocabulary(\n        self, save_directory: str, filename_prefix: Optional[str] = None\n    ) -> Tuple[str, str]:\n        os.makedirs(save_directory, exist_ok=True)\n        pre = (filename_prefix + \"-\") if filename_prefix else \"\"\n        vocab_path = os.path.join(\n            save_directory, pre + self.vocab_files_names[\"vocab_file\"]\n        )\n        specials_path = os.path.join(\n            save_directory, pre + self.vocab_files_names[\"specials_file\"]\n        )\n        specials_dict_path = os.path.join(\n            save_directory, pre + self.vocab_files_names[\"specials_dict_file\"]\n        )\n        with open(vocab_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2)\n        # Save both the flat list (\"all\") and the category dict (preferred)\n        with open(specials_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2)\n        with open(specials_dict_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(self.special_tokens, f, ensure_ascii=False, indent=2)\n        return (vocab_path, specials_path)\n\n\nclass SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):\n    def __init__(\n        self,\n        special_tokens: Dict[str, list] | None = None,\n        model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT,\n        **kwargs,\n    ):\n        if special_tokens is None:\n            special_tokens = dict()\n\n        self.special_tokens = special_tokens\n\n        self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained(\n            model_checkpoint,\n        )\n        self.system_tokens = {\n            v: self.ocr_tokenizer(v)[\"input_ids\"][0]\n            for v in special_tokens.get(\"system\", [])\n        }\n        self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING\n\n        super().__init__(**kwargs)\n\n    def get_vocab(self) -> Dict[str, int]:\n        return self.ocr_tokenizer.get_vocab()\n\n    def _add_tokens(\n        self,\n        new_tokens: Union[List[str], List[AddedToken]],\n        special_tokens: bool = False,\n    ) -> int:\n        return self.ocr_tokenizer._add_tokens(\n            new_tokens, special_tokens=special_tokens\n        )\n\n    @property\n    def vocab_size(self):\n        return self.ocr_tokenizer.vocab_size\n\n    def _tokenize(self, text: str, **kwargs):\n        # task = kwargs.get(\"task\", TaskNames.ocr_with_boxes)\n        # assert task in TASK_NAMES, f\"Invalid task: {task}\"\n\n        tokens = self.ocr_tokenizer(text)[\"input_ids\"]\n\n        return tokens\n\n    def __call__(\n        self,\n        texts: Union[str, List[str]],\n        tasks: Union[str, List[str]] = None,\n        **kwargs,\n    ) -> Dict[str, List[List[int]]]:\n        \"\"\"Tokenizes text and returns input IDs.\"\"\"\n        tokenized = []\n\n        if isinstance(texts, str):\n            texts = [texts]\n            assert isinstance(tasks, str), \"Tasks must be a string if texts is a string\"\n            tasks = [tasks]\n\n        if isinstance(texts, list):\n            assert isinstance(tasks, list), \"Tasks must be a list if texts is a list\"\n\n        for text, task in zip(texts, tasks):\n            tokens = self._tokenize(text, task=task)\n            tokenized.append(tokens)\n\n        return {\"input_ids\": tokenized}\n\n    def decode(self, token_ids, **kwargs):\n        if isinstance(token_ids, (np.ndarray, torch.Tensor)):\n            token_ids = token_ids.tolist()\n\n        decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False)\n        # replace all <SCRIPT-...> tokens with empty strings\n        decoded_text = re.sub(r\"<SCRIPT-.*?>\", \"\", decoded_text)\n        # replace </S> with empty string\n        decoded_text = re.sub(r\"</S>\", \"\", decoded_text)\n        return decoded_text\n"
  },
  {
    "path": "surya/common/surya/schema.py",
    "content": "class TaskNames:\n    block_without_boxes = \"block_without_boxes\"\n    ocr_with_boxes = \"ocr_with_boxes\"\n    ocr_without_boxes = \"ocr_without_boxes\"\n    layout = \"layout\"\n    table_structure = \"table_structure\"\n\n\nTASK_NAMES = [\n    TaskNames.block_without_boxes,\n    TaskNames.ocr_with_boxes,\n    TaskNames.ocr_without_boxes,\n    TaskNames.layout,\n    TaskNames.table_structure,\n]\n"
  },
  {
    "path": "surya/common/util.py",
    "content": "import copy\nfrom typing import List\nimport torch\nfrom functools import lru_cache\n\nimport torch.nn.functional as F\n\nfrom surya.common.polygon import PolygonBox\n\n\ndef clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:\n    new_boxes = []\n    for box_obj in boxes:\n        xs = [point[0] for point in box_obj.polygon]\n        ys = [point[1] for point in box_obj.polygon]\n        if max(xs) == min(xs) or max(ys) == min(ys):\n            continue\n\n        box = box_obj.bbox\n        contained = False\n        for other_box_obj in boxes:\n            if other_box_obj.polygon == box_obj.polygon:\n                continue\n\n            other_box = other_box_obj.bbox\n            if box == other_box:\n                continue\n            if (\n                box[0] >= other_box[0]\n                and box[1] >= other_box[1]\n                and box[2] <= other_box[2]\n                and box[3] <= other_box[3]\n            ):\n                contained = True\n                break\n        if not contained:\n            new_boxes.append(box_obj)\n    return new_boxes\n\n\ndef rescale_bbox(bbox, processor_size, image_size):\n    page_width, page_height = processor_size\n\n    img_width, img_height = image_size\n    width_scaler = img_width / page_width\n    height_scaler = img_height / page_height\n\n    new_bbox = copy.deepcopy(bbox)\n    new_bbox[0] = int(new_bbox[0] * width_scaler)\n    new_bbox[1] = int(new_bbox[1] * height_scaler)\n    new_bbox[2] = int(new_bbox[2] * width_scaler)\n    new_bbox[3] = int(new_bbox[3] * height_scaler)\n    return new_bbox\n\n\ndef expand_bbox(bbox, expansion_factor=0.01):\n    expansion_low = 1 - expansion_factor\n    expansion_high = 1 + expansion_factor\n    return [\n        bbox[0] * expansion_low,\n        bbox[1] * expansion_low,\n        bbox[2] * expansion_high,\n        bbox[3] * expansion_high,\n    ]\n\nSCRIPT_TOKEN_MAPPING = {\n    \"latin\": \"<SCRIPT-LATIN>\",\n    \"punctuation\": \"<SCRIPT-PUNCTUATION>\",\n    \"cyrillic\": \"<SCRIPT-CYRILLIC>\",\n    \"arabic\": \"<SCRIPT-ARABIC>\",\n    \"chinese\": \"<SCRIPT-CHINESE>\",\n    \"japanese\": \"<SCRIPT-JAPANESE>\",\n    \"korean\": \"<SCRIPT-KOREAN>\",\n    \"symbols\": \"<SCRIPT-SYMBOLS>\",\n    \"greek\": \"<SCRIPT-GREEK>\",\n    \"armenian\": \"<SCRIPT-ARMENIAN>\",\n    \"hebrew\": \"<SCRIPT-HEBREW>\",\n    \"devanagari\": \"<SCRIPT-DEVANAGARI>\",\n    \"bengali\": \"<SCRIPT-BENGALI>\",\n    \"gurmukhi\": \"<SCRIPT-GURMUKHI>\",\n    \"gujarati\": \"<SCRIPT-GUJARATI>\",\n    \"oriya\": \"<SCRIPT-ORIYA>\",\n    \"tamil\": \"<SCRIPT-TAMIL>\",\n    \"telugu\": \"<SCRIPT-TELUGU>\",\n    \"kannada\": \"<SCRIPT-KANNADA>\",\n    \"malayalam\": \"<SCRIPT-MALAYALAM>\",\n    \"sinhala\": \"<SCRIPT-SINHALA>\",\n    \"thai\": \"<SCRIPT-THAI>\",\n    \"lao\": \"<SCRIPT-LAO>\",\n    \"myanmar\": \"<SCRIPT-MYANMAR>\",\n    \"georgian\": \"<SCRIPT-GEORGIAN>\",\n    \"ethiopic\": \"<SCRIPT-ETHIOPIC>\",\n    \"khmer\": \"<SCRIPT-KHMER>\",\n    \"mongolian\": \"<SCRIPT-MONGOLIAN>\",\n    \"math\": \"<SCRIPT-MATH>\",\n}\n\n@lru_cache(maxsize=1)\ndef script_ranges():\n    script_categories = {\n        # Latin-based scripts (used by English, French, German, etc.)\n        \"latin\": [\n            (0x0041, 0x005A),  # Latin uppercase A-Z\n            (0x0061, 0x007A),  # Latin lowercase a-z\n            (0x0080, 0x00FF),  # Latin-1 Supplement\n            (0x0100, 0x017F),  # Latin Extended-A\n            (0x0180, 0x024F),  # Latin Extended-B\n            (0x0250, 0x02AF),  # IPA Extensions\n            (0x02B0, 0x02FF),  # Spacing Modifier Letters\n            (0x0300, 0x036F),  # Combining Diacritical Marks\n            (0x1E00, 0x1EFF),  # Latin Extended Additional\n            (0x2C60, 0x2C7F),  # Latin Extended-C\n            (0xA720, 0xA7FF),  # Latin Extended-D\n        ],\n        # Punctuation, universal characters, and general symbols\n        \"punctuation\": [\n            (0x0020, 0x0020),  # Space\n            (0x0021, 0x002F),  # Basic punctuation and symbols\n            (0x0030, 0x0039),  # Digits 0-9\n            (0x003A, 0x0040),  # More punctuation and symbols\n            (0x005B, 0x0060),  # More punctuation and symbols\n            (0x007B, 0x007F),  # More punctuation and symbols\n            (0x2000, 0x206F),  # General Punctuation\n        ],\n        # Cyrillic scripts (used by Russian, Ukrainian, etc.)\n        \"cyrillic\": [\n            (0x0400, 0x04FF),  # Cyrillic\n            (0x0500, 0x052F),  # Cyrillic Supplement\n        ],\n        # Arabic scripts\n        \"arabic\": [\n            (0x0600, 0x06FF),  # Arabic\n            (0x0750, 0x077F),  # Arabic Supplement\n            (0x08A0, 0x08FF),  # Arabic Extended-A\n        ],\n        # Chinese characters\n        \"chinese\": [\n            (0x4E00, 0x9FFF),  # Common CJK Unified Ideographs\n            (0x3400, 0x4DBF),  # CJK Extension A\n            (0x20000, 0x2A6DF),  # CJK Extension B\n        ],\n        # Japanese-specific scripts (excluding shared CJK)\n        \"japanese\": [\n            (0x3040, 0x30FF),  # Hiragana and Katakana\n        ],\n        # Korean-specific scripts\n        \"korean\": [\n            (0x1100, 0x11FF),  # Hangul Jamo\n            (0x3130, 0x318F),  # Hangul Compatibility Jamo\n            (0xAC00, 0xD7AF),  # Hangul Syllables\n        ],\n        # Various mathematical and technical symbols\n        \"symbols\": [\n            (0x2070, 0x209F),  # Superscripts and Subscripts\n            (0x20A0, 0x20CF),  # Currency Symbols\n            (0x2100, 0x214F),  # Letterlike Symbols\n            (0x2150, 0x218F),  # Number Forms\n            (0x2190, 0x21FF),  # Arrows\n            (0x2200, 0x22FF),  # Mathematical Operators\n            (0x2300, 0x23FF),  # Miscellaneous Technical\n            (0x2500, 0x257F),  # Box Drawing\n            (0x2580, 0x259F),  # Block Elements\n            (0x25A0, 0x25FF),  # Geometric Shapes\n            (0x2600, 0x26FF),  # Miscellaneous Symbols\n            (0x2700, 0x27BF),  # Dingbats\n            (0x27C0, 0x27EF),  # Miscellaneous Mathematical Symbols-A\n            (0x2980, 0x29FF),  # Miscellaneous Mathematical Symbols-B\n            (0x2A00, 0x2AFF),  # Supplemental Mathematical Operators\n            (0x1D400, 0x1D7FF),  # Mathematical Alphanumeric Symbols\n        ],\n        # Individual scripts for languages with unique writing systems\n        \"greek\": [(0x0370, 0x03FF)],  # Greek and Coptic\n        \"armenian\": [(0x0530, 0x058F)],  # Armenian\n        \"hebrew\": [(0x0590, 0x05FF)],  # Hebrew\n        \"devanagari\": [(0x0900, 0x097F)],  # Devanagari (Hindi, Sanskrit)\n        \"bengali\": [(0x0980, 0x09FF)],  # Bengali\n        \"gurmukhi\": [(0x0A00, 0x0A7F)],  # Gurmukhi (Punjabi)\n        \"gujarati\": [(0x0A80, 0x0AFF)],  # Gujarati\n        \"oriya\": [(0x0B00, 0x0B7F)],  # Oriya\n        \"tamil\": [(0x0B80, 0x0BFF)],  # Tamil\n        \"telugu\": [(0x0C00, 0x0C7F)],  # Telugu\n        \"kannada\": [(0x0C80, 0x0CFF)],  # Kannada\n        \"malayalam\": [(0x0D00, 0x0D7F)],  # Malayalam\n        \"sinhala\": [(0x0D80, 0x0DFF)],  # Sinhala\n        \"thai\": [(0x0E00, 0x0E7F)],  # Thai\n        \"lao\": [(0x0E80, 0x0EFF)],  # Lao\n        \"myanmar\": [(0x1000, 0x109F)],  # Myanmar\n        \"georgian\": [(0x10A0, 0x10FF)],  # Georgian\n        \"ethiopic\": [(0x1200, 0x137F)],  # Ethiopic\n        \"khmer\": [(0x1780, 0x17FF)],  # Khmer\n        \"mongolian\": [(0x1800, 0x18AF)],  # Mongolian\n    }\n\n    # Convert to a flat structure with character ranges\n    flat_ranges = {}\n    for category, ranges in script_categories.items():\n        # Create a set of all characters in this category\n        char_set = set()\n        for start, end in ranges:\n            char_set.update(range(start, end + 1))\n\n        # Store the set in flat_ranges\n        flat_ranges[category] = char_set\n\n    return script_categories, flat_ranges\n\ndef get_top_scripts(text: str, max_scripts: int = 5):\n    script_categories, flat_ranges = script_ranges()\n    char_count = {category: 0 for category in script_categories.keys()}\n    for char in text:\n        for category, char_set in flat_ranges.items():\n            if ord(char) in char_set:\n                char_count[category] += 1\n                break\n\n    top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True)\n    top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0]\n    if \"<math\" in text:\n        top_scripts.insert(0, \"math\")\n\n    return top_scripts[:max_scripts]\n\ndef is_flash_attn_2_supported(device: str | torch.device) -> bool:\n    if not torch.cuda.is_available():\n        return False\n\n    if \"cuda\" not in str(device):\n        return False\n\n    # Check CUDA version >= 12.0\n    cuda_version_str = torch.version.cuda\n    if cuda_version_str is None:\n        return False\n    cuda_version = tuple(map(int, cuda_version_str.split(\".\")))\n    if cuda_version < (12, 0):\n        return False\n\n    # Check GPU compute capability (Ampere, Ada, Hopper GPUs)\n    major, minor = torch.cuda.get_device_capability()\n    compute_capability = major + minor / 10\n    if compute_capability < 8.0:\n        return False\n\n    return True\n\n\ndef pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int):\n    current_batch_size = tensor.shape[0]\n    if current_batch_size >= batch_size:\n        return tensor\n\n    pad_size = batch_size - current_batch_size\n    if pad_size < 0:\n        return tensor\n\n    # Repeat the last row pad_size times\n    last_row = tensor[-1:].repeat(pad_size, 1, 1)\n\n    # Concatenate original tensor with repeated last rows\n    return torch.cat([tensor, last_row], dim=0)\n\n\ndef pad_to_batch_size(tensor: torch.Tensor, batch_size: int):\n    current_batch_size = tensor.shape[0]\n    if current_batch_size >= batch_size:\n        return tensor\n\n    pad_size = batch_size - current_batch_size\n    padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)\n\n    return F.pad(tensor, padding, mode=\"constant\", value=0)\n"
  },
  {
    "path": "surya/common/xla.py",
    "content": "import math\nfrom surya.settings import settings\n\nif settings.TORCH_DEVICE_MODEL == \"xla\":\n    import torch_xla.core.xla_model as xm\nelse:\n    xm = None\n\n\ndef get_nearest_pad(\n    length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST\n):\n    return math.ceil(length / pad_multiple) * pad_multiple\n\n\ndef get_compile_args(device: str) -> dict:\n    if not settings.FOUNDATION_XLA:\n        return {}\n\n    return {\n        \"backend\": \"openxla\",\n    }\n\n\ndef mark_step():\n    if xm is not None:\n        xm.mark_step()\n"
  },
  {
    "path": "surya/debug/draw.py",
    "content": "from PIL import ImageDraw, ImageFont\n\nfrom surya.debug.fonts import get_font_path\nfrom surya.debug.text import get_text_size\n\n\ndef draw_bboxes_on_image(\n    bboxes, image, labels=None, label_font_size=10, color: str | list = \"red\"\n):\n    polys = []\n    for bb in bboxes:\n        # Clockwise polygon\n        poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]\n        polys.append(poly)\n\n    return draw_polys_on_image(\n        polys, image, labels, label_font_size=label_font_size, color=color\n    )\n\n\ndef draw_polys_on_image(\n    corners,\n    image,\n    labels=None,\n    box_padding=-1,\n    label_offset=1,\n    label_font_size=10,\n    color: str | list = \"red\",\n):\n    draw = ImageDraw.Draw(image)\n    font_path = get_font_path()\n    label_font = ImageFont.truetype(font_path, label_font_size)\n\n    for i in range(len(corners)):\n        poly = corners[i]\n        poly = [(int(p[0]), int(p[1])) for p in poly]\n        draw.polygon(\n            poly, outline=color[i] if isinstance(color, list) else color, width=1\n        )\n\n        if labels is not None:\n            label = labels[i]\n            text_position = (\n                min([p[0] for p in poly]) + label_offset,\n                min([p[1] for p in poly]) + label_offset,\n            )\n            text_size = get_text_size(label, label_font)\n            box_position = (\n                text_position[0] - box_padding + label_offset,\n                text_position[1] - box_padding + label_offset,\n                text_position[0] + text_size[0] + box_padding + label_offset,\n                text_position[1] + text_size[1] + box_padding + label_offset,\n            )\n            try:\n                draw.rectangle(box_position, fill=\"white\")\n            except Exception as e:\n                print(f\"Error drawing rectangle at {box_position}: {e}\")\n                continue\n            draw.text(\n                text_position,\n                label,\n                fill=color[i] if isinstance(color, list) else color,\n                font=label_font,\n            )\n\n    return image\n"
  },
  {
    "path": "surya/debug/fonts.py",
    "content": "from typing import List, Optional\nimport os\nimport requests\n\nfrom surya.settings import settings\n\n\ndef get_font_path(langs: Optional[List[str]] = None) -> str:\n    font_path = settings.RECOGNITION_RENDER_FONTS[\"all\"]\n    if langs is not None:\n        for k in settings.RECOGNITION_RENDER_FONTS:\n            if k in langs and len(langs) == 1:\n                font_path = settings.RECOGNITION_RENDER_FONTS[k]\n                break\n\n    if not os.path.exists(font_path):\n        os.makedirs(os.path.dirname(font_path), exist_ok=True)\n        font_dl_path = f\"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}\"\n        with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f:\n            r.raise_for_status()\n            for chunk in r.iter_content(chunk_size=8192):\n                f.write(chunk)\n\n    return font_path"
  },
  {
    "path": "surya/debug/katex.js",
    "content": "<style>\n    .katex-display-container {\n        display: inline-block;\n        max-width: 100%;\n        overflow-x: auto;\n        max-height: 100%;\n    }\n\n    .katex-inline-container {\n        display: inline-block;\n        max-width: 100%;\n        overflow-x: auto;\n        max-height: 100%;\n    }\n</style>\n<script src=\"https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.js\" onload=\"setTimeout(function() {renderMath()})\" async></script>\n<link rel=\"stylesheet\" href=\"https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.css\">\n<script>\n    function htmlUnescape(escapedText) {\n      const htmlEntities = {\n        '&amp;': '&',\n        '&lt;': '<',\n        '&gt;': '>',\n        '&quot;': '\"',\n        '&#39;': \"'\",\n        '&nbsp;': ' '\n      };\n\n      return escapedText.replace(/&amp;|&lt;|&gt;|&quot;|&#39;|&nbsp;/g, match => htmlEntities[match]);\n    }\n\n    const renderMath = (function() {\n    try {\n       const mathElements = document.querySelectorAll('math');\n\n        mathElements.forEach(function(element) {\n          let mathContent = element.innerHTML.trim();\n          mathContent = htmlUnescape(mathContent);\n          const isDisplay = element.getAttribute('display') === 'block';\n\n          const container = document.createElement('span');\n          container.className = isDisplay ? 'katex-display-container' : 'katex-inline-container';\n          element.parentNode.insertBefore(container, element);\n\n          try {\n            katex.render(mathContent, container, {\n              displayMode: isDisplay,\n              throwOnError: false\n            });\n\n          } catch (err) {\n            console.error('KaTeX rendering error:', err);\n            container.textContent = mathContent; // Fallback to raw text\n          }\n\n          element.parentNode.removeChild(element);\n        });\n\n        console.log('Math rendering complete with', mathElements.length, 'expressions');\n      } catch (err) {\n        console.error('Error in renderMath function:', err);\n      }\n    });\n</script>"
  },
  {
    "path": "surya/debug/render_html.py",
    "content": "import html as htmllib\nimport os.path\nimport re\n\nfilepath = os.path.abspath(__file__)\n\ndef render_text_as_html(\n        bboxes: list[list[int]],\n        texts: list[str],\n        image_size: tuple[int, int],\n        base_font_size: int = 16,\n        scaler: int = 2\n):\n    katex_path = os.path.join(os.path.dirname(filepath), \"katex.js\")\n    with open(katex_path, \"r\") as f:\n        katex_script = f.read()\n\n    html_content = []\n    image_size = tuple([int(s * scaler) for s in image_size])\n    width, height = image_size\n\n\n    html_content.append(f\"\"\"\n<!DOCTYPE html>\n<html>\n<head>\n    <style>\n        body {{\n            margin: 0;\n            padding: 0;\n            width: {width}px;\n            height: {height}px;\n            position: relative;\n            overflow: hidden;\n            background: white;\n            color: black;\n        }}\n        .text-box {{\n            position: absolute;\n            overflow: hidden;\n            display: flex;\n            justify-content: left;\n            font-family: Arial, sans-serif;\n            white-space: pre-wrap;\n        }}\n        .vertical-text {{\n          writing-mode: vertical-rl;  /* Top to bottom, right to left */\n        }}\n    </style>\n    {katex_script}\n</head>\n<body>\n\"\"\")\n\n    for i, (bbox, text) in enumerate(zip(bboxes, texts)):\n        bbox = bbox.copy()\n        bbox = [int(bb * scaler) for bb in bbox]\n        x1, y1, x2, y2 = bbox\n        width = x2 - x1\n        height = y2 - y1\n        min_dim = min(width, height)\n\n        # Scale font size based on box height\n        font_size = min(int(min_dim * 0.75), base_font_size)\n\n        # Create div with absolute positioning\n        div_style = (\n            f\"left: {x1}px; \"\n            f\"top: {y1}px; \"\n            f\"width: {width}px; \"\n            f\"height: {height}px; \"\n            f\"font-size: {font_size}px;\"\n        )\n\n        class_ = \"text-box\"\n        if height > width * 2:\n            class_ += \" vertical-text\"\n\n        # Determine if content is HTML/MathML or plain text\n        if \"<\" in text and \">\" in text and re.search(r\"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\\b\", text.lower()):\n            # Content is already HTML/MathML, include as-is\n            html_content.append(f'<span class=\"{class_}\" id=\"box-{i}\" style=\"{div_style}\">{text}</span>')\n        else:\n            # Plain text, escape it\n            escaped_text = htmllib.escape(text)\n            html_content.append(f'<span class=\"{class_}\" id=\"box-{i}\" style=\"{div_style}\">{escaped_text}</span>')\n\n    html_content.append(\"</body></html>\")\n\n    return \"\\n\".join(html_content), image_size"
  },
  {
    "path": "surya/debug/text.py",
    "content": "import re\nfrom io import BytesIO\nfrom typing import List, Tuple\nfrom PIL import Image, ImageDraw, ImageFont\n\nfrom surya.debug.fonts import get_font_path\nfrom surya.debug.render_html import render_text_as_html\n\ntry:\n    from playwright.sync_api import sync_playwright\n\n    has_playwright = True\nexcept ImportError:\n    has_playwright = False\n\n\ndef strip_html_tags(html_text):\n    pattern = re.compile(r\"<[\\w/][^>]*>\")\n    text_only = pattern.sub(\"\", html_text)\n\n    return text_only\n\n\ndef get_text_size(text, font):\n    im = Image.new(mode=\"P\", size=(0, 0))\n    draw = ImageDraw.Draw(im)\n    _, _, width, height = draw.textbbox((0, 0), text=text, font=font)\n    return width, height\n\n\ndef render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):\n    font = ImageFont.truetype(font_path, box_font_size)\n    text_width, text_height = get_text_size(text, font)\n    while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:\n        box_font_size = box_font_size - 1\n        font = ImageFont.truetype(font_path, box_font_size)\n        text_width, text_height = get_text_size(text, font)\n\n    # Calculate text position (centered in bbox)\n    text_width, text_height = get_text_size(text, font)\n    x = s_bbox[0]\n    y = s_bbox[1] + (bbox_height - text_height) / 2\n\n    draw.text((x, y), text, fill=\"black\", font=font)\n\n\ndef draw_text_with_playwright(\n    bboxes, texts: List[str], image_size: Tuple[int, int]\n) -> Image.Image:\n    html_content, image_size = render_text_as_html(bboxes, texts, image_size)\n    if not has_playwright:\n        raise ImportError(\n            \"Playwright is not installed. Please install it using `pip install playwright`\"\n        )\n\n    with sync_playwright() as p:\n        browser = p.chromium.launch(headless=True)\n        page = browser.new_page(\n            viewport={\"width\": image_size[0], \"height\": image_size[1]}\n        )\n        page.set_content(html_content)\n        page.wait_for_timeout(1000)\n        body = page.query_selector(\"body\")\n        image = body.screenshot()\n        browser.close()\n\n    pil_img = Image.open(BytesIO(image))\n    return pil_img\n\n\ndef draw_text_on_image(\n    bboxes,\n    texts,\n    image_size: Tuple[int, int],\n    font_path=None,\n    max_font_size=60,\n    res_upscale=2,\n) -> Image.Image:\n    if has_playwright:\n        return draw_text_with_playwright(bboxes, texts, image_size)\n\n    texts = [strip_html_tags(text) for text in texts]\n    if font_path is None:\n        font_path = get_font_path()\n    new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)\n    image = Image.new(\"RGB\", new_image_size, color=\"white\")\n    draw = ImageDraw.Draw(image)\n\n    for bbox, text in zip(bboxes, texts):\n        s_bbox = [int(coord * res_upscale) for coord in bbox]\n        bbox_width = s_bbox[2] - s_bbox[0]\n        bbox_height = s_bbox[3] - s_bbox[1]\n\n        # Shrink the text to fit in the bbox if needed\n        box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size))\n        render_text(\n            draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size\n        )\n\n    return image\n"
  },
  {
    "path": "surya/detection/__init__.py",
    "content": "from concurrent.futures import ThreadPoolExecutor\nfrom typing import List, Generator, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom surya.common.predictor import BasePredictor\nfrom surya.common.xla import mark_step\n\nfrom surya.detection.loader import DetectionModelLoader\nfrom surya.detection.parallel import FakeExecutor\nfrom surya.detection.util import get_total_splits, split_image\nfrom surya.detection.schema import TextDetectionResult\nfrom surya.settings import settings\nfrom surya.detection.heatmap import parallel_get_boxes\n\n\nclass DetectionPredictor(BasePredictor):\n    model_loader_cls = DetectionModelLoader\n    batch_size = settings.DETECTOR_BATCH_SIZE\n    default_batch_sizes = {\"cpu\": 8, \"mps\": 8, \"cuda\": 36, \"xla\": 18}\n\n    def __call__(\n        self, images: List[Image.Image], batch_size=None, include_maps=False\n    ) -> List[TextDetectionResult]:\n        detection_generator = self.batch_detection(\n            images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE\n        )\n\n        postprocessing_futures = []\n        max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))\n        parallelize = (\n            not settings.IN_STREAMLIT\n            and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH\n        )\n        executor = ThreadPoolExecutor if parallelize else FakeExecutor\n        with executor(max_workers=max_workers) as e:\n            for preds, orig_sizes in detection_generator:\n                for pred, orig_size in zip(preds, orig_sizes):\n                    postprocessing_futures.append(\n                        e.submit(parallel_get_boxes, pred, orig_size, include_maps)\n                    )\n\n        return [future.result() for future in postprocessing_futures]\n\n    def prepare_image(self, img):\n        new_size = (self.processor.size[\"width\"], self.processor.size[\"height\"])\n\n        # This double resize actually necessary for downstream accuracy\n        img.thumbnail(new_size, Image.Resampling.LANCZOS)\n        img = img.resize(\n            new_size, Image.Resampling.LANCZOS\n        )  # Stretch smaller dimension to fit new size\n\n        img = np.asarray(img, dtype=np.uint8)\n        img = self.processor(img)[\"pixel_values\"][0]\n        img = torch.from_numpy(img)\n        return img\n\n    def batch_detection(\n        self, images: List, batch_size=None, static_cache=False\n    ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:\n        assert all([isinstance(image, Image.Image) for image in images])\n        if batch_size is None:\n            batch_size = self.get_batch_size()\n        heatmap_count = self.model.config.num_labels\n\n        orig_sizes = [image.size for image in images]\n        splits_per_image = [\n            get_total_splits(size, self.processor.size[\"height\"]) for size in orig_sizes\n        ]\n\n        batches = []\n        current_batch_size = 0\n        current_batch = []\n        for i in range(len(images)):\n            if current_batch_size + splits_per_image[i] > batch_size:\n                if len(current_batch) > 0:\n                    batches.append(current_batch)\n                current_batch = []\n                current_batch_size = 0\n            current_batch.append(i)\n            current_batch_size += splits_per_image[i]\n\n        if len(current_batch) > 0:\n            batches.append(current_batch)\n\n        for batch_idx in tqdm(\n            range(len(batches)), desc=\"Detecting bboxes\", disable=self.disable_tqdm\n        ):\n            batch_image_idxs = batches[batch_idx]\n            batch_images = [images[j].convert(\"RGB\") for j in batch_image_idxs]\n\n            split_index = []\n            split_heights = []\n            image_splits = []\n            for image_idx, image in enumerate(batch_images):\n                image_parts, split_height = split_image(\n                    image, self.processor.size[\"height\"]\n                )\n                image_splits.extend(image_parts)\n                split_index.extend([image_idx] * len(image_parts))\n                split_heights.extend(split_height)\n\n            image_splits = [self.prepare_image(image) for image in image_splits]\n            # Batch images in dim 0\n            batch = torch.stack(image_splits, dim=0).to(self.model.dtype)\n            if static_cache:\n                batch = self.pad_to_batch_size(batch, batch_size)\n\n            with settings.INFERENCE_MODE():\n                pred = self.model(\n                    pixel_values=batch.to(self.model.device)\n                )  # Moving the to device here fixes issues with xla recompilation\n\n            logits = pred.logits\n            correct_shape = [\n                self.processor.size[\"height\"],\n                self.processor.size[\"width\"],\n            ]\n            current_shape = list(logits.shape[2:])\n            if current_shape != correct_shape:\n                logits = F.interpolate(\n                    logits, size=correct_shape, mode=\"bilinear\", align_corners=False\n                )\n            mark_step()\n\n            logits = logits.to(torch.float32).cpu().numpy()\n            preds = []\n            for i, (idx, height) in enumerate(zip(split_index, split_heights)):\n                # If our current prediction length is below the image idx, that means we have a new image\n                # Otherwise, we need to add to the current image\n                if len(preds) <= idx:\n                    preds.append([logits[i][k] for k in range(heatmap_count)])\n                else:\n                    heatmaps = preds[idx]\n                    pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]\n\n                    if height < self.processor.size[\"height\"]:\n                        # Cut off padding to get original height\n                        pred_heatmaps = [\n                            pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps\n                        ]\n\n                    for k in range(heatmap_count):\n                        heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])\n                    preds[idx] = heatmaps\n\n            yield preds, [orig_sizes[j] for j in batch_image_idxs]\n\n        torch.cuda.empty_cache()\n"
  },
  {
    "path": "surya/detection/heatmap.py",
    "content": "from typing import List\n\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom surya.common.util import clean_boxes\nfrom surya.detection import TextDetectionResult\nfrom surya.common.polygon import PolygonBox\nfrom surya.settings import settings\n\n\ndef get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7):\n    # Find average intensity of top 10% pixels\n    flat_map = linemap.ravel()\n    top_10_count = int(len(flat_map) * 0.9)\n    avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:])\n    scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2)\n\n    low_text = np.clip(low_text * scaling_factor, 0.1, 0.6)\n    text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8)\n\n    return text_threshold, low_text\n\n\ndef detect_boxes(linemap, text_threshold, low_text):\n    # From CRAFT - https://github.com/clovaai/CRAFT-pytorch\n    # Modified to return boxes and for speed, accuracy\n    img_h, img_w = linemap.shape\n\n    text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text)\n\n    text_score_comb = (linemap > low_text).astype(np.uint8)\n    label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(\n        text_score_comb, connectivity=4\n    )\n\n    det = []\n    confidences = []\n    max_confidence = 0\n\n    for k in range(1, label_count):\n        # size filtering\n        size = stats[k, cv2.CC_STAT_AREA]\n        if size < 10:\n            continue\n\n        # make segmentation map\n        x, y, w, h = stats[\n            k,\n            [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],\n        ]\n\n        try:\n            niter = int(np.sqrt(min(w, h)))\n        except ValueError:\n            niter = 0\n\n        buffer = 1\n        sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)\n        ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)\n\n        mask = labels[sy:ey, sx:ex] == k\n        selected_linemap = linemap[sy:ey, sx:ex][mask]\n        if selected_linemap.size == 0:\n            continue\n\n        line_max = np.max(selected_linemap)\n\n        # thresholding\n        if line_max < text_threshold:\n            continue\n\n        segmap = mask.astype(np.uint8)\n\n        ksize = buffer + niter\n        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize))\n        selected_segmap = cv2.dilate(segmap, kernel)\n\n        # make box\n        y_inds, x_inds = np.nonzero(selected_segmap)\n        x_inds += sx\n        y_inds += sy\n        np_contours = np.column_stack((x_inds, y_inds))\n        rectangle = cv2.minAreaRect(np_contours)\n        box = cv2.boxPoints(rectangle)\n\n        # align diamond-shape\n        w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])\n        box_ratio = max(w, h) / (min(w, h) + 1e-5)\n        if abs(1 - box_ratio) <= 0.1:\n            left, right = np_contours[:, 0].min(), np_contours[:, 0].max()\n            top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max()\n            box = np.array(\n                [[left, top], [right, top], [right, bottom], [left, bottom]],\n                dtype=np.float32,\n            )\n\n        # make clock-wise order\n        startidx = box.sum(axis=1).argmin()\n        box = np.roll(box, 4 - startidx, 0)\n\n        max_confidence = max(max_confidence, line_max)\n\n        confidences.append(line_max)\n        det.append(box)\n\n    if max_confidence > 0:\n        confidences = [c / max_confidence for c in confidences]\n    return det, confidences\n\n\ndef get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:\n    if text_threshold is None:\n        text_threshold = settings.DETECTOR_TEXT_THRESHOLD\n    if low_text is None:\n        low_text = settings.DETECTOR_BLANK_THRESHOLD\n\n    if textmap.dtype != np.float32:\n        textmap = textmap.astype(np.float32)\n\n    boxes, confidences = detect_boxes(textmap, text_threshold, low_text)\n    # From point form to box form\n    return [\n        PolygonBox(polygon=box, confidence=confidence)\n        for box, confidence in zip(boxes, confidences)\n    ]\n\n\ndef get_and_clean_boxes(\n    textmap, processor_size, image_size, text_threshold=None, low_text=None\n) -> List[PolygonBox]:\n    bboxes = get_detected_boxes(textmap, text_threshold, low_text)\n    for bbox in bboxes:\n        bbox.rescale(processor_size, image_size)\n        bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]])\n\n    bboxes = clean_boxes(bboxes)\n    return bboxes\n\n\ndef parallel_get_boxes(preds, orig_sizes, include_maps=False):\n    heatmap, affinity_map = preds\n    heat_img, aff_img = None, None\n\n    if include_maps:\n        heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))\n        aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))\n    heatmap_size = list(reversed(heatmap.shape))\n    bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)\n    for box in bboxes:\n        # Skip for vertical boxes\n        if box.height < 3 * box.width:\n            box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN)\n            box.fit_to_bounds(\n                [0, 0, orig_sizes[0], orig_sizes[1]]\n            )  # Fix any bad expands\n\n    result = TextDetectionResult(\n        bboxes=bboxes,\n        heatmap=heat_img,\n        affinity_map=aff_img,\n        image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]],\n    )\n    return result\n"
  },
  {
    "path": "surya/detection/loader.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom surya.common.load import ModelLoader\nfrom surya.detection.processor import SegformerImageProcessor\n\nfrom surya.detection.model.config import EfficientViTConfig\nfrom surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n\nclass DetectionModelLoader(ModelLoader):\n    def __init__(self, checkpoint: Optional[str] = None):\n        super().__init__(checkpoint)\n\n        if self.checkpoint is None:\n            self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT\n\n    def model(\n        self,\n        device: Optional[torch.device | str] = None,\n        dtype: Optional[torch.dtype | str] = None,\n        attention_implementation: Optional[str] = None,\n    ) -> EfficientViTForSemanticSegmentation:\n        if device is None:\n            device = settings.TORCH_DEVICE_MODEL\n        if dtype is None:\n            dtype = settings.MODEL_DTYPE\n\n        config = EfficientViTConfig.from_pretrained(self.checkpoint)\n        model = EfficientViTForSemanticSegmentation.from_pretrained(\n            self.checkpoint,\n            dtype=dtype,\n            config=config,\n        )\n        model = model.to(device)\n        model = model.eval()\n\n        if settings.COMPILE_ALL or settings.COMPILE_DETECTOR:\n            torch._dynamo.config.cache_size_limit = 1\n            torch._dynamo.config.suppress_errors = False\n\n            logger.info(\n                f\"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}\"\n            )\n            compile_args = {\"backend\": \"openxla\"} if device == \"xla\" else {}\n            model = torch.compile(model, **compile_args)\n\n        logger.debug(\n            f\"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}\"\n        )\n        return model\n\n    def processor(\n        self,\n        device: Optional[torch.device | str] = None,\n        dtype: Optional[torch.dtype | str] = None,\n    ) -> SegformerImageProcessor:\n        return SegformerImageProcessor.from_pretrained(self.checkpoint)\n"
  },
  {
    "path": "surya/detection/model/__init__.py",
    "content": ""
  },
  {
    "path": "surya/detection/model/config.py",
    "content": "from transformers import PretrainedConfig\n\nfrom surya.common.s3 import S3DownloaderMixin\n\n\nclass EfficientViTConfig(S3DownloaderMixin, PretrainedConfig):\n    r\"\"\"\n    ```\"\"\"\n\n    model_type = \"efficientvit\"\n\n    def __init__(\n        self,\n        num_classes=2,\n        num_channels=3,\n        widths=(32, 64, 128, 256, 512),\n        head_dim=32,\n        num_stages=4,\n        depths=(1, 1, 1, 6, 6),\n        strides=(2, 2, 2, 2, 2),\n        hidden_sizes=(32, 64, 160, 256),\n        patch_size=(7, 7),\n        hidden_dropout_prob=0.0,\n        attention_probs_dropout_prob=0.0,\n        classifier_dropout_prob=0.0,\n        layer_norm_eps=1e-6,\n        decoder_layer_hidden_size=128,\n        decoder_hidden_size=512,\n        semantic_loss_ignore_index=255,\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.num_classes = num_classes\n        self.widths = widths\n        self.head_dim = head_dim\n\n        self.num_channels = num_channels\n        self.num_stages = num_stages\n        self.depths = depths\n        self.strides = strides\n        self.hidden_sizes = hidden_sizes\n        self.patch_size = patch_size\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.classifier_dropout_prob = classifier_dropout_prob\n        self.layer_norm_eps = layer_norm_eps\n        self.decoder_hidden_size = decoder_hidden_size\n        self.decoder_layer_hidden_size = decoder_layer_hidden_size\n        self.semantic_loss_ignore_index = semantic_loss_ignore_index\n\n        self.initializer_range = initializer_range"
  },
  {
    "path": "surya/detection/model/encoderdecoder.py",
    "content": "\"\"\"\nThis is an implementation of efficientvit, with some modifications (decode head, etc).\n\nOriginal paper at https://arxiv.org/abs/2205.14756\n\nCode adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py\nOriginal code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit\n\nLicense: Apache 2\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional, Union, Tuple, List, Any\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom transformers.modeling_outputs import SemanticSegmenterOutput\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.detection.model.config import EfficientViTConfig\n\n\ndef val2list(x: Union[List, Tuple, Any], repeat_time=1):\n    if isinstance(x, (list, tuple)):\n        return list(x)\n    return [x for _ in range(repeat_time)]\n\n\ndef val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1):\n    # repeat elements if necessary\n    x = val2list(x)\n    if len(x) > 0:\n        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]\n\n    return tuple(x)\n\n\ndef get_same_padding(\n    kernel_size: Union[int, Tuple[int, ...]],\n) -> Union[int, Tuple[int, ...]]:\n    if isinstance(kernel_size, tuple):\n        return tuple([get_same_padding(ks) for ks in kernel_size])\n    else:\n        assert kernel_size % 2 > 0, \"kernel size should be odd number\"\n        return kernel_size // 2\n\n\ndef get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:\n    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2\n    return padding\n\n\nclass ConvNormAct(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        dilation=1,\n        groups=1,\n        bias=False,\n        dropout=0.0,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.ReLU,\n    ):\n        super(ConvNormAct, self).__init__()\n        self.dropout = nn.Dropout(dropout, inplace=False)\n        padding = get_padding(kernel_size, stride, dilation)\n        self.conv = nn.Conv2d(\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n            padding=padding,\n        )\n        self.norm = (\n            norm_layer(num_features=out_channels) if norm_layer else nn.Identity()\n        )\n        self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.norm(x)\n        x = self.act(x)\n        return x\n\n\nclass DSConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        use_bias=False,\n        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),\n        act_layer=(nn.ReLU6, None),\n    ):\n        super(DSConv, self).__init__()\n        use_bias = val2tuple(use_bias, 2)\n        norm_layer = val2tuple(norm_layer, 2)\n        act_layer = val2tuple(act_layer, 2)\n\n        self.depth_conv = ConvNormAct(\n            in_channels,\n            in_channels,\n            kernel_size,\n            stride,\n            groups=in_channels,\n            norm_layer=norm_layer[0],\n            act_layer=act_layer[0],\n            bias=use_bias[0],\n        )\n        self.point_conv = ConvNormAct(\n            in_channels,\n            out_channels,\n            1,\n            norm_layer=norm_layer[1],\n            act_layer=act_layer[1],\n            bias=use_bias[1],\n        )\n\n    def forward(self, x):\n        x = self.depth_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\nclass ConvBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=1,\n        use_bias=False,\n        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),\n        act_layer=(nn.ReLU6, None),\n    ):\n        super(ConvBlock, self).__init__()\n        use_bias = val2tuple(use_bias, 2)\n        norm_layer = val2tuple(norm_layer, 2)\n        act_layer = val2tuple(act_layer, 2)\n        mid_channels = mid_channels or round(in_channels * expand_ratio)\n\n        self.conv1 = ConvNormAct(\n            in_channels,\n            mid_channels,\n            kernel_size,\n            stride,\n            norm_layer=norm_layer[0],\n            act_layer=act_layer[0],\n            bias=use_bias[0],\n        )\n        self.conv2 = ConvNormAct(\n            mid_channels,\n            out_channels,\n            kernel_size,\n            1,\n            norm_layer=norm_layer[1],\n            act_layer=act_layer[1],\n            bias=use_bias[1],\n        )\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\n\nclass MBConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=6,\n        use_bias=False,\n        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),\n        act_layer=(nn.ReLU6, nn.ReLU6, None),\n    ):\n        super(MBConv, self).__init__()\n        use_bias = val2tuple(use_bias, 3)\n        norm_layer = val2tuple(norm_layer, 3)\n        act_layer = val2tuple(act_layer, 3)\n        mid_channels = mid_channels or round(in_channels * expand_ratio)\n\n        self.inverted_conv = ConvNormAct(\n            in_channels,\n            mid_channels,\n            1,\n            stride=1,\n            norm_layer=norm_layer[0],\n            act_layer=act_layer[0],\n            bias=use_bias[0],\n        )\n        self.depth_conv = ConvNormAct(\n            mid_channels,\n            mid_channels,\n            kernel_size,\n            stride=stride,\n            groups=mid_channels,\n            norm_layer=norm_layer[1],\n            act_layer=act_layer[1],\n            bias=use_bias[1],\n        )\n        self.point_conv = ConvNormAct(\n            mid_channels,\n            out_channels,\n            1,\n            norm_layer=norm_layer[2],\n            act_layer=act_layer[2],\n            bias=use_bias[2],\n        )\n\n    def forward(self, x):\n        x = self.inverted_conv(x)\n        x = self.depth_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\nclass FusedMBConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=6,\n        groups=1,\n        use_bias=False,\n        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),\n        act_layer=(nn.ReLU6, None),\n    ):\n        super(FusedMBConv, self).__init__()\n        use_bias = val2tuple(use_bias, 2)\n        norm_layer = val2tuple(norm_layer, 2)\n        act_layer = val2tuple(act_layer, 2)\n        mid_channels = mid_channels or round(in_channels * expand_ratio)\n\n        self.spatial_conv = ConvNormAct(\n            in_channels,\n            mid_channels,\n            kernel_size,\n            stride=stride,\n            groups=groups,\n            norm_layer=norm_layer[0],\n            act_layer=act_layer[0],\n            bias=use_bias[0],\n        )\n        self.point_conv = ConvNormAct(\n            mid_channels,\n            out_channels,\n            1,\n            norm_layer=norm_layer[1],\n            act_layer=act_layer[1],\n            bias=use_bias[1],\n        )\n\n    def forward(self, x):\n        x = self.spatial_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\nclass LiteMLA(nn.Module):\n    \"\"\"Lightweight multi-scale linear attention\"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        heads: Union[int, None] = None,\n        heads_ratio: float = 1.0,\n        dim=8,\n        use_bias=False,\n        norm_layer=(None, nn.BatchNorm2d),\n        act_layer=(None, None),\n        kernel_func=nn.ReLU,\n        scales=(5,),\n        eps=1e-5,\n    ):\n        super(LiteMLA, self).__init__()\n        self.eps = eps\n        heads = heads or int(in_channels // dim * heads_ratio)\n        total_dim = heads * dim\n        use_bias = val2tuple(use_bias, 2)\n        norm_layer = val2tuple(norm_layer, 2)\n        act_layer = val2tuple(act_layer, 2)\n\n        self.dim = dim\n        self.qkv = ConvNormAct(\n            in_channels,\n            3 * total_dim,\n            1,\n            bias=use_bias[0],\n            norm_layer=norm_layer[0],\n            act_layer=act_layer[0],\n        )\n        self.aggreg = nn.ModuleList(\n            [\n                nn.Sequential(\n                    nn.Conv2d(\n                        3 * total_dim,\n                        3 * total_dim,\n                        scale,\n                        padding=get_same_padding(scale),\n                        groups=3 * total_dim,\n                        bias=use_bias[0],\n                    ),\n                    nn.Conv2d(\n                        3 * total_dim,\n                        3 * total_dim,\n                        1,\n                        groups=3 * heads,\n                        bias=use_bias[0],\n                    ),\n                )\n                for scale in scales\n            ]\n        )\n        self.kernel_func = kernel_func(inplace=False)\n\n        self.proj = ConvNormAct(\n            total_dim * (1 + len(scales)),\n            out_channels,\n            1,\n            bias=use_bias[1],\n            norm_layer=norm_layer[1],\n            act_layer=act_layer[1],\n        )\n\n    def _attn(self, q, k, v):\n        dtype = v.dtype\n        q, k, v = q.float(), k.float(), v.float()\n        kv = k.transpose(-1, -2) @ v\n        out = q @ kv\n        out = out[..., :-1] / (out[..., -1:] + self.eps)\n        return out.to(dtype)\n\n    def forward(self, x):\n        # Shape is B, C, H, W\n        B, _, H, W = x.shape\n\n        # generate multi-scale q, k, v\n        qkv = self.qkv(x)\n        multi_scale_qkv = [qkv]\n        for op in self.aggreg:\n            multi_scale_qkv.append(op(qkv))\n        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)\n        multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(\n            -1, -2\n        )\n        # Shape for each is B, C, HW, head_dim\n        q, k, v = multi_scale_qkv.chunk(3, dim=-1)\n\n        # lightweight global attention\n        q = self.kernel_func(q)\n        k = self.kernel_func(k)\n        v = F.pad(v, (0, 1), mode=\"constant\", value=1.0)\n\n        out = self._attn(q, k, v)\n\n        # final projection\n        out = out.transpose(-1, -2).reshape(B, -1, H, W)\n        out = self.proj(out)\n        return out\n\n\nclass EfficientVitBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        heads_ratio=1.0,\n        head_dim=32,\n        expand_ratio=4,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.Hardswish,\n    ):\n        super(EfficientVitBlock, self).__init__()\n        self.context_module = ResidualBlock(\n            LiteMLA(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                heads_ratio=heads_ratio,\n                dim=head_dim,\n                norm_layer=(None, norm_layer),\n            ),\n            nn.Identity(),\n        )\n        self.local_module = ResidualBlock(\n            MBConv(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                expand_ratio=expand_ratio,\n                use_bias=(True, True, False),\n                norm_layer=(None, None, norm_layer),\n                act_layer=(act_layer, act_layer, None),\n            ),\n            nn.Identity(),\n        )\n\n    def forward(self, x):\n        x = self.context_module(x)\n        x = self.local_module(x)\n        return x\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(\n        self,\n        main: Optional[nn.Module],\n        shortcut: Optional[nn.Module] = None,\n        pre_norm: Optional[nn.Module] = None,\n    ):\n        super(ResidualBlock, self).__init__()\n        self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()\n        self.main = main\n        self.shortcut = shortcut\n\n    def forward(self, x):\n        res = self.main(self.pre_norm(x))\n        if self.shortcut is not None:\n            res = res + self.shortcut(x)\n        return res\n\n\ndef build_local_block(\n    in_channels: int,\n    out_channels: int,\n    stride: int,\n    kernel_size: int,\n    expand_ratio: float,\n    norm_layer: str,\n    act_layer: str,\n    fewer_norm: bool = False,\n    block_type: str = \"default\",\n):\n    assert block_type in [\"default\", \"large\", \"fused\"]\n    if expand_ratio == 1:\n        if block_type == \"default\":\n            block = DSConv(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                stride=stride,\n                kernel_size=kernel_size,\n                use_bias=(True, False) if fewer_norm else False,\n                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,\n                act_layer=(act_layer, None),\n            )\n        else:\n            block = ConvBlock(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                stride=stride,\n                kernel_size=kernel_size,\n                use_bias=(True, False) if fewer_norm else False,\n                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,\n                act_layer=(act_layer, None),\n            )\n    else:\n        if block_type == \"default\":\n            block = MBConv(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                stride=stride,\n                kernel_size=kernel_size,\n                expand_ratio=expand_ratio,\n                use_bias=(True, True, False) if fewer_norm else False,\n                norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,\n                act_layer=(act_layer, act_layer, None),\n            )\n        else:\n            block = FusedMBConv(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                stride=stride,\n                kernel_size=kernel_size,\n                expand_ratio=expand_ratio,\n                use_bias=(True, False) if fewer_norm else False,\n                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,\n                act_layer=(act_layer, None),\n            )\n    return block\n\n\nclass Stem(nn.Sequential):\n    def __init__(\n        self,\n        in_chs,\n        out_chs,\n        depth,\n        stride,\n        norm_layer,\n        act_layer,\n        block_type=\"default\",\n    ):\n        super().__init__()\n        self.stride = stride\n\n        self.add_module(\n            \"in_conv\",\n            ConvNormAct(\n                in_chs,\n                out_chs,\n                kernel_size=stride + 1,\n                stride=stride,\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n            ),\n        )\n        stem_block = 0\n        for _ in range(depth):\n            self.add_module(\n                f\"res{stem_block}\",\n                ResidualBlock(\n                    build_local_block(\n                        in_channels=out_chs,\n                        out_channels=out_chs,\n                        stride=1,\n                        kernel_size=3,\n                        expand_ratio=1,\n                        norm_layer=norm_layer,\n                        act_layer=act_layer,\n                        block_type=block_type,\n                    ),\n                    nn.Identity(),\n                ),\n            )\n            stem_block += 1\n\n\nclass EfficientVitLargeStage(nn.Module):\n    def __init__(\n        self,\n        in_chs,\n        out_chs,\n        depth,\n        stride,\n        norm_layer,\n        act_layer,\n        head_dim,\n        vit_stage=False,\n        fewer_norm=False,\n    ):\n        super(EfficientVitLargeStage, self).__init__()\n        blocks = [\n            ResidualBlock(\n                build_local_block(\n                    in_channels=in_chs,\n                    out_channels=out_chs,\n                    stride=stride,\n                    kernel_size=stride + 1,\n                    expand_ratio=24 if vit_stage else 16,\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    fewer_norm=vit_stage or fewer_norm,\n                    block_type=\"default\" if fewer_norm else \"fused\",\n                ),\n                None,\n            )\n        ]\n        in_chs = out_chs\n\n        if vit_stage:\n            # for stage 4\n            for _ in range(depth):\n                blocks.append(\n                    EfficientVitBlock(\n                        in_channels=in_chs,\n                        head_dim=head_dim,\n                        expand_ratio=6,\n                        norm_layer=norm_layer,\n                        act_layer=act_layer,\n                    )\n                )\n        else:\n            # for stage 1, 2, 3\n            for i in range(depth):\n                blocks.append(\n                    ResidualBlock(\n                        build_local_block(\n                            in_channels=in_chs,\n                            out_channels=out_chs,\n                            stride=1,\n                            kernel_size=3,\n                            expand_ratio=4,\n                            norm_layer=norm_layer,\n                            act_layer=act_layer,\n                            fewer_norm=fewer_norm,\n                            block_type=\"default\" if fewer_norm else \"fused\",\n                        ),\n                        nn.Identity(),\n                    )\n                )\n\n        self.blocks = nn.Sequential(*blocks)\n\n    def forward(self, x):\n        return self.blocks(x)\n\n\nclass EfficientVitLarge(nn.Module):\n    def __init__(\n        self,\n        config: EfficientViTConfig,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.Hardswish,\n    ):\n        super(EfficientVitLarge, self).__init__()\n        self.grad_checkpointing = False\n        self.num_classes = config.num_classes\n        self.norm_eps = config.layer_norm_eps\n        norm_layer = partial(norm_layer, eps=self.norm_eps)\n\n        # input stem\n        self.stem = Stem(\n            config.num_channels,\n            config.widths[0],\n            config.depths[0],\n            config.strides[0],\n            norm_layer,\n            act_layer,\n            block_type=\"large\",\n        )\n        stride = config.strides[0]\n\n        # stages\n        self.feature_info = []\n        self.stages = nn.Sequential()\n        in_channels = config.widths[0]\n        for i, (w, d, s) in enumerate(\n            zip(config.widths[1:], config.depths[1:], config.strides[1:])\n        ):\n            self.stages.append(\n                EfficientVitLargeStage(\n                    in_channels,\n                    w,\n                    depth=d,\n                    stride=s,\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    head_dim=config.head_dim,\n                    vit_stage=i >= 3,\n                    fewer_norm=i >= 2,\n                )\n            )\n            stride *= s\n            in_channels = w\n            self.feature_info += [\n                dict(num_chs=in_channels, reduction=stride, module=f\"stages.{i}\")\n            ]\n\n        self.num_features = in_channels\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n\n    def forward(self, x):\n        x = self.stem(x)\n        encoder_hidden_states = []\n        for i, module in enumerate(self.stages):\n            x = module(x)\n            encoder_hidden_states.append(x)\n\n        return encoder_hidden_states\n\n\nclass EfficientViTPreTrainedModel(SuryaPreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = EfficientViTConfig\n    base_model_prefix = \"efficientvit\"\n    main_input_name = \"pixel_values\"\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Conv2d)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\nclass DecodeMLP(nn.Module):\n    def __init__(self, input_dim, output_dim):\n        super().__init__()\n        self.proj = nn.Linear(input_dim, output_dim)\n\n    def forward(self, hidden_states: torch.Tensor):\n        # Input is B, C, H, W\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n        # Output is B, HW, C\n        hidden_states = self.proj(hidden_states)\n        return hidden_states\n\n\nclass DecodeHead(EfficientViTPreTrainedModel):\n    def __init__(self, config: EfficientViTConfig):\n        super().__init__(config)\n\n        # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size\n        mlps = []\n        for width in config.widths[1:]:\n            mlp = DecodeMLP(\n                input_dim=width, output_dim=config.decoder_layer_hidden_size\n            )\n            mlps.append(mlp)\n        self.linear_c = nn.ModuleList(mlps)\n\n        # the following 3 layers implement the ConvModule of the original implementation\n        self.linear_fuse = nn.Conv2d(\n            in_channels=config.decoder_layer_hidden_size * config.num_stages,\n            out_channels=config.decoder_hidden_size,\n            kernel_size=1,\n            bias=False,\n        )\n        self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)\n        self.activation = nn.ReLU()\n\n        self.dropout = nn.Dropout(config.classifier_dropout_prob)\n        self.classifier = nn.Conv2d(\n            config.decoder_hidden_size, config.num_labels, kernel_size=1\n        )\n\n        self.config = config\n\n    def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:\n        batch_size = encoder_hidden_states[-1].shape[0]\n\n        all_hidden_states = ()\n        for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):\n            height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]\n            encoder_hidden_state = mlp(encoder_hidden_state)  # Output is B, HW, C\n            # Permute to B, C, HW\n            encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)\n            encoder_hidden_state = encoder_hidden_state.reshape(\n                batch_size, -1, height, width\n            )\n            # upsample\n            encoder_hidden_state = nn.functional.interpolate(\n                encoder_hidden_state,\n                size=encoder_hidden_states[0].size()[2:],\n                mode=\"bilinear\",\n                align_corners=False,\n            )\n            all_hidden_states += (encoder_hidden_state,)\n\n        hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))\n        hidden_states = self.batch_norm(hidden_states)\n        hidden_states = self.activation(hidden_states)\n\n        # logits are of shape (batch_size, num_labels, height/4, width/4)\n        logits = self.classifier(hidden_states)\n\n        return logits\n\n\nclass EfficientViTForSemanticSegmentation(\n    S3DownloaderMixin, EfficientViTPreTrainedModel\n):\n    def __init__(self, config, **kwargs):\n        super().__init__(config)\n        self.vit = EfficientVitLarge(config)\n        self.decode_head = DecodeHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self, pixel_values: torch.FloatTensor\n    ) -> Union[Tuple, SemanticSegmenterOutput]:\n        # Pixel values should be B,C,H,W\n        encoder_hidden_states = self.vit(\n            pixel_values,\n        )\n\n        logits = self.decode_head(encoder_hidden_states)\n\n        # Apply sigmoid to get 0-1 output\n        logits = torch.special.expit(logits)\n\n        return SemanticSegmenterOutput(\n            loss=None, logits=logits, hidden_states=encoder_hidden_states\n        )\n\n\nclass EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel):\n    def __init__(self, config, **kwargs):\n        super().__init__(config, **kwargs)\n        self.vit = EfficientVitLarge(config)\n        self.decode_head = DecodeHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self, pixel_values: torch.FloatTensor\n    ) -> Union[Tuple, SemanticSegmenterOutput]:\n        # Pixel values should be B,C,H,W\n        encoder_hidden_states = self.vit(\n            pixel_values,\n        )\n\n        logits = self.decode_head(encoder_hidden_states)\n\n        # Apply sigmoid to get 0-1 output\n        logits = torch.special.expit(logits)\n\n        return SemanticSegmenterOutput(\n            loss=None, logits=logits, hidden_states=encoder_hidden_states\n        )\n"
  },
  {
    "path": "surya/detection/parallel.py",
    "content": "class FakeFuture:\n    def __init__(self, func, *args, **kwargs):\n        self._result = func(*args, **kwargs)\n\n    def result(self):\n        return self._result\n\nclass FakeExecutor:\n    def __init__(self, **kwargs):\n        pass\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *excinfo):\n        pass\n\n    def submit(self, fn, *args, **kwargs):\n        return FakeFuture(fn, *args, **kwargs)\n"
  },
  {
    "path": "surya/detection/processor.py",
    "content": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Modified image processor class for Segformer based on transformers\"\"\"\n\nimport warnings\nfrom typing import Any, Dict, List, Optional, Union\n\nimport numpy as np\n\nfrom transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict\nfrom transformers.image_transforms import to_channel_dimension_format\nfrom transformers.image_utils import (\n    IMAGENET_DEFAULT_MEAN,\n    IMAGENET_DEFAULT_STD,\n    ChannelDimension,\n    ImageInput,\n    PILImageResampling,\n    infer_channel_dimension_format,\n    make_list_of_images,\n)\nfrom transformers.utils import TensorType\n\n\nimport PIL.Image\nimport torch\n\nfrom surya.common.s3 import S3DownloaderMixin\n\n\nclass SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):\n    r\"\"\"\n    Constructs a Segformer image processor.\n\n    Args:\n        do_resize (`bool`, *optional*, defaults to `True`):\n            Whether to resize the image's (height, width) dimensions to the specified `(size[\"height\"],\n            size[\"width\"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.\n        size (`Dict[str, int]` *optional*, defaults to `{\"height\": 512, \"width\": 512}`):\n            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`\n            method.\n        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):\n            Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the\n            `preprocess` method.\n        do_rescale (`bool`, *optional*, defaults to `True`):\n            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`\n            parameter in the `preprocess` method.\n        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):\n            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`\n            method.\n        do_normalize (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`\n            method.\n        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):\n            Mean to use if normalizing the image. This is a float or list of floats the length of the number of\n            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.\n        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):\n            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the\n            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.\n        do_reduce_labels (`bool`, *optional*, defaults to `False`):\n            Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is\n            used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The\n            background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the\n            `preprocess` method.\n    \"\"\"\n\n    model_input_names = [\"pixel_values\"]\n\n    def __init__(\n        self,\n        do_resize: bool = True,\n        size: Dict[str, int] = None,\n        resample: PILImageResampling = PILImageResampling.BILINEAR,\n        do_rescale: bool = True,\n        rescale_factor: Union[int, float] = 1 / 255,\n        do_normalize: bool = True,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        do_reduce_labels: bool = False,\n        **kwargs,\n    ) -> None:\n        if \"reduce_labels\" in kwargs:\n            warnings.warn(\n                \"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use \"\n                \"`do_reduce_labels` instead.\",\n                FutureWarning,\n            )\n            do_reduce_labels = kwargs.pop(\"reduce_labels\")\n\n        super().__init__(**kwargs)\n        size = size if size is not None else {\"height\": 512, \"width\": 512}\n        size = get_size_dict(size)\n        self.do_resize = do_resize\n        self.size = size\n        self.resample = resample\n        self.do_rescale = do_rescale\n        self.rescale_factor = rescale_factor\n        self.do_normalize = do_normalize\n        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN\n        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD\n        self.do_reduce_labels = do_reduce_labels\n        self._valid_processor_keys = [\n            \"images\",\n            \"segmentation_maps\",\n            \"do_resize\",\n            \"size\",\n            \"resample\",\n            \"do_rescale\",\n            \"rescale_factor\",\n            \"do_normalize\",\n            \"image_mean\",\n            \"image_std\",\n            \"do_reduce_labels\",\n            \"return_tensors\",\n            \"data_format\",\n            \"input_data_format\",\n        ]\n\n    @classmethod\n    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):\n        \"\"\"\n        Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image\n        processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,\n        reduce_labels=True)`\n        \"\"\"\n        image_processor_dict = image_processor_dict.copy()\n        if \"reduce_labels\" in kwargs:\n            image_processor_dict[\"reduce_labels\"] = kwargs.pop(\"reduce_labels\")\n        return super().from_dict(image_processor_dict, **kwargs)\n\n    def _preprocess(\n        self,\n        image: ImageInput,\n        do_resize: bool,\n        do_rescale: bool,\n        do_normalize: bool,\n        size: Optional[Dict[str, int]] = None,\n        resample: PILImageResampling = None,\n        rescale_factor: Optional[float] = None,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n    ):\n\n        if do_rescale:\n            image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)\n\n        if do_normalize:\n            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)\n\n        return image\n\n    def _preprocess_image(\n        self,\n        image: ImageInput,\n        do_resize: bool = None,\n        size: Dict[str, int] = None,\n        resample: PILImageResampling = None,\n        do_rescale: bool = None,\n        rescale_factor: float = None,\n        do_normalize: bool = None,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n    ) -> np.ndarray:\n        \"\"\"Preprocesses a single image.\"\"\"\n        # All transformations expect numpy arrays.\n        if input_data_format is None:\n            input_data_format = infer_channel_dimension_format(image)\n\n        image = self._preprocess(\n            image=image,\n            do_resize=do_resize,\n            size=size,\n            resample=resample,\n            do_rescale=do_rescale,\n            rescale_factor=rescale_factor,\n            do_normalize=do_normalize,\n            image_mean=image_mean,\n            image_std=image_std,\n            input_data_format=input_data_format,\n        )\n        if data_format is not None:\n            image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)\n        return image\n\n    def __call__(self, images, segmentation_maps=None, **kwargs):\n        \"\"\"\n        Preprocesses a batch of images and optionally segmentation maps.\n\n        Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be\n        passed in as positional arguments.\n        \"\"\"\n        return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)\n\n    def preprocess(\n        self,\n        images: ImageInput,\n        segmentation_maps: Optional[ImageInput] = None,\n        do_resize: Optional[bool] = None,\n        size: Optional[Dict[str, int]] = None,\n        resample: PILImageResampling = None,\n        do_rescale: Optional[bool] = None,\n        rescale_factor: Optional[float] = None,\n        do_normalize: Optional[bool] = None,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        do_reduce_labels: Optional[bool] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        data_format: ChannelDimension = ChannelDimension.FIRST,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        **kwargs,\n    ) -> PIL.Image.Image:\n        \"\"\"\n        Preprocess an image or batch of images.\n\n        Args:\n            images (`ImageInput`):\n                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n                passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n            segmentation_maps (`ImageInput`, *optional*):\n                Segmentation map to preprocess.\n            do_resize (`bool`, *optional*, defaults to `self.do_resize`):\n                Whether to resize the image.\n            size (`Dict[str, int]`, *optional*, defaults to `self.size`):\n                Size of the image after `resize` is applied.\n            resample (`int`, *optional*, defaults to `self.resample`):\n                Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only\n                has an effect if `do_resize` is set to `True`.\n            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):\n                Whether to rescale the image values between [0 - 1].\n            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):\n                Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):\n                Whether to normalize the image.\n            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):\n                Image mean.\n            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):\n                Image standard deviation.\n            do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):\n                Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0\n                is used for background, and background itself is not included in all classes of a dataset (e.g.\n                ADE20k). The background label will be replaced by 255.\n            return_tensors (`str` or `TensorType`, *optional*):\n                The type of tensors to return. Can be one of:\n                    - Unset: Return a list of `np.ndarray`.\n                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.\n                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.\n                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.\n                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.\n            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):\n                The channel dimension format for the output image. Can be one of:\n                    - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                    - `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n            input_data_format (`ChannelDimension` or `str`, *optional*):\n                The channel dimension format for the input image. If unset, the channel dimension format is inferred\n                from the input image. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n                - `\"none\"` or `ChannelDimension.NONE`: image in (height, width) format.\n        \"\"\"\n        do_resize = do_resize if do_resize is not None else self.do_resize\n        do_rescale = do_rescale if do_rescale is not None else self.do_rescale\n        do_normalize = do_normalize if do_normalize is not None else self.do_normalize\n        resample = resample if resample is not None else self.resample\n        size = size if size is not None else self.size\n        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor\n        image_mean = image_mean if image_mean is not None else self.image_mean\n        image_std = image_std if image_std is not None else self.image_std\n\n        images = make_list_of_images(images)\n        images = [\n            self._preprocess_image(\n                image=img,\n                do_resize=do_resize,\n                resample=resample,\n                size=size,\n                do_rescale=do_rescale,\n                rescale_factor=rescale_factor,\n                do_normalize=do_normalize,\n                image_mean=image_mean,\n                image_std=image_std,\n                data_format=data_format,\n                input_data_format=input_data_format,\n            )\n            for img in images\n        ]\n\n        data = {\"pixel_values\": images}\n        return BatchFeature(data=data, tensor_type=return_tensors)"
  },
  {
    "path": "surya/detection/schema.py",
    "content": "from typing import List, Optional, Any\n\nfrom pydantic import BaseModel\n\nfrom surya.common.polygon import PolygonBox\n\n\nclass TextDetectionResult(BaseModel):\n    bboxes: List[PolygonBox]\n    heatmap: Optional[Any]\n    affinity_map: Optional[Any]\n    image_bbox: List[float]\n"
  },
  {
    "path": "surya/detection/util.py",
    "content": "import math\nfrom PIL import ImageOps\n\nfrom surya.settings import settings\n\n\ndef get_total_splits(image_size, height):\n    img_height = list(image_size)[1]\n    max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT\n    if img_height > max_height:\n        num_splits = math.ceil(img_height / height)\n        return num_splits\n    return 1\n\n\ndef split_image(img, height):\n    # This will not modify/return the original image - it will either crop, or copy the image\n    img_height = list(img.size)[1]\n    max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT\n    if img_height > max_height:\n        num_splits = math.ceil(img_height / height)\n        splits = []\n        split_heights = []\n        for i in range(num_splits):\n            top = i * height\n            bottom = (i + 1) * height\n            if bottom > img_height:\n                bottom = img_height\n            cropped = img.crop((0, top, img.size[0], bottom))\n            chunk_height = bottom - top\n            if chunk_height < height:\n                cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0))\n            splits.append(cropped)\n            split_heights.append(chunk_height)\n        return splits, split_heights\n    return [img.copy()], [img_height]\n"
  },
  {
    "path": "surya/foundation/__init__.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple\nfrom collections import deque\n\nimport cv2\nimport numpy as np\nimport torch\nimport math\nfrom PIL import Image\nfrom tqdm import tqdm\nimport torch.nn.functional as F\n\nfrom surya.common.surya import SuryaModelOutput\nfrom surya.common.xla import mark_step\nfrom surya.common.predictor import BasePredictor\n\nfrom surya.foundation.loader import FoundationModelLoader\nfrom surya.foundation.util import (\n    detect_repeat_token,\n)\nfrom surya.common.surya.schema import TaskNames\nfrom surya.foundation.cache.dynamic_ops import DynamicOpsCache\nfrom surya.foundation.cache.static_ops import StaticOpsCache\n\nfrom surya.settings import settings\nfrom surya.logging import get_logger, configure_logging\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@dataclass\nclass ContinuousBatchInput:\n    input_ids: torch.Tensor\n    input_boxes: torch.Tensor\n    position_ids: torch.Tensor\n    # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts\n    num_valid_tokens: torch.Tensor\n    # count the number of predicted tokens for each batch element so far\n    num_predicted_tokens: torch.Tensor\n    needs_bbox_embedding: torch.Tensor\n\n\n@dataclass\nclass ContinuousBatchOutput:\n    input_ids: torch.Tensor\n    preds: torch.Tensor\n    bbox_preds: torch.Tensor\n    scores: torch.Tensor\n    token_probs: torch.Tensor\n\n\n@dataclass\nclass FoundationPrompt:\n    id: int\n    task_name: TaskNames\n    image: np.ndarray\n    text: str\n    math_mode: bool\n\n\nclass FoundationPredictor(BasePredictor):\n    model_loader_cls = FoundationModelLoader\n    batch_size = (\n        settings.RECOGNITION_BATCH_SIZE\n    )  # Default to the recognition batch size\n    torch_dtype = None  # No default, loader picks the dtype based on device properties - bf16/fp16\n    default_batch_sizes = {\"cpu\": 32, \"mps\": 64, \"cuda\": 256, \"xla\": 64}\n    encoder_chunk_size: int = 4096  # Default chunk size\n    encoder_chunk_sizes = {\"cpu\": 4096, \"mps\": 4096, \"cuda\": 32768, \"xla\": 32768}\n    extra_token_count = {\n        \"xla\": 128\n    }  # We have to pad the XLA cache since we don't use sliding window\n    min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2\n    min_trim_length: int = 50\n    tasks = {\n        TaskNames.ocr_with_boxes: {\n            \"needs_bboxes\": True,\n            \"img_size\": (1024, 512),\n            \"max_tokens\": 224,\n        },\n        TaskNames.ocr_without_boxes: {\n            \"needs_bboxes\": False,\n            \"img_size\": (1024, 512),\n            \"max_tokens\": 224,\n        },\n        TaskNames.block_without_boxes: {\n            \"needs_bboxes\": False,\n            \"img_size\": (1024, 512),\n            \"max_tokens\": 768,\n        },\n        TaskNames.layout: {\n            \"needs_bboxes\": False,\n            \"img_size\": (1024, 1024),\n            \"max_tokens\": 200,\n        },\n        TaskNames.table_structure: {\n            \"needs_bboxes\": False,\n            \"img_size\": (1024, 512),\n            \"max_tokens\": 600,\n        },\n    }\n\n    def __init__(\n        self,\n        checkpoint=None,\n        device=settings.TORCH_DEVICE_MODEL,\n        dtype=None,\n        attention_implementation: Optional[str] = None,\n    ):\n        super().__init__(checkpoint, device, dtype, attention_implementation)\n        self.prompt_queue = deque()\n        self.batch_prompt_mapping = None\n        self.kv_cache = None\n\n        self.beacon_token_interval = self.model.config.beacon_token_interval\n\n        # Setup various tokens on-device\n        self.device_pad_token = torch.tensor(\n            self.processor.pad_token_id, device=self.model.device, dtype=torch.long\n        )\n        self.device_beacon_token = torch.tensor(\n            self.processor.beacon_token_id, device=self.model.device, dtype=torch.long\n        )\n        self.special_token_ids = torch.tensor(\n            [self.model.config.image_token_id] + self.model.config.register_token_ids,\n            device=self.model.device,\n        )\n\n        self.pad_to_multiple = (\n            settings.FOUNDATION_PAD_TO_NEAREST\n            if settings.FOUNDATION_STATIC_CACHE\n            else None\n        )\n\n    def to(self, device_dtype: torch.device | str | None = None):\n        super().to(device_dtype)\n        self.special_token_ids = self.special_token_ids.to(device_dtype)\n\n    def get_encoder_chunk_size(self) -> int:\n        if settings.FOUNDATION_CHUNK_SIZE is not None:\n            return settings.FOUNDATION_CHUNK_SIZE\n\n        chunk_size = self.encoder_chunk_size\n        if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:\n            if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:\n                chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL]\n        return chunk_size\n\n    def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int):\n        kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache\n        self.kv_cache = kv_cache_cls(\n            self.model.config,\n            batch_size,\n            max_cache_len,\n            text_sliding_window=max_sliding_window,\n            device=self.model.device,\n            dtype=self.model.dtype,\n        )\n        self.prompt_queue.clear()\n        self.batch_prompt_mapping = {i: None for i in range(batch_size)}\n\n    @property\n    def num_empty_slots(self):\n        return sum(v is None for v in self.batch_prompt_mapping.values())\n\n    @property\n    def num_active_slots(self):\n        return len(self.batch_prompt_mapping) - self.num_empty_slots\n\n    def prepare_input(\n        self,\n        task_names: List[str],\n        images: List[Image.Image],\n        input_text: List[str | None],\n        math_modes: List[bool],\n    ):\n        batch = []\n        for image, text, task_name, math_mode in zip(\n            images, input_text, task_names, math_modes\n        ):\n            image_size = self.tasks[task_name][\"img_size\"]\n\n            try:\n                image = self.processor.scale_to_fit(\n                    image, image_size\n                )  # Only resizes if out of bounds (max/min)\n            except cv2.error:\n                # The image is empty if it can't be resized, so just make a blank image\n                image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)\n\n            # Task input is the same for all tasks for now\n            text = text or \"\"\n\n            # Remove input text that exceeds max generation tokens (likely invalid)\n            if len(text) > self.tasks[task_name][\"max_tokens\"]:\n                text = \"\"\n            inputs = [\n                {\"type\": \"image\", \"image\": image, \"rotated\": False},\n                {\"type\": \"text\", \"text\": text.strip(), \"math\": math_mode},\n            ]\n            batch.append({\"task\": task_name, \"inputs\": inputs})\n\n        return batch\n\n    def process_outputs(\n        self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None\n    ) -> ContinuousBatchOutput:\n        # Predictions are multi-token\n        lm_logits = outputs[\"lm_logits\"].float()  # shape: [batch_size, seq_len, V]\n        bbox_logits = outputs[\"bbox_logits\"].float()  # shape: [batch_size, seq_len, 6]\n\n        if (\n            max_lookahead_tokens is not None\n            and lm_logits.shape[1] > max_lookahead_tokens + 1\n        ):\n            lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :]\n            bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :]\n\n        # Get predictions\n        preds = torch.argmax(lm_logits, dim=-1)\n        input_ids = preds.to(torch.long)\n\n        # Confidence scores for all tokens\n        token_probs = F.softmax(lm_logits, dim=-1)\n        scores = torch.max(token_probs, dim=-1).values  # shape: [B, T]\n\n        # Update input boxes\n        box_preds = bbox_logits * self.model.config.bbox_size\n        box_preds = box_preds.to(torch.long)\n\n        return ContinuousBatchOutput(\n            input_ids=input_ids,\n            preds=preds,\n            bbox_preds=box_preds,\n            scores=scores,\n            token_probs=token_probs,\n        )\n\n    # Always left pad with beacons, don't worry about attention masking\n    def maybe_insert_beacon_tokens(\n        self,\n        input_ids: torch.Tensor,\n        input_boxes: torch.Tensor,\n        num_predicted_tokens: torch.Tensor,\n        num_new_tokens: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        batch_size, seq_len = (\n            input_ids.shape\n        )  # seq_len can be >1 - In case of multi-token predictions\n\n        # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted**\n        token_positions = num_predicted_tokens + torch.arange(\n            1, seq_len + 1, device=input_ids.device\n        ).unsqueeze(0)\n        beacon_positions = token_positions % self.beacon_token_interval == 0\n\n        # If no beacons needed, return original input\n        needs_beacon = beacon_positions.any(dim=1)  # shape: [batch_size]\n        if not needs_beacon.any():\n            if num_new_tokens is None:\n                num_new_tokens = (\n                    torch.ones(batch_size, dtype=torch.long, device=input_ids.device)\n                    * seq_len\n                )\n            return input_ids, input_boxes, num_new_tokens.squeeze(1)\n\n        beacon_insert_pos = torch.zeros(\n            batch_size, dtype=torch.long, device=input_ids.device\n        )\n        for i in range(batch_size):\n            if needs_beacon[i]:\n                # Find first position that needs beacon\n                beacon_insert_pos[i] = torch.where(beacon_positions[i])[0]\n\n        # Padded input ids.\n        new_input_ids = torch.full(\n            (batch_size, seq_len + 1),\n            self.device_pad_token,\n            dtype=input_ids.dtype,\n            device=input_ids.device,\n        )\n        new_input_boxes = torch.full(\n            (batch_size, seq_len + 1, 6),\n            -100,\n            dtype=input_boxes.dtype,\n            device=input_boxes.device,\n        )\n        # Fill in tokens for each sequence\n        for i in range(batch_size):\n            if needs_beacon[i]:\n                insert_pos = beacon_insert_pos[i]\n                new_input_ids[i, insert_pos] = self.device_beacon_token\n                new_input_boxes[i, insert_pos, :] = -100\n                if insert_pos > 0:\n                    new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos]\n                    new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos]\n                new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:]\n                new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:]\n            else:\n                new_input_ids[i, 1:] = input_ids[i, :]\n                new_input_boxes[i, 1:] = input_boxes[i, :]\n\n        # Calculate valid token counts for both padded and non padded sequences\n        valid_token_counts = torch.where(\n            needs_beacon,\n            torch.tensor(seq_len + 1, device=input_ids.device),\n            torch.tensor(seq_len, device=input_ids.device),\n        )\n\n        return new_input_ids, new_input_boxes, valid_token_counts\n\n    def decode(\n        self,\n        current_inputs: Optional[ContinuousBatchInput] = None,\n        max_lookahead_tokens: Optional[int] = None,\n    ):\n        # Note - If we want to use the outputs from the non-last token, we\n        # need to set the cache position manually to ensure causality. The default\n        # behavior only works for the last token currently\n        input_ids = current_inputs.input_ids\n        input_boxes = current_inputs.input_boxes\n        embed_boxes = current_inputs.needs_bbox_embedding\n\n        position_ids = current_inputs.position_ids\n        num_predicted_tokens = current_inputs.num_predicted_tokens\n        num_valid_tokens = current_inputs.num_valid_tokens\n        batch_size = input_ids.shape[0]\n\n        # Pre-shift the attention mask based on the cache update\n        self.kv_cache.decode_attention_mask_update(\n            num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size))\n        )\n\n        cache_position = self.get_cache_position(\n            input_ids.shape[1], self.kv_cache.attention_mask, prefill=False\n        )\n        with settings.INFERENCE_MODE():\n            outputs = self.model(\n                input_ids=input_ids,\n                attention_mask=self.kv_cache.attention_mask,\n                position_ids=position_ids,\n                cache_position=cache_position,\n                use_cache=True,\n                past_key_values=self.kv_cache,\n                prefill=False,\n                num_valid_tokens=num_valid_tokens,\n                input_boxes=input_boxes,\n                embed_boxes=embed_boxes,\n                logits_to_keep=1,\n            )\n\n        processed_output: ContinuousBatchOutput = self.process_outputs(\n            outputs, max_lookahead_tokens=max_lookahead_tokens\n        )\n\n        input_ids = processed_output.input_ids\n        input_boxes = processed_output.bbox_preds\n\n        # Update this **before** inserting beacon tokens\n        tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE\n        if max_lookahead_tokens is not None:\n            num_new_tokens = torch.clamp(\n                (\n                    processed_output.scores.ge(tau)\n                    .to(torch.long)\n                    .cumprod(dim=1)\n                    .sum(dim=1, keepdim=True)\n                ),\n                min=1,\n            )\n        else:\n            num_new_tokens = input_ids.shape[1]\n\n        num_predicted_tokens += num_new_tokens\n        input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens(\n             input_ids, input_boxes, num_predicted_tokens, num_new_tokens\n        )\n        position_ids = position_ids[:, -1:] + torch.arange(\n            1, input_ids.shape[1] + 1, device=input_ids.device\n        )\n        # Some of the input sequences may now have left padding tokens, so we want to account for that\n        # offset is a per-batch offset of the position_ids\n        offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1)\n        position_ids -= offset\n\n        new_input = ContinuousBatchInput(\n            input_ids=input_ids,\n            input_boxes=input_boxes,\n            position_ids=position_ids,\n            num_valid_tokens=num_valid_tokens,\n            num_predicted_tokens=num_predicted_tokens,\n            needs_bbox_embedding=current_inputs.needs_bbox_embedding,\n        )\n\n        return new_input, processed_output\n\n    def pad_and_shift_input_ids_position_ids(\n        self,\n        input_ids: torch.Tensor,\n        bbox_preds: torch.Tensor,\n        position_ids: torch.Tensor,\n        new_seq_len: int,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Pads new_input_ids to match the new seq len with **left padding**\n        and creates updated position_ids\n\n        Returns:\n            padded_input_ids (torch.Tensor): [batch_size, current_seq_len]\n            updated_position_ids (torch.Tensor): [batch_size, current_seq_len]\n        \"\"\"\n        # No padding\n        if new_seq_len == input_ids.shape[1]:\n            return (\n                input_ids,\n                bbox_preds,\n                position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device),\n            )\n\n        pad_len = new_seq_len - input_ids.shape[1]\n        padded_input_ids = torch.nn.functional.pad(\n            input_ids, (pad_len, 0), value=self.device_pad_token\n        )\n\n        padded_bbox_preds = torch.nn.functional.pad(\n            bbox_preds, (0, 0, pad_len, 0), value=-100\n        )\n\n        # Since we have **left padding**, offset the new position_ids by the amount of padding\n        # This ensures that the **true tokens** get the correct position_ids\n        # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs\n        updated_position_ids = position_ids[:, -1:] + torch.arange(\n            1, new_seq_len + 1, device=self.model.device\n        )\n        updated_position_ids -= pad_len\n\n        return padded_input_ids, padded_bbox_preds, updated_position_ids\n\n    def get_cache_position(\n        self,\n        seq_len: int,\n        attention_mask: torch.Tensor,\n        prefill: bool,\n    ):\n        batch_size, target_len = attention_mask.shape\n        base_cache_position = (\n            torch.arange(seq_len, device=attention_mask.device)\n            .unsqueeze(0)\n            .expand(batch_size, -1)\n        )\n        if prefill:\n            return base_cache_position\n\n        # This is a (batch_size) tensor, we can add the seq lens here\n        cache_seqlens = (\n            attention_mask\n            * torch.arange(attention_mask.size(1), device=attention_mask.device)\n        ).argmax(dim=1).to(torch.int32) + 1\n        # Needs to be unsqueezed so broadcasting works\n        return cache_seqlens.unsqueeze(1) + base_cache_position\n\n    def prefill(\n        self,\n        current_inputs: Optional[ContinuousBatchInput] = None,\n        max_lookahead_tokens: Optional[int] = None,\n    ):\n        logger.debug(f\"Prefilling {self.num_empty_slots} slots\")\n\n        prompts: List[FoundationPrompt] = [\n            self.prompt_queue.popleft()\n            for _ in range(min(self.num_empty_slots, len(self.prompt_queue)))\n        ]\n        non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None]\n        idxs_to_merge = non_active_idxs[: len(prompts)]\n\n        for i, prompt in zip(idxs_to_merge, prompts):\n            self.batch_prompt_mapping[i] = prompt.id\n\n        needs_bbox_embedding = torch.tensor(\n            [\n                p.task_name in [TaskNames.layout, TaskNames.table_structure]\n                for p in prompts\n            ],\n            dtype=torch.bool,\n        )\n\n        batch_input = self.prepare_input(\n            task_names=[p.task_name for p in prompts],\n            images=[p.image for p in prompts],\n            input_text=[p.text for p in prompts],\n            math_modes=[\n                p.math_mode for p in prompts\n            ],  # Pass math mode to the processor\n        )\n        processed_inputs = self.processor(\n            batch_input,\n            padding_side=\"left\",\n            device=self.model.device,\n            pad_to_multiple=self.pad_to_multiple,\n        )\n\n        input_ids = processed_inputs[\"input_ids\"].to(dtype=torch.long)\n        attention_mask = processed_inputs[\"attention_mask\"].to(dtype=torch.long)\n        position_ids = processed_inputs[\"position_ids\"].to(dtype=torch.long)\n        valid_batch_size = len(idxs_to_merge)\n\n        # Keep these off device until later\n        image_tiles = processed_inputs[\"image_tiles\"].to(dtype=self.model.dtype)\n        grid_thw = processed_inputs[\"grid_thw\"].to(dtype=torch.long)\n\n        if settings.FOUNDATION_STATIC_CACHE:\n            input_ids = self.pad_to_batch_size(\n                input_ids, batch_size=self.kv_cache.max_batch_size\n            )\n            attention_mask = self.pad_to_batch_size(\n                attention_mask, batch_size=self.kv_cache.max_batch_size\n            )\n            position_ids = self.pad_to_batch_size(\n                position_ids, batch_size=self.kv_cache.max_batch_size\n            )\n            needs_bbox_embedding = self.pad_to_batch_size(\n                needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size\n            )\n\n        # Move to device after padding\n        input_ids = input_ids.to(device=self.model.device)\n        attention_mask = attention_mask.to(device=self.model.device)\n        position_ids = position_ids.to(device=self.model.device)\n        needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device)\n\n        # Find text lengths of each\n        # Oddly, this is optimal on GPU - causes a 30% slowdown if \"optimized\"\n        # Be very careful with the type and device of this - can cause\n        # a big slowdown if put on device\n        is_special = (\n            (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu()\n        )  # (batch, seq_len)\n        text_lengths = []\n        for i in range(input_ids.shape[0]):\n            special_positions = is_special[i].nonzero(as_tuple=True)[0]\n            if len(special_positions) > 0:\n                # Assuming special tokens are contiguous at the start\n                prefix_len = special_positions[-1].item() + 1\n            else:\n                prefix_len = 0\n            text_lengths.append(input_ids.shape[1] - prefix_len)\n        text_lengths = torch.tensor(text_lengths, dtype=torch.long)\n\n        cache_position = self.get_cache_position(\n            input_ids.shape[1], attention_mask, prefill=True\n        )\n        with settings.INFERENCE_MODE():\n            image_embeddings = self.model.get_image_embeddings(\n                pixel_values=image_tiles,\n                grid_thw=grid_thw,\n                encoder_chunk_size=self.get_encoder_chunk_size(),\n                valid_batch_size=valid_batch_size,\n                max_batch_size=self.kv_cache.max_batch_size,\n            )\n            mark_step()\n\n            outputs = self.model(\n                input_ids=input_ids,\n                image_embeddings=image_embeddings,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                cache_position=cache_position,\n                inputs_embeds=None,\n                past_key_values=self.kv_cache,\n                use_cache=True,\n                encoder_chunk_size=self.get_encoder_chunk_size(),\n                cache_idxs=idxs_to_merge,\n                prefill=True,\n                num_valid_tokens=None,  # Not required during prefill\n                text_lengths=text_lengths,\n                valid_batch_size=valid_batch_size,\n                logits_to_keep=1,\n            )\n\n        # Process outputs\n        processed_outputs = self.process_outputs(\n            outputs, max_lookahead_tokens=max_lookahead_tokens\n        )\n        # Multi-token prediction\n        predicted_tokens = processed_outputs.input_ids.shape[1]\n        num_valid_tokens = (\n            torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long)\n            * predicted_tokens\n        )\n        num_predicted_tokens = (\n            torch.ones(\n                (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long\n            )\n            * predicted_tokens\n        )\n\n        self.kv_cache.prefill_attention_mask_update(\n            attention_mask, idxs_to_merge, valid_batch_size, text_lengths\n        )\n        self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths)\n\n        full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size\n\n        # If full batch, then we can ignore current_inputs\n        if current_inputs is None or full_batch:\n            new_seq_len = processed_outputs.input_ids.shape[1]\n            # No padding tokens - So we can safely set position_ids this way\n            position_ids = position_ids[:, -1:] + torch.arange(\n                1, new_seq_len + 1, device=position_ids.device\n            )\n            new_input = ContinuousBatchInput(\n                input_ids=processed_outputs.input_ids,\n                input_boxes=processed_outputs.bbox_preds,\n                position_ids=position_ids,\n                num_valid_tokens=num_valid_tokens,\n                num_predicted_tokens=num_predicted_tokens,\n                needs_bbox_embedding=needs_bbox_embedding,\n            )\n\n            return (\n                new_input,\n                processed_outputs,\n                range(processed_outputs.input_ids.shape[0]),\n            )\n\n        # Merging inputs for next steps\n        current_input_ids = current_inputs.input_ids\n        current_position_ids = current_inputs.position_ids\n        current_input_boxes = current_inputs.input_boxes\n\n        current_needs_bbox_embedding = current_inputs.needs_bbox_embedding\n\n        assert current_input_ids.shape[1] == current_position_ids.shape[1]\n        input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids(\n            processed_outputs.input_ids,\n            processed_outputs.bbox_preds,\n            position_ids,\n            new_seq_len=current_input_ids.shape[1],\n        )\n\n        current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size]\n        current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size]\n        current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size]\n\n        current_num_valid_tokens = current_inputs.num_valid_tokens\n        current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size]\n\n        current_num_predicted_tokens = current_inputs.num_predicted_tokens\n        current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[\n            :valid_batch_size\n        ]\n        current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[\n            :valid_batch_size\n        ]\n\n        new_input = ContinuousBatchInput(\n            input_ids=current_input_ids,\n            input_boxes=current_input_boxes,\n            position_ids=current_position_ids,\n            num_valid_tokens=current_num_valid_tokens,\n            num_predicted_tokens=current_num_predicted_tokens,\n            needs_bbox_embedding=current_needs_bbox_embedding,\n        )\n\n        return new_input, processed_outputs, idxs_to_merge\n\n    def get_max_image_token_count(\n        self, images: list[np.ndarray], tasks: List[TaskNames]\n    ) -> int:\n        def compute_scaled_size(\n            H: int, W: int, max_size: Tuple[int, int]\n        ) -> Tuple[int, int]:\n            max_W, max_H = max_size\n            min_W, min_H = (168, 168)\n\n            current_pixels = H * W\n            max_pixels = max_H * max_W\n            min_pixels = min_H * min_W\n            current_pixels = max(1, current_pixels)  # Avoid zero division\n\n            if current_pixels > max_pixels:\n                scale = (max_pixels / current_pixels) ** 0.5\n                return math.floor(H * scale), math.floor(W * scale)\n            elif current_pixels < min_pixels:\n                scale = (min_pixels / current_pixels) ** 0.5\n                return math.ceil(H * scale), math.ceil(W * scale)\n            return H, W\n\n        def get_tile_count(H: int, W: int, factor: int) -> int:\n            H_bar = math.ceil(H / factor) * factor\n            W_bar = math.ceil(W / factor) * factor\n            grid_h = H_bar / self.processor.patch_size\n            grid_w = W_bar // self.processor.patch_size\n            return grid_h * grid_w\n\n        max_tokens = 0\n        factor = self.processor.patch_size * self.processor.merge_size\n\n        for image, task in zip(images, tasks):\n            H, W = image.shape[:2]\n            max_size = self.tasks[task][\"img_size\"]\n            scaled_H, scaled_W = compute_scaled_size(H, W, max_size)\n            token_count = get_tile_count(scaled_H, scaled_W, factor) / (\n                self.processor.merge_size**2\n            )\n            max_tokens = max(max_tokens, token_count)\n\n        # Extra 10 to account for EOS/BOS/Rotation token etc.\n        return 10 + self.processor.num_register_tokens + int(max_tokens)\n\n    def prediction_loop(\n        self,\n        images: List[np.ndarray],\n        input_texts: List[str],\n        task_names: List[TaskNames],\n        batch_size: int | None = None,\n        max_tokens: int | None = None,\n        max_sliding_window: int | None = None,\n        math_mode: bool = True,\n        drop_repeated_tokens: bool = True,\n        max_lookahead_tokens: Optional[int] = None,\n        top_k: int = 0,\n        tqdm_desc: str = \"Recognizing Text\"\n    ) -> tuple:\n        allowed_tasks = self.tasks.keys()\n        assert all([task_name in allowed_tasks for task_name in task_names]), (\n            f\"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}\"\n        )\n\n        predicted_tokens = [[] for _ in range(len(images))]\n        scores = [[] for _ in range(len(images))]\n        topk_probs = [[] for _ in range(len(images))]\n\n        if batch_size is None:\n            batch_size = self.get_batch_size()\n\n        batch_size = min(len(images), batch_size)\n        current_inputs = None\n\n        max_image_tokens = self.get_max_image_token_count(images, task_names)\n        if max_sliding_window is None:\n            max_sliding_window = self.model.config.sliding_window\n        self.setup_cache(\n            batch_size,\n            max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0),\n            max_sliding_window=max_sliding_window,\n        )\n\n        batch_max_tokens = {}\n        for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)):\n            self.prompt_queue.append(\n                FoundationPrompt(\n                    id=idx, task_name=task, text=txt, image=img, math_mode=math_mode\n                )\n            )\n            batch_max_tokens[idx] = (\n                max_tokens\n                or settings.FOUNDATION_MAX_TOKENS\n                or self.tasks[task][\"max_tokens\"]\n            )\n\n        overall_max_tokens = max(batch_max_tokens.values())\n\n        pbar = tqdm(\n            total=len(self.prompt_queue),\n            desc=tqdm_desc,\n            disable=self.disable_tqdm,\n        )\n\n        batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6)\n        batch_pos = [0] * len(images)\n\n        while self.prompt_queue or self.num_active_slots > 0:\n            if (\n                self.num_empty_slots / batch_size\n            ) >= self.min_prefill_ratio and self.prompt_queue:\n                updated_inputs, outputs, merge_idxs = self.prefill(\n                    current_inputs, max_lookahead_tokens=0\n                )\n\n                predicted_tokens_cpu = outputs.preds.cpu()\n                scores_cpu = outputs.scores.cpu()\n                bbox_preds_cpu = outputs.bbox_preds.cpu()\n\n                if top_k > 0:\n                    batch_top_k_probs, batch_top_k_indices = torch.topk(\n                        outputs.token_probs, k=top_k, dim=-1\n                    )\n                    batch_top_k_probs_cpu = batch_top_k_probs.cpu()\n                    batch_top_k_indices_cpu = batch_top_k_indices.cpu()\n\n                for temp_idx, b_idx in enumerate(merge_idxs):\n                    if self.batch_prompt_mapping[b_idx] is not None:\n                        p_idx = self.batch_prompt_mapping[b_idx]\n                        seq_len = predicted_tokens_cpu.shape[1]\n                        for t_idx in range(seq_len):\n                            token = predicted_tokens_cpu[temp_idx, t_idx].item()\n                            predicted_tokens[p_idx].append(token)\n                            batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[\n                                temp_idx, t_idx\n                            ]\n                            batch_pos[p_idx] += 1\n                            scores[p_idx].append(scores_cpu[temp_idx, t_idx].item())\n\n                            if top_k > 0:\n                                top_k_scores = {\n                                    batch_top_k_indices_cpu[temp_idx, t_idx][\n                                        k\n                                    ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][\n                                        k\n                                    ].item()\n                                    for k in range(top_k)\n                                }\n                                topk_probs[p_idx].append(top_k_scores)\n\n                            if token in [\n                                self.processor.eos_token_id,\n                                self.processor.no_output_token,\n                            ]:\n                                self.batch_prompt_mapping[b_idx] = None\n                                pbar.update(1)\n                                break\n            else:\n                updated_inputs, outputs = self.decode(\n                    current_inputs, max_lookahead_tokens=max_lookahead_tokens\n                )\n                mark_step()\n\n                predicted_tokens_cpu = outputs.preds.cpu()\n                scores_cpu = outputs.scores.cpu()\n                bbox_preds_cpu = outputs.bbox_preds.cpu()\n\n                if top_k > 0:\n                    batch_top_k_probs, batch_top_k_indices = torch.topk(\n                        outputs.token_probs, k=top_k, dim=-1\n                    )\n                    batch_top_k_probs_cpu = batch_top_k_probs.cpu()\n                    batch_top_k_indices_cpu = batch_top_k_indices.cpu()\n\n                for b_idx, p_idx in self.batch_prompt_mapping.items():\n                    if p_idx is not None:\n                        seq_len = predicted_tokens_cpu.shape[1]\n                        num_tokens = updated_inputs.num_valid_tokens[b_idx].item()\n                        should_stop = False\n\n                        for t_idx in range(seq_len):\n                            # don't use multitoken prediction for lower confidence tokens\n                            if t_idx > 0 and num_tokens < seq_len:\n                                # roll so tokens are right aligned\n                                updated_inputs.input_ids[b_idx] = (\n                                    updated_inputs.input_ids[b_idx].roll(\n                                        shifts=seq_len - num_tokens, dims=0\n                                    )\n                                )\n                                # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens)\n                                break\n\n                            token = predicted_tokens_cpu[b_idx, t_idx].item()\n                            predicted_tokens[p_idx].append(token)\n                            batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[\n                                b_idx, t_idx\n                            ]\n                            batch_pos[p_idx] += 1\n                            scores[p_idx].append(scores_cpu[b_idx, t_idx].item())\n\n                            if top_k > 0:\n                                top_k_scores = {\n                                    batch_top_k_indices_cpu[temp_idx, t_idx][\n                                        k\n                                    ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][\n                                        k\n                                    ].item()\n                                    for k in range(top_k)\n                                }\n                                topk_probs[p_idx].append(top_k_scores)\n\n                            repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[\n                                p_idx\n                            ] or (\n                                drop_repeated_tokens\n                                and detect_repeat_token(predicted_tokens[p_idx])\n                                and task_names[p_idx]\n                                in [\n                                    TaskNames.ocr_with_boxes,\n                                    TaskNames.ocr_without_boxes,\n                                ]\n                            )\n                            if (\n                                token\n                                in [\n                                    self.processor.eos_token_id,\n                                    self.processor.pad_token_id,\n                                ]\n                                or repeats\n                            ):\n                                should_stop = True\n                                break\n\n                        if should_stop:\n                            self.batch_prompt_mapping[b_idx] = None\n                            pbar.update(1)\n\n            # Update inputs and mark XLA step\n            current_inputs = updated_inputs\n\n        pbar.close()\n\n        del self.kv_cache\n        self.kv_cache = None\n        torch.cuda.empty_cache()\n\n        return predicted_tokens, batch_bboxes, scores, topk_probs\n"
  },
  {
    "path": "surya/foundation/cache/__init__.py",
    "content": ""
  },
  {
    "path": "surya/foundation/cache/dynamic_ops.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple\nimport torch\nfrom transformers import PretrainedConfig\n\n\"\"\"\nSpecial cache class for the surya foundation model that supports - \n1) Static shape\n2) A custom sliding window, where image tokens stay in cache, and text tokens are popped\n3) Continuous batching - merging etc\n4) Attention mask management - To match with what's currently in the cache\n\nHeavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079\n\"\"\"\n\n\nclass DynamicOpsCache:\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        batch_size: int,\n        max_cache_len: int,\n        text_sliding_window: int,\n        device: int,\n        dtype: int,\n    ):\n        self.text_sliding_window = text_sliding_window\n        self.num_layers = config.num_hidden_layers\n        self.max_batch_size = batch_size\n        self.max_cache_len = max_cache_len\n        self.head_dim = (\n            getattr(config, \"head_dim\", None)\n            or config.hidden_size // config.num_attention_heads\n        )\n        self._dtype = dtype\n        self.num_key_value_heads = (\n            config.num_attention_heads\n            if getattr(config, \"num_key_value_heads\", None) is None\n            else config.num_key_value_heads\n        )\n\n        # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125\n        self.key_cache: list[torch.Tensor] = []\n        self.value_cache: list[torch.Tensor] = []\n        cache_shape = (\n            self.max_batch_size,\n            self.num_key_value_heads,\n            self.max_cache_len,\n            self.head_dim,\n        )\n        device = torch.device(device) if device is not None else None\n        for _ in range(config.num_hidden_layers):\n            new_layer_key_cache = torch.zeros(\n                cache_shape, dtype=self._dtype, device=device\n            )\n            new_layer_value_cache = torch.zeros(\n                cache_shape, dtype=self._dtype, device=device\n            )\n            torch._dynamo.mark_static_address(new_layer_key_cache)\n            torch._dynamo.mark_static_address(new_layer_value_cache)\n            self.key_cache.append(new_layer_key_cache)\n            self.value_cache.append(new_layer_value_cache)\n\n        self.attention_mask = torch.zeros(\n            (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long\n        )\n        self.text_token_counts = [\n            torch.zeros(self.max_batch_size, dtype=torch.long, device=device)\n            for _ in range(self.num_layers)\n        ]\n\n        self.dtype = dtype\n        self.device = device\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        prefill = cache_kwargs.get(\"prefill\", False)\n        update_fn = self._prefill_update if prefill else self._decode_update\n        return update_fn(\n            self.key_cache[layer_idx],\n            self.value_cache[layer_idx],\n            key_states,\n            value_states,\n            self.text_token_counts[layer_idx],\n            cache_kwargs,\n        )\n\n    def update_text_counts(\n        self,\n        merge_idxs: torch.Tensor,\n        valid_batch_size: torch.Tensor,\n        new_text_lens: torch.Tensor,\n    ):\n        new_text_len_tensor = new_text_lens.to(device=self.device)\n\n        for layer_idx in range(self.num_layers):\n            self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[\n                :valid_batch_size\n            ]\n\n    # Mirrors the logic from _prefill_update\n    # Logic is better explained in this funcrtion\n    def prefill_attention_mask_update(\n        self,\n        prefill_attention_mask: torch.Tensor,\n        merge_idxs: torch.Tensor,\n        valid_batch_mask: torch.Tensor,\n        text_lengths: List[int],\n    ):\n        seq_len = prefill_attention_mask.shape[1]\n        sliding_window = self.text_sliding_window\n        total_cache_len = self.max_cache_len\n        prefix_cache_space = total_cache_len - sliding_window\n\n        for batch_idx, cache_idx in enumerate(merge_idxs):\n            text_len = text_lengths[batch_idx]\n            prefix_len = seq_len - text_len\n            self.attention_mask[cache_idx] = 0  # Set default\n\n            assert prefix_len > 0, \"There are no prefix (image) tokens!\"\n\n            end_pos = prefix_cache_space\n            # Handle prefix part - Which may be left padded\n            if prefix_len <= prefix_cache_space:\n                start_pos = prefix_cache_space - prefix_len\n                self.attention_mask[cache_idx, start_pos:end_pos] = (\n                    prefill_attention_mask[batch_idx, :prefix_len]\n                )\n            else:\n                self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[\n                    batch_idx, prefix_len - prefix_cache_space : prefix_len\n                ]\n\n            # Handle text part, keeping sliding window in consideration\n            # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here\n            if text_len > 0:\n                text_cache_start = prefix_cache_space\n                if text_len <= sliding_window:\n                    self.attention_mask[\n                        cache_idx, text_cache_start : text_cache_start + text_len\n                    ] = 1\n                else:\n                    self.attention_mask[cache_idx, -sliding_window:] = 1\n\n    # Slow impl for now - Prefill time is dominated by the large sequence length forward pass\n    def _prefill_update(\n        self,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        text_token_counts: torch.Tensor,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        cache_idxs: List[int] = cache_kwargs.get(\"cache_idxs\", None)\n        text_lengths: List[int] = cache_kwargs.get(\"text_lengths\", None)\n        assert cache_idxs is not None, \"cache_idxs must be specified during prefill\"\n        assert text_lengths is not None, \"text_lengths must be specified during prefill\"\n\n        _, _, seq_len, _ = key_states.shape\n        total_cache_len = self.max_cache_len\n        sliding_window = self.text_sliding_window\n        prefix_cache_space = total_cache_len - sliding_window\n\n        for batch_idx, cache_idx in enumerate(cache_idxs):\n            text_len = text_lengths[batch_idx]\n            prefix_len = seq_len - text_len\n\n            ###### Handle Image Tokens (Prefix) #####\n            # Place image tokens in appropriate cache space, aligned to the **right edge**\n            assert prefix_len > 0, \"There are no prefix (image) tokens!\"\n\n            # prefix_len may be greater than the prefix cache space due to left padding - This happens when\n            # a different batch element has a large input text during prefill, causing others to have a lot of\n            # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since\n            # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding\n            end_pos = prefix_cache_space\n            if prefix_len <= prefix_cache_space:\n                start_pos = prefix_cache_space - prefix_len\n                key_cache[cache_idx, :, start_pos:end_pos] = key_states[\n                    batch_idx, :, :prefix_len\n                ]\n                value_cache[cache_idx, :, start_pos:end_pos] = value_states[\n                    batch_idx, :, :prefix_len\n                ]\n            else:\n                key_cache[cache_idx, :, :end_pos] = key_states[\n                    batch_idx, :, prefix_len - prefix_cache_space : prefix_len\n                ]\n                value_cache[cache_idx, :, :end_pos] = value_states[\n                    batch_idx, :, prefix_len - prefix_cache_space : prefix_len\n                ]\n\n            ###### Handle Text Tokens #####\n            # Text tokens start at the **left edge** of sliding window cache space\n            if text_len > 0:\n                text_cache_start = prefix_cache_space\n\n                if text_len <= sliding_window:\n                    key_cache[\n                        cache_idx, :, text_cache_start : text_cache_start + text_len\n                    ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len]\n                    value_cache[\n                        cache_idx, :, text_cache_start : text_cache_start + text_len\n                    ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len]\n                else:\n                    start_in_text = text_len - sliding_window\n                    key_cache[\n                        cache_idx,\n                        :,\n                        text_cache_start : text_cache_start + sliding_window,\n                    ] = key_states[\n                        batch_idx, :, prefix_len + start_in_text : prefix_len + text_len\n                    ]\n                    value_cache[\n                        cache_idx,\n                        :,\n                        text_cache_start : text_cache_start + sliding_window,\n                    ] = value_states[\n                        batch_idx, :, prefix_len + start_in_text : prefix_len + text_len\n                    ]\n\n        # Return the full key/value states (not just cached) for use in subsequent layers\n        return key_states, value_states\n\n    # \"\"\"\n    # Matches the logic of the decode update, but needs to be called before the updates\n    # since some parts of the model depend on the attention mask\n    # \"\"\"\n    def decode_attention_mask_update(\n        self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]\n    ):\n        sliding_window = self.text_sliding_window\n        text_cache_start = self.max_cache_len - sliding_window\n\n        # Using text_token_counts of first layer, should be same for all though\n        current_text_lens = self.text_token_counts[0]\n        cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device)\n\n        # Get current text lengths for the relevant cache indices\n        current_lens = current_text_lens[cache_idxs_tensor]\n        new_text_lens = current_lens + num_valid_tokens\n        is_full = new_text_lens > sliding_window\n\n        # Handle full caches - set entire sliding window to 1\n        if is_full.any():\n            full_mask = is_full\n            full_cache_idxs = cache_idxs_tensor[full_mask]\n            self.attention_mask[full_cache_idxs, text_cache_start:] = 1\n\n        # Handle non-full caches - set specific ranges to 1\n        if (~is_full).any():\n            non_full_mask = ~is_full\n            non_full_cache_idxs = cache_idxs_tensor[non_full_mask]\n            non_full_current_lens = current_lens[non_full_mask]\n            non_full_valid_tokens = num_valid_tokens[non_full_mask]\n\n            max_valid_tokens = (\n                non_full_valid_tokens.max().item()\n                if len(non_full_valid_tokens) > 0\n                else 0\n            )\n            if max_valid_tokens > 0:\n                batch_size = len(non_full_cache_idxs)\n                offset_range = torch.arange(\n                    max_valid_tokens, device=current_text_lens.device\n                )\n                batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1)\n                start_positions = non_full_current_lens.unsqueeze(1)\n                valid_token_counts = non_full_valid_tokens.unsqueeze(1)\n\n                position_indices = start_positions + batch_offsets\n                valid_mask = batch_offsets < valid_token_counts\n\n                row_indices = non_full_cache_idxs.unsqueeze(1).expand(\n                    -1, max_valid_tokens\n                )[valid_mask]\n                col_indices = text_cache_start + position_indices[valid_mask]\n\n                self.attention_mask[row_indices, col_indices] = 1\n\n    \"\"\"\n    Static cache update\n    - respects per-batch text token limits\n    - per-batch valid token lengths (right-padded inputs)\n\n    kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim]\n    They may have different `true` lengths, to account for multi token preds, or beacon tokens\n    Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number\n    of actual (non-padded) tokens to add per batch element.\n    \"\"\"\n\n    def _decode_update(\n        self,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        text_token_counts: torch.Tensor,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        num_valid_tokens: torch.Tensor = cache_kwargs.get(\n            \"num_valid_tokens\"\n        )  # shape: (B,)\n        assert num_valid_tokens is not None, (\n            \"`num_valid_tokens` must be provided in `cache_kwargs`\"\n        )\n        device = key_states.device\n\n        batch_size, num_head, seq_len, head_dim = key_states.shape\n        sliding_window = self.text_sliding_window\n        max_cache_len = self.max_cache_len\n        cache_text_start = max_cache_len - sliding_window\n        new_text_lengths = text_token_counts + num_valid_tokens\n        slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0)\n        needs_rotate = slide_amounts > 0\n\n        # Rotate the cache if needed\n        if torch.any(needs_rotate):\n            k_slice = key_cache[:, :, -sliding_window:]  # shape: [B, H, W, D]\n            v_slice = value_cache[:, :, -sliding_window:]  # same shape\n\n            cache_indices = (\n                torch.arange(sliding_window, device=device)\n                .unsqueeze(0)\n                .repeat(batch_size, 1)\n            )  # [B, W]\n            rolled_indices = (\n                cache_indices + slide_amounts.unsqueeze(1)\n            ) % sliding_window  # [B, W]\n\n            # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice\n            rolled_indices = (\n                rolled_indices.unsqueeze(1)\n                .unsqueeze(-1)\n                .expand(-1, num_head, -1, head_dim)\n            )\n\n            k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices)\n            v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices)\n\n            key_cache[:, :, -sliding_window:] = k_slice_rolled\n            value_cache[:, :, -sliding_window:] = v_slice_rolled\n\n        # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence\n        insert_positions = torch.where(\n            needs_rotate,\n            max_cache_len - num_valid_tokens,\n            text_token_counts + cache_text_start,\n        )\n\n        max_tokens = num_valid_tokens.max().item()\n        offsets = torch.arange(max_tokens, device=device).unsqueeze(0)  # [1, max_T]\n        valid_mask = offsets < num_valid_tokens.unsqueeze(1)  # [B, max_T]\n        src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets  # [B, max_T]\n        src_indices = src_indices.clamp(max=seq_len - 1)  # safety\n\n        tgt_indices = insert_positions.unsqueeze(1) + offsets  # [B, max_T]\n        tgt_indices = tgt_indices.clamp(max=max_cache_len - 1)  # safety\n\n        src_idx_exp = (\n            src_indices.unsqueeze(1)\n            .unsqueeze(-1)\n            .expand(batch_size, num_head, max_tokens, head_dim)\n        )\n        tgt_idx_exp = (\n            tgt_indices.unsqueeze(1)\n            .unsqueeze(-1)\n            .expand(batch_size, num_head, max_tokens, head_dim)\n        )\n        valid_mask_exp = (\n            valid_mask.unsqueeze(1)\n            .unsqueeze(-1)\n            .expand(batch_size, num_head, max_tokens, head_dim)\n        )\n\n        k_src = torch.gather(key_states, 2, src_idx_exp)\n        v_src = torch.gather(value_states, 2, src_idx_exp)\n        k_src = k_src * valid_mask_exp\n        v_src = v_src * valid_mask_exp\n\n        # Write into cache\n        key_cache.scatter_(2, tgt_idx_exp, k_src)\n        value_cache.scatter_(2, tgt_idx_exp, v_src)\n\n        # In-place edit - Mutates\n        text_token_counts += num_valid_tokens\n        text_token_counts.clamp_(max=sliding_window)\n\n        return key_cache, value_cache\n\n    # We have a non-uniform cache, so its better to not return it and handle any logic\n    # that requires this ourselves\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        raise NotImplementedError()\n"
  },
  {
    "path": "surya/foundation/cache/static_ops.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple\nimport torch\nfrom transformers import PretrainedConfig\n\nfrom surya.foundation.cache.dynamic_ops import DynamicOpsCache\n\n\"\"\"\nSpecial cache class for the surya foundation model that supports - \n1) Static shape\n2) A custom sliding window, where image tokens stay in cache, and text tokens are popped\n3) Continuous batching - merging etc\n4) Attention mask management - To match with what's currently in the cache\n\nHeavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079\n\"\"\"\n\n\nclass StaticOpsCache(DynamicOpsCache):\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        batch_size: int,\n        max_cache_len: int,\n        text_sliding_window: int,\n        device: int,\n        dtype: int,\n    ):\n        self.text_sliding_window = text_sliding_window\n        self.num_layers = config.num_hidden_layers\n        self.max_batch_size = batch_size\n        self.max_cache_len = max_cache_len\n        self.head_dim = (\n            getattr(config, \"head_dim\", None)\n            or config.hidden_size // config.num_attention_heads\n        )\n        self._dtype = dtype\n        self.num_key_value_heads = (\n            config.num_attention_heads\n            if getattr(config, \"num_key_value_heads\", None) is None\n            else config.num_key_value_heads\n        )\n\n        # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125\n        self.key_cache: list[torch.Tensor] = []\n        self.value_cache: list[torch.Tensor] = []\n        cache_shape = (\n            self.max_batch_size,\n            self.num_key_value_heads,\n            self.max_cache_len,\n            self.head_dim,\n        )\n        device = torch.device(device) if device is not None else None\n        for _ in range(config.num_hidden_layers):\n            new_layer_key_cache = torch.zeros(\n                cache_shape, dtype=self._dtype, device=device\n            )\n            new_layer_value_cache = torch.zeros(\n                cache_shape, dtype=self._dtype, device=device\n            )\n            torch._dynamo.mark_static_address(new_layer_key_cache)\n            torch._dynamo.mark_static_address(new_layer_value_cache)\n            self.key_cache.append(new_layer_key_cache)\n            self.value_cache.append(new_layer_value_cache)\n\n        self.attention_mask = torch.zeros(\n            (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long\n        )\n        self.text_token_counts = [\n            torch.zeros(self.max_batch_size, dtype=torch.long, device=device)\n            for _ in range(self.num_layers)\n        ]\n\n        self.dtype = dtype\n        self.device = device\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        prefill = cache_kwargs.get(\"prefill\", False)\n        update_fn = self._prefill_update if prefill else self._decode_update\n        return update_fn(\n            self.key_cache[layer_idx],\n            self.value_cache[layer_idx],\n            key_states,\n            value_states,\n            self.text_token_counts[layer_idx],\n            cache_kwargs,\n        )\n\n    def _prefill_update(\n        self,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        text_token_counts: torch.Tensor,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        cache_idxs: torch.tensor = cache_kwargs.get(\"cache_idxs\", None)\n        text_lengths: List[int] = cache_kwargs.get(\"text_lengths\", None)\n        assert cache_idxs is not None, \"cache_idxs must be specified during prefill\"\n        assert text_lengths is not None, \"text_lengths must be specified during prefill\"\n\n        cache_idx_length = len(cache_idxs)\n        full_batch = len(cache_idxs) == self.max_batch_size\n\n        # Insert key and value states at the end of the cache\n        new_tokens = key_states.shape[2]\n\n        # Direct right-aligned assignment\n        if full_batch:\n            key_cache[:, :, -new_tokens:] = key_states\n            value_cache[:, :, -new_tokens:] = value_states\n        else:\n            key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length]\n            value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length]\n\n        return key_states, value_states\n\n    # \"\"\"\n    # Matches the logic of the decode update, but needs to be called before the updates\n    # since some parts of the model depend on the attention mask\n    # \"\"\"\n    def decode_attention_mask_update(\n        self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]\n    ):\n        max_valid_tokens = num_valid_tokens.max().item()\n        if max_valid_tokens == 0:\n            # If no valid tokens, we don't need to update the attention mask\n            return\n\n        # Shift the attention mask to the left by max_valid_tokens\n        self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1)\n        self.attention_mask[:, -max_valid_tokens:] = (\n            1  # Full attention to all new tokens\n        )\n\n    # Mirrors the logic from _prefill_update\n    def prefill_attention_mask_update(\n        self,\n        attention_mask: torch.Tensor,\n        merge_idxs: torch.Tensor,\n        valid_batch_size: torch.Tensor,\n        text_lengths: List[int],\n    ):\n        # Set from -(image_length + text_length) to end to 1 for each batch element\n        seq_len = attention_mask.shape[1]\n        self.attention_mask[merge_idxs] = (\n            0  # Reset the attention mask for the current batch elements\n        )\n        self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size]\n\n    def _decode_update(\n        self,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        text_token_counts: torch.Tensor,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # Naive, always assumes we'll roll by a fixed amount\n        # Needs left padding with beacons to work properly\n\n        num_valid_tokens: torch.Tensor = cache_kwargs.get(\n            \"num_valid_tokens\"\n        )  # shape: (B,)\n        assert num_valid_tokens is not None, (\n            \"`num_valid_tokens` must be provided in `cache_kwargs`\"\n        )\n        # (B, H, L, D)\n\n        valid_tokens = key_states.shape[2]\n\n        key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2))\n        value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2))\n\n        key_cache[:, :, -valid_tokens:, :] = key_states\n        value_cache[:, :, -valid_tokens:, :] = value_states\n\n        # In-place edit - Mutates\n        text_token_counts += num_valid_tokens\n        text_token_counts.clamp_(max=self.text_sliding_window)\n        return key_cache, value_cache\n\n    # The attention mask managed by our kv cache automatically masks the tokens\n    # in the cache, so we can return full length for HF to use in other places\n    # This is mainly utilized in the cache_positions creation\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        return self.max_cache_len\n"
  },
  {
    "path": "surya/foundation/loader.py",
    "content": "from typing import Optional\n\nimport torch\nfrom transformers.utils import is_flash_attn_2_available\n\nfrom surya.common.load import ModelLoader\nfrom surya.common.surya.config import SuryaModelConfig\nfrom surya.common.surya import SuryaModel, SuryaXLAModel\nfrom surya.common.surya.processor import SuryaOCRProcessor\nfrom surya.common.surya.processor.tokenizer import SuryaOCRTokenizer\nfrom surya.common.util import is_flash_attn_2_supported\nfrom surya.common.xla import get_compile_args\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n\nclass FoundationModelLoader(ModelLoader):\n    def __init__(self, checkpoint: Optional[str] = None):\n        super().__init__(checkpoint)\n\n        if self.checkpoint is None:\n            self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT\n\n    def model(\n        self,\n        device=settings.TORCH_DEVICE_MODEL,\n        dtype=None,\n        attention_implementation: Optional[str] = None,\n    ) -> SuryaModel:\n        if device is None:\n            device = settings.TORCH_DEVICE_MODEL\n        if dtype is None:\n            # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports\n            # emulated bf16, but falls back to very slow kernels, especially for SDPA\n            dtype = settings.MODEL_DTYPE_BFLOAT\n            if device == \"cuda\" and not torch.cuda.is_bf16_supported(\n                including_emulation=False\n            ):\n                # If the device is cuda, we check if bf16 is supported, and if not, we use float16\n                dtype = settings.MODEL_DTYPE\n        elif dtype == torch.float16:\n            dtype = torch.bfloat16  # Model weights in bfloat16\n\n        config = SuryaModelConfig.from_pretrained(self.checkpoint)\n\n        if attention_implementation is not None:\n            config.decoder._attn_implementation = attention_implementation\n            config.vision_encoder._attn_implementation = attention_implementation\n        elif is_flash_attn_2_available() and is_flash_attn_2_supported(device):\n            config.decoder._attn_implementation = \"flash_attention_2\"\n            config.vision_encoder._attn_implementation = \"flash_attention_2\"\n        elif device == \"xla\":\n            config.decoder._attn_implementation = \"sdpa\"\n            config.vision_encoder._attn_implementation = \"sdpa\"\n        else:\n            config.decoder._attn_implementation = \"sdpa\"\n            config.vision_encoder._attn_implementation = \"sdpa\"\n\n        model_cls = SuryaModel\n        if device == \"xla\":\n            model_cls = SuryaXLAModel\n\n        config._attn_implementation_autoset = True\n        config.vision_encoder._attn_implementation_autoset = True\n        config.decoder._attn_implementation_autoset = True\n\n        model = model_cls.from_pretrained(\n            self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True\n        ).to(device)\n        model = model.eval()\n\n        if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION:\n            torch._dynamo.config.cache_size_limit = 1000\n            torch._dynamo.config.suppress_errors = True\n            torch._dynamo.config.specialize_int = False\n            torch._dynamo.config.allow_unspec_int_on_nn_module = True\n            torch._dynamo.config.capture_scalar_outputs = True\n            torch._dynamo.config.recompile_limit = 32\n\n            logger.info(\n                f\"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}\"\n            )\n            compile_args = get_compile_args(device)\n            model.vision_encoder = torch.compile(model.vision_encoder, **compile_args)\n            model.decoder = torch.compile(model.decoder, **compile_args)\n\n        logger.debug(\n            f\"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}.\"\n        )\n        return model\n\n    def processor(\n        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT\n    ) -> SuryaOCRProcessor:\n        config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint)\n\n        ocr_tokenizer = SuryaOCRTokenizer(\n            special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint\n        )\n\n        processor = SuryaOCRProcessor(\n            ocr_tokenizer=ocr_tokenizer,\n            blank_bbox_token_id=config.blank_bbox_token_id,\n            num_register_tokens=config.num_register_tokens,\n            sequence_length=None,\n            patch_size=config.vision_encoder.patch_size,\n            merge_size=config.vision_encoder.spatial_merge_size,\n            model_device=device,\n            num_beacon_tokens=config.num_beacon_tokens,\n            beacon_token_interval=config.beacon_token_interval,\n        )\n\n        return processor\n"
  },
  {
    "path": "surya/foundation/util.py",
    "content": "from typing import List, Tuple\nimport numpy as np\nimport torch\n\ndef detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40):\n    if len(predicted_tokens) < max_repeats:\n        return False\n\n    # Detect repeats containing 1 or 2 tokens\n    last_n = predicted_tokens[-max_repeats:]\n    unique_tokens = len(set(last_n))\n    if unique_tokens > 5:\n        return False\n\n    return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens]\n\ndef prediction_to_polygon_batch(\n    pred: torch.Tensor,\n    img_sizes: List[Tuple[int, int]],\n    bbox_scaler,\n    skew_scaler,\n    skew_min=0.001,\n):\n    img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to(\n        pred.device\n    )\n    w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None]\n    h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None]\n\n    cx = pred[:, :, 0]\n    cy = pred[:, :, 1]\n    width = pred[:, :, 2]\n    height = pred[:, :, 3]\n\n    x1 = cx - width / 2\n    y1 = cy - height / 2\n    x2 = cx + width / 2\n    y2 = cy + height / 2\n\n    skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2)\n    skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2)\n\n    skew_x[torch.abs(skew_x) < skew_min] = 0\n    skew_y[torch.abs(skew_y) < skew_min] = 0\n\n    polygons_flat = torch.stack(\n        [\n            x1 - skew_x,\n            y1 - skew_y,\n            x2 - skew_x,\n            y1 + skew_y,\n            x2 + skew_x,\n            y2 + skew_y,\n            x1 + skew_x,\n            y2 - skew_y,\n        ],\n        dim=2,\n    )\n\n    batch_size, seq_len, _ = pred.shape\n    polygons = polygons_flat.view(batch_size, seq_len, 4, 2)\n\n    polygons[:, :, :, 0] *= w_scale\n    polygons[:, :, :, 1] *= h_scale\n\n    return polygons"
  },
  {
    "path": "surya/input/load.py",
    "content": "from typing import List\nimport PIL\n\nfrom surya.input.processing import open_pdf, get_page_images\nfrom surya.logging import get_logger\nfrom surya.settings import settings\nimport os\nimport filetype\nfrom PIL import Image\nimport json\n\nlogger = get_logger()\n\n\ndef get_name_from_path(path):\n    return os.path.basename(path).split(\".\")[0]\n\n\ndef load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI):\n    doc = open_pdf(pdf_path)\n    last_page = len(doc)\n\n    if page_range:\n        assert all([0 <= page < last_page for page in page_range]), (\n            f\"Invalid page range: {page_range}\"\n        )\n    else:\n        page_range = list(range(last_page))\n\n    images = get_page_images(doc, page_range, dpi=dpi)\n    doc.close()\n    names = [get_name_from_path(pdf_path) for _ in page_range]\n    return images, names\n\n\ndef load_image(image_path):\n    image = Image.open(image_path).convert(\"RGB\")\n    name = get_name_from_path(image_path)\n    return [image], [name]\n\n\ndef load_from_file(\n    input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI\n):\n    input_type = filetype.guess(input_path)\n    if input_type and input_type.extension == \"pdf\":\n        return load_pdf(input_path, page_range, dpi=dpi)\n    else:\n        return load_image(input_path)\n\n\ndef load_from_folder(\n    folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI\n):\n    image_paths = [\n        os.path.join(folder_path, image_name)\n        for image_name in os.listdir(folder_path)\n        if not image_name.startswith(\".\")\n    ]\n    image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]\n\n    images = []\n    names = []\n    for path in image_paths:\n        extension = filetype.guess(path)\n        if extension and extension.extension == \"pdf\":\n            image, name = load_pdf(path, page_range, dpi=dpi)\n            images.extend(image)\n            names.extend(name)\n        else:\n            try:\n                image, name = load_image(path)\n                images.extend(image)\n                names.extend(name)\n            except PIL.UnidentifiedImageError:\n                logger.warning(f\"Could not load image {path}\")\n                continue\n    return images, names\n\n\ndef load_lang_file(lang_path, names):\n    with open(lang_path, \"r\") as f:\n        lang_dict = json.load(f)\n    return [lang_dict[name].copy() for name in names]\n"
  },
  {
    "path": "surya/input/processing.py",
    "content": "from typing import List\n\nimport cv2\nimport numpy as np\nimport pypdfium2\nfrom PIL import Image\n\nfrom surya.logging import get_logger\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n\ndef convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:\n    new_images = []\n    for image in images:\n        if image.mode != \"RGB\":\n            image = image.convert(\"RGB\")\n        new_images.append(image)\n    return new_images\n\n\ndef open_pdf(pdf_filepath):\n    return pypdfium2.PdfDocument(pdf_filepath)\n\n\ndef get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):\n    images = [\n        doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices\n    ]\n    images = [image.convert(\"RGB\") for image in images]\n    return images\n\n\ndef slice_bboxes_from_image(image: np.ndarray, bboxes):\n    lines = []\n    for bbox in bboxes:\n        bbox = np.array(bbox, dtype=np.int32)\n        bbox = np.clip(bbox, 0, None)  # Ensure no negative indices\n        # Ensure bbox is within the image bounds\n        if bbox[3] <= bbox[1]:\n            bbox[3] = bbox[1] + 1\n\n        if bbox[2] <= bbox[0]:\n            bbox[2] = bbox[0] + 1\n\n        bbox[2] = min(bbox[2], image.shape[1])\n        bbox[3] = min(bbox[3], image.shape[0])\n\n        line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()\n        if line.size == 0:\n            logger.warning(f\"Warning: found an empty line with bbox {bbox}\")\n        lines.append(line)\n    return lines\n\n\ndef slice_polys_from_image(image: np.ndarray, polys):\n    lines = []\n    for idx, poly in enumerate(polys):\n        lines.append(slice_and_pad_poly(image, poly))\n    return lines\n\n\ndef slice_and_pad_poly(image_array: np.array, coordinates):\n    # Draw polygon onto mask\n    coordinates = [(corner[0], corner[1]) for corner in coordinates]\n    bbox = [\n        min([x[0] for x in coordinates]),\n        min([x[1] for x in coordinates]),\n        max([x[0] for x in coordinates]),\n        max([x[1] for x in coordinates]),\n    ]\n\n    # We mask out anything not in the polygon\n    cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()\n    height, width = cropped_polygon.shape[:2]\n\n    coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]\n\n    # Validate the cropped area\n    if any(\n        [\n            bbox[3] <= bbox[1] or bbox[2] <= bbox[0],\n            len(coordinates) < 3,\n            height == 0,\n            width == 0,\n        ]\n    ):\n        return cropped_polygon\n\n    # Pad the area outside the polygon with the pad value\n    try:\n        mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)\n        cv2.fillPoly(mask, [np.int32(coordinates)], 1)\n        mask = np.stack([mask] * 3, axis=-1)\n\n        cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE\n    except cv2.error as e:\n        logger.warning(f\"Warning: issue while processing polygon: {e}\")\n\n    return cropped_polygon\n"
  },
  {
    "path": "surya/layout/__init__.py",
    "content": "from typing import List\n\nfrom PIL import Image\n\nfrom surya.common.predictor import BasePredictor\nfrom surya.layout.schema import LayoutBox, LayoutResult\nfrom surya.settings import settings\nfrom surya.foundation import FoundationPredictor, TaskNames\nfrom surya.foundation.util import prediction_to_polygon_batch\nfrom surya.input.processing import convert_if_not_rgb\nfrom surya.layout.label import LAYOUT_PRED_RELABEL\nfrom surya.common.util import clean_boxes\n\n\nclass LayoutPredictor(BasePredictor):\n    batch_size = settings.LAYOUT_BATCH_SIZE\n    default_batch_sizes = {\"cpu\": 4, \"mps\": 4, \"cuda\": 32, \"xla\": 16}\n\n    # Override base init - Do not load model\n    def __init__(self, foundation_predictor: FoundationPredictor):\n        self.foundation_predictor = foundation_predictor\n        self.processor = self.foundation_predictor.processor\n        self.bbox_size = self.foundation_predictor.model.config.bbox_size\n        self.tasks = self.foundation_predictor.tasks\n\n    # Special handling for disable tqdm to pass into foundation predictor\n    # Make sure they are kept in sync\n    @property\n    def disable_tqdm(self) -> bool:\n        return super().disable_tqdm\n\n    @disable_tqdm.setter\n    def disable_tqdm(self, value: bool) -> None:\n        self._disable_tqdm = bool(value)\n        self.foundation_predictor.disable_tqdm = bool(value)\n\n    def __call__(\n        self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5\n    ) -> List[LayoutResult]:\n        assert all([isinstance(image, Image.Image) for image in images])\n        if batch_size is None:\n            batch_size = self.get_batch_size()\n\n        if len(images) == 0:\n            return []\n\n        images = convert_if_not_rgb(images)\n        images = [self.processor.image_processor(image) for image in images]\n\n        predicted_tokens, batch_bboxes, scores, topk_scores = (\n            self.foundation_predictor.prediction_loop(\n                images=images,\n                input_texts=[\"\" for _ in range(len(images))],\n                task_names=[TaskNames.layout for _ in range(len(images))],\n                batch_size=batch_size,\n                max_lookahead_tokens=0,  # Do not do MTP for layout\n                top_k=5,\n                max_sliding_window=576,\n                max_tokens=500,\n                tqdm_desc=\"Recognizing Layout\"\n            )\n        )\n\n        image_sizes = [img.shape for img in images]\n        predicted_polygons = prediction_to_polygon_batch(\n            batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2\n        )\n        layout_results = []\n        for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip(\n            images, predicted_tokens, predicted_polygons, scores, topk_scores\n        ):\n            layout_boxes = []\n            for z, (tok, poly, score, tok_topk) in enumerate(\n                zip(image_tokens, image_polygons, image_scores, image_topk_scores)\n            ):\n                if tok == self.processor.eos_token_id:\n                    break\n\n                predicted_label = self.processor.decode([tok], \"layout\")\n                label = LAYOUT_PRED_RELABEL.get(predicted_label)\n                if not label:\n                    # Layout can sometimes return unknown labels from other objectives\n                    continue\n\n                top_k_dict = {}\n                for k, v in tok_topk.items():\n                    topk_label = self.processor.decode([k], \"layout\")\n                    if topk_label in LAYOUT_PRED_RELABEL:\n                        topk_label = LAYOUT_PRED_RELABEL[topk_label]\n                    if not topk_label.strip():\n                        continue\n                    top_k_dict.update({topk_label: v})\n                layout_boxes.append(\n                    LayoutBox(\n                        polygon=poly.tolist(),\n                        label=label,\n                        position=z,\n                        top_k=top_k_dict,\n                        confidence=score,\n                    )\n                )\n            layout_boxes = clean_boxes(layout_boxes)\n            layout_results.append(\n                LayoutResult(\n                    bboxes=layout_boxes,\n                    image_bbox=[0, 0, image.shape[1], image.shape[0]],\n                )  # Image is numpy array\n            )\n\n        assert len(layout_results) == len(images)\n        return layout_results\n"
  },
  {
    "path": "surya/layout/label.py",
    "content": "LAYOUT_PRED_RELABEL = {\n    \"<page-header>\": \"PageHeader\",\n    \"<page-footer>\": \"PageFooter\",\n    \"<footnote>\": \"Footnote\",\n    \"<image>\": \"Picture\",\n    \"<figure>\": \"Figure\",\n    \"<text>\": \"Text\",\n    \"<caption>\": \"Caption\",\n    \"<list-item>\": \"ListItem\",\n    \"<section-header>\": \"SectionHeader\",\n    \"<table>\": \"Table\",\n    \"<table-of-contents>\": \"TableOfContents\",\n    \"<form>\": \"Form\",\n    \"<equation-block>\": \"Equation\",\n    \"<code-block>\": \"Code\",\n    \"<complex-block>\": \"Figure\",\n}\n"
  },
  {
    "path": "surya/layout/schema.py",
    "content": "from typing import Optional, Dict, List\n\nfrom pydantic import BaseModel\n\nfrom surya.common.polygon import PolygonBox\n\n\nclass LayoutBox(PolygonBox):\n    label: str\n    position: int\n    top_k: Optional[Dict[str, float]] = None\n\n\nclass LayoutResult(BaseModel):\n    bboxes: List[LayoutBox]\n    image_bbox: List[float]\n    sliced: bool = False  # Whether the image was sliced and reconstructed\n"
  },
  {
    "path": "surya/logging.py",
    "content": "import logging\nimport warnings\nfrom surya.settings import settings\n\n\ndef configure_logging():\n    logger = get_logger()\n\n    # Remove any existing handlers to prevent duplicates\n    for handler in logger.handlers[:]:\n        logger.removeHandler(handler)\n\n    # Add our handler\n    handler = logging.StreamHandler()\n    formatter = logging.Formatter(\"%(asctime)s [%(levelname)s] %(name)s: %(message)s\")\n    handler.setFormatter(formatter)\n    logger.addHandler(handler)\n\n    # Prevent propagation to parent loggers to avoid double logging\n    logger.propagate = False\n\n    logger.setLevel(settings.LOGLEVEL)\n    warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\n\ndef get_logger():\n    return logging.getLogger(\"surya\")\n"
  },
  {
    "path": "surya/models.py",
    "content": "from typing import Dict\n\nimport torch\n\nfrom surya.common.predictor import BasePredictor\nfrom surya.detection import DetectionPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.logging import configure_logging\nfrom surya.ocr_error import OCRErrorPredictor\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.table_rec import TableRecPredictor\nfrom surya.settings import settings\n\nconfigure_logging()\n\n\ndef load_predictors(\n    device: str | torch.device | None = None, dtype: torch.dtype | str | None = None\n) -> Dict[str, BasePredictor]:\n    return {\n        \"layout\": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)),\n        \"ocr_error\": OCRErrorPredictor(device=device, dtype=dtype),\n        \"recognition\": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)),\n        \"detection\": DetectionPredictor(device=device, dtype=dtype),\n        \"table_rec\": TableRecPredictor(device=device, dtype=dtype),\n    }\n"
  },
  {
    "path": "surya/ocr_error/__init__.py",
    "content": "import math\nfrom typing import List, Optional\n\nfrom tqdm import tqdm\n\nfrom surya.common.predictor import BasePredictor\nfrom surya.ocr_error.loader import OCRErrorModelLoader\nfrom surya.ocr_error.model.config import ID2LABEL\nfrom surya.ocr_error.schema import OCRErrorDetectionResult\nfrom surya.settings import settings\nfrom surya.common.xla import mark_step\n\n\nclass OCRErrorPredictor(BasePredictor):\n    model_loader_cls = OCRErrorModelLoader\n    batch_size = settings.OCR_ERROR_BATCH_SIZE\n    default_batch_sizes = {\"cpu\": 8, \"mps\": 8, \"cuda\": 64, \"xla\": 32}\n\n    def __call__(self, texts: List[str], batch_size: Optional[int] = None):\n        return self.batch_ocr_error_detection(texts, batch_size)\n\n    def batch_ocr_error_detection(\n        self, texts: List[str], batch_size: Optional[int] = None\n    ):\n        if batch_size is None:\n            batch_size = self.get_batch_size()\n\n        num_batches = math.ceil(len(texts) / batch_size)\n        texts_processed = self.processor(\n            texts, padding=\"longest\", truncation=True, return_tensors=\"pt\"\n        )\n        predictions = []\n        for batch_idx in tqdm(\n            range(num_batches),\n            desc=\"Running OCR Error Detection\",\n            disable=self.disable_tqdm,\n        ):\n            start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size\n            batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(\n                self.model.device\n            )\n            batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(\n                self.model.device\n            )\n\n            # Pad to batch size\n            current_batch_size = batch_input_ids.shape[0]\n            if settings.OCR_ERROR_STATIC_CACHE:\n                batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)\n                batch_attention_mask = self.pad_to_batch_size(\n                    batch_attention_mask, batch_size\n                )\n\n            with settings.INFERENCE_MODE():\n                pred = self.model(batch_input_ids, attention_mask=batch_attention_mask)\n\n                logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size]\n                predictions.extend(logits)\n            mark_step()\n\n        return OCRErrorDetectionResult(\n            texts=texts, labels=[ID2LABEL[p] for p in predictions]\n        )\n"
  },
  {
    "path": "surya/ocr_error/loader.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom surya.common.load import ModelLoader\nfrom surya.logging import get_logger\nfrom surya.ocr_error.model.config import DistilBertConfig\nfrom surya.ocr_error.model.encoder import DistilBertForSequenceClassification\nfrom surya.ocr_error.tokenizer import DistilBertTokenizer\nfrom surya.settings import settings\n\nlogger = get_logger()\n\n\nclass OCRErrorModelLoader(ModelLoader):\n    def __init__(self, checkpoint: Optional[str] = None):\n        super().__init__(checkpoint)\n\n        if self.checkpoint is None:\n            self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT\n\n    def model(\n        self,\n        device=settings.TORCH_DEVICE_MODEL,\n        dtype=settings.MODEL_DTYPE,\n        attention_implementation: Optional[str] = None,\n    ) -> DistilBertForSequenceClassification:\n        if device is None:\n            device = settings.TORCH_DEVICE_MODEL\n        if dtype is None:\n            dtype = settings.MODEL_DTYPE\n\n        config = DistilBertConfig.from_pretrained(self.checkpoint)\n        model = (\n            DistilBertForSequenceClassification.from_pretrained(\n                self.checkpoint,\n                dtype=dtype,\n                config=config,\n            )\n            .to(device)\n            .eval()\n        )\n\n        if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR:\n            torch._dynamo.config.cache_size_limit = 1\n            torch._dynamo.config.suppress_errors = False\n\n            logger.info(\n                f\"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}\"\n            )\n            compile_args = {\"backend\": \"openxla\"} if device == \"xla\" else {}\n            model = torch.compile(model, **compile_args)\n\n        return model\n\n    def processor(\n        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE\n    ) -> DistilBertTokenizer:\n        return DistilBertTokenizer.from_pretrained(self.checkpoint)\n"
  },
  {
    "path": "surya/ocr_error/model/__init__.py",
    "content": ""
  },
  {
    "path": "surya/ocr_error/model/config.py",
    "content": "from collections import OrderedDict\nfrom typing import Mapping\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.onnx import OnnxConfig\n\nfrom surya.common.s3 import S3DownloaderMixin\n\nID2LABEL = {\n    0: 'good',\n    1: 'bad'\n}\n\nclass DistilBertConfig(S3DownloaderMixin, PretrainedConfig):\n    model_type = \"distilbert\"\n    attribute_map = {\n        \"hidden_size\": \"dim\",\n        \"num_attention_heads\": \"n_heads\",\n        \"num_hidden_layers\": \"n_layers\",\n    }\n\n    def __init__(\n        self,\n        vocab_size=30522,\n        max_position_embeddings=512,\n        sinusoidal_pos_embds=False,\n        n_layers=6,\n        n_heads=12,\n        dim=768,\n        hidden_dim=4 * 768,\n        dropout=0.1,\n        attention_dropout=0.1,\n        activation=\"gelu\",\n        initializer_range=0.02,\n        qa_dropout=0.1,\n        seq_classif_dropout=0.2,\n        pad_token_id=0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.sinusoidal_pos_embds = sinusoidal_pos_embds\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = dim\n        self.hidden_dim = hidden_dim\n        self.dropout = dropout\n        self.attention_dropout = attention_dropout\n        self.activation = activation\n        self.initializer_range = initializer_range\n        self.qa_dropout = qa_dropout\n        self.seq_classif_dropout = seq_classif_dropout\n        super().__init__(**kwargs, pad_token_id=pad_token_id)\n\n\nclass DistilBertOnnxConfig(OnnxConfig):\n    @property\n    def inputs(self) -> Mapping[str, Mapping[int, str]]:\n        if self.task == \"multiple-choice\":\n            dynamic_axis = {0: \"batch\", 1: \"choice\", 2: \"sequence\"}\n        else:\n            dynamic_axis = {0: \"batch\", 1: \"sequence\"}\n        return OrderedDict(\n            [\n                (\"input_ids\", dynamic_axis),\n                (\"attention_mask\", dynamic_axis),\n            ]\n        )"
  },
  {
    "path": "surya/ocr_error/model/encoder.py",
    "content": "from __future__ import annotations\n\nimport math\nfrom typing import Optional, Set, List, Tuple, Union, Dict\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss\nfrom transformers import apply_chunking_to_forward\nfrom transformers.activations import get_activation\nfrom transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput\nfrom transformers.pytorch_utils import (\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\n\nfrom transformers.utils import (\n    is_flash_attn_greater_or_equal_2_10,\n)\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.ocr_error.model.config import DistilBertConfig\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):\n    position_enc = np.array(\n        [\n            [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]\n            for pos in range(n_pos)\n        ]\n    )\n    out.requires_grad = False\n    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))\n    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))\n    out.detach_()\n\n\nclass Embeddings(nn.Module):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.dim, padding_idx=config.pad_token_id\n        )\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.dim\n        )\n\n        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.dropout = nn.Dropout(config.dropout)\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(config.max_position_embeddings).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(\n        self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Parameters:\n            input_ids (torch.Tensor):\n                torch.tensor(bs, max_seq_length) The token ids to embed.\n            input_embeds (*optional*, torch.Tensor):\n                The pre-computed word embeddings. Can only be passed if the input ids are `None`.\n\n\n        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type\n        embeddings)\n        \"\"\"\n        if input_ids is not None:\n            input_embeds = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)\n\n        seq_length = input_embeds.size(1)\n\n        # Setting the position-ids to the registered buffer in constructor, it helps\n        # when tracing the model without passing position-ids, solves\n        # isues similar to issue #5664\n        if hasattr(self, \"position_ids\"):\n            position_ids = self.position_ids[:, :seq_length]\n        else:\n            position_ids = torch.arange(\n                seq_length, dtype=torch.long, device=input_ids.device\n            )  # (max_seq_length)\n            position_ids = position_ids.unsqueeze(0).expand_as(\n                input_ids\n            )  # (bs, max_seq_length)\n\n        position_embeddings = self.position_embeddings(\n            position_ids\n        )  # (bs, max_seq_length, dim)\n\n        embeddings = input_embeds + position_embeddings  # (bs, max_seq_length, dim)\n        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)\n        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)\n        return embeddings\n\n\nclass MultiHeadSelfAttention(nn.Module):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__()\n        self.config = config\n\n        self.n_heads = config.n_heads\n        self.dim = config.dim\n        self.dropout = nn.Dropout(p=config.attention_dropout)\n        self.is_causal = False\n\n        # Have an even number of multi heads that divide the dimensions\n        if self.dim % self.n_heads != 0:\n            # Raise value errors for even multi-head attention nodes\n            raise ValueError(\n                f\"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly\"\n            )\n\n        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n\n        self.pruned_heads: Set[int] = set()\n        self.attention_head_size = self.dim // self.n_heads\n\n    def prune_heads(self, heads: List[int]):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads, self.n_heads, self.attention_head_size, self.pruned_heads\n        )\n        # Prune linear layers\n        self.q_lin = prune_linear_layer(self.q_lin, index)\n        self.k_lin = prune_linear_layer(self.k_lin, index)\n        self.v_lin = prune_linear_layer(self.v_lin, index)\n        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)\n        # Update hyper params\n        self.n_heads = self.n_heads - len(heads)\n        self.dim = self.attention_head_size * self.n_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        mask: torch.Tensor,\n        head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, ...]:\n        \"\"\"\n        Parameters:\n            query: torch.tensor(bs, seq_length, dim)\n            key: torch.tensor(bs, seq_length, dim)\n            value: torch.tensor(bs, seq_length, dim)\n            mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,\n            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`\n        \"\"\"\n        bs, q_length, dim = query.size()\n        k_length = key.size(1)\n        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'\n        # assert key.size() == value.size()\n\n        dim_per_head = self.dim // self.n_heads\n\n        mask_reshp = (bs, 1, 1, k_length)\n\n        def shape(x: torch.Tensor) -> torch.Tensor:\n            \"\"\"separate heads\"\"\"\n            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)\n\n        def unshape(x: torch.Tensor) -> torch.Tensor:\n            \"\"\"group heads\"\"\"\n            return (\n                x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)\n            )\n\n        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)\n        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)\n        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)\n\n        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)\n        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)\n        mask = (\n            (mask == 0).view(mask_reshp).expand_as(scores)\n        )  # (bs, n_heads, q_length, k_length)\n        scores = scores.masked_fill(\n            mask, torch.tensor(torch.finfo(scores.dtype).min)\n        )  # (bs, n_heads, q_length, k_length)\n\n        weights = nn.functional.softmax(\n            scores, dim=-1\n        )  # (bs, n_heads, q_length, k_length)\n        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            weights = weights * head_mask\n\n        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)\n        context = unshape(context)  # (bs, q_length, dim)\n        context = self.out_lin(context)  # (bs, q_length, dim)\n\n        if output_attentions:\n            return (context, weights)\n        else:\n            return (context,)\n\n\nclass DistilBertFlashAttention2(MultiHeadSelfAttention):\n    \"\"\"\n    DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module\n    stays untouched. The only required change would be on the forward pass where it needs to correctly call the public\n    API of flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        mask: torch.Tensor,\n        head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, ...]:\n        \"\"\"\n        Parameters:\n            query: torch.tensor(bs, seq_length, dim)\n            key: torch.tensor(bs, seq_length, dim)\n            value: torch.tensor(bs, seq_length, dim)\n            mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,\n            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`\n        \"\"\"\n        batch_size, q_length, dim = query.size()\n\n        dim_per_head = self.dim // self.n_heads\n\n        def reshape(x: torch.Tensor) -> torch.Tensor:\n            \"\"\"separate heads\"\"\"\n            return x.view(batch_size, -1, self.n_heads, dim_per_head)\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        query_states = reshape(self.q_lin(query))\n        key_states = reshape(self.k_lin(key))\n        value_states = reshape(self.v_lin(value))\n\n        attn_dropout = self.config.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n\n        if query_states.dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_lin.weight.dtype\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_weights = self._flash_attention_forward(\n            query_states, key_states, value_states, mask, q_length, dropout=attn_dropout\n        )\n\n        attn_weights_reshaped = attn_weights.reshape(\n            batch_size, q_length, self.n_heads * dim_per_head\n        )\n        attn_output = self.out_lin(attn_weights_reshaped)\n\n        if output_attentions:\n            return (attn_output, attn_weights)\n        else:\n            return (attn_output,)\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`float`):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        from flash_attn import flash_attn_func, flash_attn_varlen_func\n        from flash_attn.bert_padding import pad_input\n\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states,\n                dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n        return attn_output\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        from flash_attn.bert_padding import index_first_axis, unpad_input\n\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nclass FFN(nn.Module):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__()\n        self.dropout = nn.Dropout(p=config.dropout)\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)\n        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)\n        self.activation = get_activation(config.activation)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return apply_chunking_to_forward(\n            self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input\n        )\n\n    def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:\n        x = self.lin1(input)\n        x = self.activation(x)\n        x = self.lin2(x)\n        x = self.dropout(x)\n        return x\n\n\nDISTILBERT_ATTENTION_CLASSES = {\n    \"eager\": MultiHeadSelfAttention,\n    \"flash_attention_2\": DistilBertFlashAttention2,\n}\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__()\n\n        # Have an even number of Configure multi-heads\n        if config.dim % config.n_heads != 0:\n            raise ValueError(\n                f\"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly\"\n            )\n\n        self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](\n            config\n        )\n        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n        self.ffn = FFN(config)\n        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        attn_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, ...]:\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim)\n            attn_mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:\n            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.\n        \"\"\"\n        # Self-Attention\n        sa_output = self.attention(\n            query=x,\n            key=x,\n            value=x,\n            mask=attn_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n        )\n        if output_attentions:\n            sa_output, sa_weights = (\n                sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)\n            )\n        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples\n            sa_output = sa_output[0]\n\n        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)\n\n        # Feed Forward Network\n        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)\n        ffn_output: torch.Tensor = self.output_layer_norm(\n            ffn_output + sa_output\n        )  # (bs, seq_length, dim)\n\n        output = (ffn_output,)\n        if output_attentions:\n            output = (sa_weights,) + output\n        return output\n\n\nclass Transformer(nn.Module):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__()\n        self.n_layers = config.n_layers\n        self.layer = nn.ModuleList(\n            [TransformerBlock(config) for _ in range(config.n_layers)]\n        )\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        attn_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: Optional[bool] = None,\n    ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:  # docstyle-ignore\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.\n            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.\n\n        Returns:\n            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)\n            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]\n                Tuple of length n_layers with the hidden states from each layer.\n                Optional: only if output_hidden_states=True\n            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]\n                Tuple of length n_layers with the attention weights from each layer\n                Optional: only if output_attentions=True\n        \"\"\"\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_state = x\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_state,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    layer_module.__call__,\n                    hidden_state,\n                    attn_mask,\n                    head_mask[i],\n                    output_attentions,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_state,\n                    attn_mask,\n                    head_mask[i],\n                    output_attentions,\n                )\n\n            hidden_state = layer_outputs[-1]\n\n            if output_attentions:\n                if len(layer_outputs) != 2:\n                    raise ValueError(\n                        f\"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}\"\n                    )\n\n                attentions = layer_outputs[0]\n                all_attentions = all_attentions + (attentions,)\n            else:\n                if len(layer_outputs) != 1:\n                    raise ValueError(\n                        f\"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}\"\n                    )\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_state,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_state, all_hidden_states, all_attentions]\n                if v is not None\n            )\n        return BaseModelOutput(\n            last_hidden_state=hidden_state,\n            hidden_states=all_hidden_states,\n            attentions=all_attentions,\n        )\n\n\nclass DistilBertPreTrainedModel(SuryaPreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = DistilBertConfig\n    load_tf_weights = None\n    base_model_prefix = \"distilbert\"\n    supports_gradient_checkpointing = True\n    _supports_flash_attn_2 = True\n\n    def _init_weights(self, module: nn.Module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:\n            create_sinusoidal_embeddings(\n                self.config.max_position_embeddings,\n                self.config.dim,\n                module.position_embeddings.weight,\n            )\n\n\nclass DistilBertModel(DistilBertPreTrainedModel):\n    def __init__(self, config: DistilBertConfig):\n        super().__init__(config)\n\n        self.embeddings = Embeddings(config)  # Embeddings\n        self.transformer = Transformer(config)  # Encoder\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_position_embeddings(self) -> nn.Embedding:\n        \"\"\"\n        Returns the position embeddings\n        \"\"\"\n        return self.embeddings.position_embeddings\n\n    def resize_position_embeddings(self, new_num_position_embeddings: int):\n        \"\"\"\n        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.\n\n        Arguments:\n            new_num_position_embeddings (`int`):\n                The number of new position embedding matrix. If position embeddings are learned, increasing the size\n                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the\n                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the\n                size will add correct vectors at the end following the position encoding algorithm, whereas reducing\n                the size will remove vectors from the end.\n        \"\"\"\n        num_position_embeds_diff = (\n            new_num_position_embeddings - self.config.max_position_embeddings\n        )\n\n        # no resizing needs to be done if the length stays the same\n        if num_position_embeds_diff == 0:\n            return\n\n        self.config.max_position_embeddings = new_num_position_embeddings\n\n        old_position_embeddings_weight = (\n            self.embeddings.position_embeddings.weight.clone()\n        )\n\n        self.embeddings.position_embeddings = nn.Embedding(\n            self.config.max_position_embeddings, self.config.dim\n        )\n\n        if self.config.sinusoidal_pos_embds:\n            create_sinusoidal_embeddings(\n                n_pos=self.config.max_position_embeddings,\n                dim=self.config.dim,\n                out=self.position_embeddings.weight,\n            )\n        else:\n            with torch.no_grad():\n                if num_position_embeds_diff > 0:\n                    self.embeddings.position_embeddings.weight[\n                        :-num_position_embeds_diff\n                    ] = nn.Parameter(old_position_embeddings_weight)\n                else:\n                    self.embeddings.position_embeddings.weight = nn.Parameter(\n                        old_position_embeddings_weight[:num_position_embeds_diff]\n                    )\n        # move position_embeddings to correct device\n        self.embeddings.position_embeddings.to(self.device)\n\n    def get_input_embeddings(self) -> nn.Embedding:\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, new_embeddings: nn.Embedding):\n        self.embeddings.word_embeddings = new_embeddings\n\n    def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.transformer.layer[layer].attention.prune_heads(heads)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)\n            input_shape = input_ids.size()\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)\n\n        if self._use_flash_attention_2:\n            attention_mask = (\n                attention_mask\n                if (attention_mask is not None and 0 in attention_mask)\n                else None\n            )\n        else:\n            if attention_mask is None:\n                attention_mask = torch.ones(\n                    input_shape, device=device\n                )  # (bs, seq_length)\n\n        return self.transformer(\n            x=embeddings,\n            attn_mask=attention_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\nclass DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel):\n    def __init__(self, config: DistilBertConfig, **kwargs):\n        super().__init__(config, **kwargs)\n        self.num_labels = config.num_labels\n        self.config = config\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, config.num_labels)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_position_embeddings(self) -> nn.Embedding:\n        \"\"\"\n        Returns the position embeddings\n        \"\"\"\n        return self.distilbert.get_position_embeddings()\n\n    def resize_position_embeddings(self, new_num_position_embeddings: int):\n        \"\"\"\n        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.\n\n        Arguments:\n            new_num_position_embeddings (`int`):\n                The number of new position embedding matrix. If position embeddings are learned, increasing the size\n                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the\n                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the\n                size will add correct vectors at the end following the position encoding algorithm, whereas reducing\n                the size will remove vectors from the end.\n        \"\"\"\n        self.distilbert.resize_position_embeddings(new_num_position_embeddings)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs, dim)\n        logits = self.classifier(pooled_output)  # (bs, num_labels)\n\n        loss = None\n        if labels is not None:\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(logits, labels)\n\n        if not return_dict:\n            output = (logits,) + distilbert_output[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n"
  },
  {
    "path": "surya/ocr_error/schema.py",
    "content": "from typing import List\n\nfrom pydantic import BaseModel\n\n\nclass OCRErrorDetectionResult(BaseModel):\n    texts: List[str]\n    labels: List[str]\n"
  },
  {
    "path": "surya/ocr_error/tokenizer.py",
    "content": "import collections\nimport os\nimport json\nimport unicodedata\nfrom typing import List, Optional, Tuple\n\nfrom tokenizers import normalizers\n\nfrom transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace\nfrom transformers.tokenization_utils_fast import PreTrainedTokenizerFast\n\nfrom surya.common.s3 import S3DownloaderMixin\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"vocab.txt\"}\n\n# Copied from transformers.models.bert.tokenization_bert.load_vocab\ndef load_vocab(vocab_file):\n    \"\"\"Loads a vocabulary file into a dictionary.\"\"\"\n    vocab = collections.OrderedDict()\n    with open(vocab_file, \"r\", encoding=\"utf-8\") as reader:\n        tokens = reader.readlines()\n    for index, token in enumerate(tokens):\n        token = token.rstrip(\"\\n\")\n        vocab[token] = index\n    return vocab\n\n\n# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize\ndef whitespace_tokenize(text):\n    \"\"\"Runs basic whitespace cleaning and splitting on a piece of text.\"\"\"\n    text = text.strip()\n    if not text:\n        return []\n    tokens = text.split()\n    return tokens\n\n\nclass DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):\n    r\"\"\"\n    Construct a DistilBERT tokenizer. Based on WordPiece.\n\n    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to\n    this superclass for more information regarding those methods.\n\n    Args:\n        vocab_file (`str`):\n            File containing the vocabulary.\n        do_lower_case (`bool`, *optional*, defaults to `True`):\n            Whether or not to lowercase the input when tokenizing.\n        do_basic_tokenize (`bool`, *optional*, defaults to `True`):\n            Whether or not to do basic tokenization before WordPiece.\n        never_split (`Iterable`, *optional*):\n            Collection of tokens which will never be split during tokenization. Only has an effect when\n            `do_basic_tokenize=True`\n        unk_token (`str`, *optional*, defaults to `\"[UNK]\"`):\n            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this\n            token instead.\n        sep_token (`str`, *optional*, defaults to `\"[SEP]\"`):\n            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for\n            sequence classification or for a text and a question for question answering. It is also used as the last\n            token of a sequence built with special tokens.\n        pad_token (`str`, *optional*, defaults to `\"[PAD]\"`):\n            The token used for padding, for example when batching sequences of different lengths.\n        cls_token (`str`, *optional*, defaults to `\"[CLS]\"`):\n            The classifier token which is used when doing sequence classification (classification of the whole sequence\n            instead of per-token classification). It is the first token of the sequence when built with special tokens.\n        mask_token (`str`, *optional*, defaults to `\"[MASK]\"`):\n            The token used for masking values. This is the token used when training this model with masked language\n            modeling. This is the token which the model will try to predict.\n        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):\n            Whether or not to tokenize Chinese characters.\n\n            This should likely be deactivated for Japanese (see this\n            [issue](https://github.com/huggingface/transformers/issues/328)).\n        strip_accents (`bool`, *optional*):\n            Whether or not to strip all accents. If this option is not specified, then it will be determined by the\n            value for `lowercase` (as in the original BERT).\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    def __init__(\n        self,\n        vocab_file,\n        do_lower_case=True,\n        do_basic_tokenize=True,\n        never_split=None,\n        unk_token=\"[UNK]\",\n        sep_token=\"[SEP]\",\n        pad_token=\"[PAD]\",\n        cls_token=\"[CLS]\",\n        mask_token=\"[MASK]\",\n        tokenize_chinese_chars=True,\n        strip_accents=None,\n        **kwargs,\n    ):\n        if not os.path.isfile(vocab_file):\n            raise ValueError(\n                f\"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained\"\n                \" model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`\"\n            )\n        self.vocab = load_vocab(vocab_file)\n        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])\n        self.do_basic_tokenize = do_basic_tokenize\n        if do_basic_tokenize:\n            self.basic_tokenizer = BasicTokenizer(\n                do_lower_case=do_lower_case,\n                never_split=never_split,\n                tokenize_chinese_chars=tokenize_chinese_chars,\n                strip_accents=strip_accents,\n            )\n        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))\n\n        super().__init__(\n            do_lower_case=do_lower_case,\n            do_basic_tokenize=do_basic_tokenize,\n            never_split=never_split,\n            unk_token=unk_token,\n            sep_token=sep_token,\n            pad_token=pad_token,\n            cls_token=cls_token,\n            mask_token=mask_token,\n            tokenize_chinese_chars=tokenize_chinese_chars,\n            strip_accents=strip_accents,\n            **kwargs,\n        )\n\n    @property\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case\n    def do_lower_case(self):\n        return self.basic_tokenizer.do_lower_case\n\n    @property\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size\n    def vocab_size(self):\n        return len(self.vocab)\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab\n    def get_vocab(self):\n        return dict(self.vocab, **self.added_tokens_encoder)\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize\n    def _tokenize(self, text, split_special_tokens=False):\n        split_tokens = []\n        if self.do_basic_tokenize:\n            for token in self.basic_tokenizer.tokenize(\n                text, never_split=self.all_special_tokens if not split_special_tokens else None\n            ):\n                # If the token is part of the never_split set\n                if token in self.basic_tokenizer.never_split:\n                    split_tokens.append(token)\n                else:\n                    split_tokens += self.wordpiece_tokenizer.tokenize(token)\n        else:\n            split_tokens = self.wordpiece_tokenizer.tokenize(text)\n        return split_tokens\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.vocab.get(token, self.vocab.get(self.unk_token))\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.ids_to_tokens.get(index, self.unk_token)\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        out_string = \" \".join(tokens).replace(\" ##\", \"\").strip()\n        return out_string\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        if token_ids_1 is None:\n            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]\n        cls = [self.cls_token_id]\n        sep = [self.sep_token_id]\n        return cls + token_ids_0 + sep + token_ids_1 + sep\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask\n    def get_special_tokens_mask(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n            )\n\n        if token_ids_1 is not None:\n            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]\n        return [1] + ([0] * len(token_ids_0)) + [1]\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence\n        pair mask has the following format:\n\n        ```\n        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n        | first sequence    | second sequence |\n        ```\n\n        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).\n        \"\"\"\n        sep = [self.sep_token_id]\n        cls = [self.cls_token_id]\n        if token_ids_1 is None:\n            return len(cls + token_ids_0 + sep) * [0]\n        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]\n\n    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        index = 0\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n            )\n        else:\n            vocab_file = (filename_prefix + \"-\" if filename_prefix else \"\") + save_directory\n        with open(vocab_file, \"w\", encoding=\"utf-8\") as writer:\n            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):\n                if index != token_index:\n                    # logger.warning(\n                    #     f\"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.\"\n                    #     \" Please check that the vocabulary is not corrupted!\"\n                    # )\n                    index = token_index\n                writer.write(token + \"\\n\")\n                index += 1\n        return (vocab_file,)\n\n\n# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer\nclass BasicTokenizer(object):\n    \"\"\"\n    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).\n\n    Args:\n        do_lower_case (`bool`, *optional*, defaults to `True`):\n            Whether or not to lowercase the input when tokenizing.\n        never_split (`Iterable`, *optional*):\n            Collection of tokens which will never be split during tokenization. Only has an effect when\n            `do_basic_tokenize=True`\n        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):\n            Whether or not to tokenize Chinese characters.\n\n            This should likely be deactivated for Japanese (see this\n            [issue](https://github.com/huggingface/transformers/issues/328)).\n        strip_accents (`bool`, *optional*):\n            Whether or not to strip all accents. If this option is not specified, then it will be determined by the\n            value for `lowercase` (as in the original BERT).\n        do_split_on_punc (`bool`, *optional*, defaults to `True`):\n            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture\n            the full context of the words, such as contractions.\n    \"\"\"\n\n    def __init__(\n        self,\n        do_lower_case=True,\n        never_split=None,\n        tokenize_chinese_chars=True,\n        strip_accents=None,\n        do_split_on_punc=True,\n    ):\n        if never_split is None:\n            never_split = []\n        self.do_lower_case = do_lower_case\n        self.never_split = set(never_split)\n        self.tokenize_chinese_chars = tokenize_chinese_chars\n        self.strip_accents = strip_accents\n        self.do_split_on_punc = do_split_on_punc\n\n    def tokenize(self, text, never_split=None):\n        \"\"\"\n        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.\n\n        Args:\n            never_split (`List[str]`, *optional*)\n                Kept for backward compatibility purposes. Now implemented directly at the base class level (see\n                [`PreTrainedTokenizer.tokenize`]) List of token not to split.\n        \"\"\"\n        # union() returns a new set by concatenating the two sets.\n        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split\n        text = self._clean_text(text)\n\n        # This was added on November 1st, 2018 for the multilingual and Chinese\n        # models. This is also applied to the English models now, but it doesn't\n        # matter since the English models were not trained on any Chinese data\n        # and generally don't have any Chinese data in them (there are Chinese\n        # characters in the vocabulary because Wikipedia does have some Chinese\n        # words in the English Wikipedia.).\n        if self.tokenize_chinese_chars:\n            text = self._tokenize_chinese_chars(text)\n        # prevents treating the same character with different unicode codepoints as different characters\n        unicode_normalized_text = unicodedata.normalize(\"NFC\", text)\n        orig_tokens = whitespace_tokenize(unicode_normalized_text)\n        split_tokens = []\n        for token in orig_tokens:\n            if token not in never_split:\n                if self.do_lower_case:\n                    token = token.lower()\n                    if self.strip_accents is not False:\n                        token = self._run_strip_accents(token)\n                elif self.strip_accents:\n                    token = self._run_strip_accents(token)\n            split_tokens.extend(self._run_split_on_punc(token, never_split))\n\n        output_tokens = whitespace_tokenize(\" \".join(split_tokens))\n        return output_tokens\n\n    def _run_strip_accents(self, text):\n        \"\"\"Strips accents from a piece of text.\"\"\"\n        text = unicodedata.normalize(\"NFD\", text)\n        output = []\n        for char in text:\n            cat = unicodedata.category(char)\n            if cat == \"Mn\":\n                continue\n            output.append(char)\n        return \"\".join(output)\n\n    def _run_split_on_punc(self, text, never_split=None):\n        \"\"\"Splits punctuation on a piece of text.\"\"\"\n        if not self.do_split_on_punc or (never_split is not None and text in never_split):\n            return [text]\n        chars = list(text)\n        i = 0\n        start_new_word = True\n        output = []\n        while i < len(chars):\n            char = chars[i]\n            if _is_punctuation(char):\n                output.append([char])\n                start_new_word = True\n            else:\n                if start_new_word:\n                    output.append([])\n                start_new_word = False\n                output[-1].append(char)\n            i += 1\n\n        return [\"\".join(x) for x in output]\n\n    def _tokenize_chinese_chars(self, text):\n        \"\"\"Adds whitespace around any CJK character.\"\"\"\n        output = []\n        for char in text:\n            cp = ord(char)\n            if self._is_chinese_char(cp):\n                output.append(\" \")\n                output.append(char)\n                output.append(\" \")\n            else:\n                output.append(char)\n        return \"\".join(output)\n\n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n        ):  #\n            return True\n\n        return False\n\n    def _clean_text(self, text):\n        \"\"\"Performs invalid character removal and whitespace cleanup on text.\"\"\"\n        output = []\n        for char in text:\n            cp = ord(char)\n            if cp == 0 or cp == 0xFFFD or _is_control(char):\n                continue\n            if _is_whitespace(char):\n                output.append(\" \")\n            else:\n                output.append(char)\n        return \"\".join(output)\n\n\n# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer\nclass WordpieceTokenizer(object):\n    \"\"\"Runs WordPiece tokenization.\"\"\"\n\n    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):\n        self.vocab = vocab\n        self.unk_token = unk_token\n        self.max_input_chars_per_word = max_input_chars_per_word\n\n    def tokenize(self, text):\n        \"\"\"\n        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform\n        tokenization using the given vocabulary.\n\n        For example, `input = \"unaffable\"` wil return as output `[\"un\", \"##aff\", \"##able\"]`.\n\n        Args:\n            text: A single token or whitespace separated tokens. This should have\n                already been passed through *BasicTokenizer*.\n\n        Returns:\n            A list of wordpiece tokens.\n        \"\"\"\n\n        output_tokens = []\n        for token in whitespace_tokenize(text):\n            chars = list(token)\n            if len(chars) > self.max_input_chars_per_word:\n                output_tokens.append(self.unk_token)\n                continue\n\n            is_bad = False\n            start = 0\n            sub_tokens = []\n            while start < len(chars):\n                end = len(chars)\n                cur_substr = None\n                while start < end:\n                    substr = \"\".join(chars[start:end])\n                    if start > 0:\n                        substr = \"##\" + substr\n                    if substr in self.vocab:\n                        cur_substr = substr\n                        break\n                    end -= 1\n                if cur_substr is None:\n                    is_bad = True\n                    break\n                sub_tokens.append(cur_substr)\n                start = end\n\n            if is_bad:\n                output_tokens.append(self.unk_token)\n            else:\n                output_tokens.extend(sub_tokens)\n        return output_tokens"
  },
  {
    "path": "surya/recognition/__init__.py",
    "content": "from __future__ import annotations\n\nimport re\nfrom typing import List\n\nimport numpy as np\nimport torch\nfrom PIL import Image\nimport torch.nn.functional as F\n\nfrom surya.common.polygon import PolygonBox\nfrom surya.common.surya.processor import NOMATH_TOKEN\nfrom surya.common.predictor import BasePredictor\nfrom surya.detection import DetectionPredictor\nfrom surya.foundation import FoundationPredictor\n\nfrom surya.input.processing import (\n    convert_if_not_rgb,\n    slice_polys_from_image,\n    slice_bboxes_from_image,\n)\nfrom surya.recognition.postprocessing import fix_unbalanced_tags\nfrom surya.recognition.util import (\n    sort_text_lines,\n    clean_close_polygons,\n    unwrap_math,\n    clean_math_tags,\n    filter_blacklist_tags,\n    words_from_chars\n)\nfrom surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch\nfrom surya.recognition.schema import TextLine, OCRResult, TextChar\nfrom surya.common.surya.schema import TaskNames\nfrom surya.settings import settings\nfrom surya.logging import get_logger, configure_logging\n\nconfigure_logging()\nlogger = get_logger()\n\nclass RecognitionPredictor(BasePredictor):\n    batch_size = settings.RECOGNITION_BATCH_SIZE\n    default_batch_sizes = {\"cpu\": 32, \"mps\": 64, \"cuda\": 256, \"xla\": 128}\n\n    # Override base init - Do not load model\n    def __init__(self, foundation_predictor: FoundationPredictor):\n        self.foundation_predictor = foundation_predictor\n        self.processor = self.foundation_predictor.processor\n        self.bbox_size = self.foundation_predictor.model.config.bbox_size\n        self.tasks = self.foundation_predictor.tasks\n\n    # Special handling for disable tqdm to pass into foundation predictor\n    # Make sure they are kept in sync\n    @property\n    def disable_tqdm(self) -> bool:\n        return super().disable_tqdm\n\n    @disable_tqdm.setter\n    def disable_tqdm(self, value: bool) -> None:\n        self._disable_tqdm = bool(value)\n        self.foundation_predictor.disable_tqdm = bool(value)\n\n    def detect_and_slice_bboxes(\n        self,\n        images: List[Image.Image],\n        task_names: List[str],\n        det_predictor: DetectionPredictor,\n        detection_batch_size: int | None = None,\n        highres_images: List[Image.Image] | None = None,\n    ):\n        det_predictions = det_predictor(images, batch_size=detection_batch_size)\n\n        all_slices = []\n        slice_map = []\n        all_polygons = []\n        all_task_names = []\n        all_res_scales = []\n\n        for idx, (det_pred, image, highres_image, task_name) in enumerate(\n            zip(det_predictions, images, highres_images, task_names)\n        ):\n            polygons = [p.polygon for p in det_pred.bboxes]\n            if highres_image:\n                width_scaler = highres_image.size[0] / image.size[0]\n                height_scaler = highres_image.size[1] / image.size[1]\n                scaled_polygons = [\n                    [\n                        [int(p[0] * width_scaler), int(p[1] * height_scaler)]\n                        for p in polygon\n                    ]\n                    for polygon in polygons\n                ]\n                highres_image = self.processor.image_processor(highres_image)\n                slices = slice_polys_from_image(highres_image, scaled_polygons)\n                res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))]\n            else:\n                image = self.processor.image_processor(image)\n                slices = slice_polys_from_image(image, polygons)\n                res_scales = [(1, 1) for _ in range(len(slices))]\n\n            slice_map.append(len(slices))\n            all_slices.extend(slices)\n            all_polygons.extend(polygons)\n            all_task_names.extend([task_name] * len(slices))\n            all_res_scales.extend(res_scales)\n\n        assert (\n            len(all_slices)\n            == sum(slice_map)\n            == len(all_polygons)\n            == len(all_task_names)\n            == len(all_res_scales)\n        )\n\n        return {\n            \"slices\": all_slices,\n            \"slice_map\": slice_map,\n            \"polygons\": all_polygons,\n            \"task_names\": all_task_names,\n            \"input_text\": [None] * len(all_slices),\n            \"res_scales\": all_res_scales,\n        }\n\n    def slice_bboxes(\n        self,\n        images: List[Image.Image],\n        task_names: List[str],\n        bboxes: List[List[List[int]]] | None = None,\n        polygons: List[List[List[List[int]]]] | None = None,\n        input_text: List[List[str | None]] | None = None,\n    ) -> dict:\n        assert bboxes is not None or polygons is not None\n        slice_map = []\n        all_slices = []\n        all_polygons = []\n        all_text = []\n        all_task_names = []\n\n        for idx, image in enumerate(images):\n            image = self.processor.image_processor(image)\n            if polygons is not None:\n                polys = polygons[idx]\n                slices = slice_polys_from_image(image, polys)\n            else:\n                slices = slice_bboxes_from_image(image, bboxes[idx])\n                polys = [\n                    [\n                        [bbox[0], bbox[1]],\n                        [bbox[2], bbox[1]],\n                        [bbox[2], bbox[3]],\n                        [bbox[0], bbox[3]],\n                    ]\n                    for bbox in bboxes[idx]\n                ]\n            slice_map.append(len(slices))\n            all_slices.extend(slices)\n            all_polygons.extend(polys)\n            all_task_names.extend([task_names[idx]] * len(slices))\n\n            if input_text is None:\n                all_text.extend([None] * len(slices))\n            else:\n                all_text.extend(input_text[idx])\n\n        assert (\n            len(all_slices)\n            == sum(slice_map)\n            == len(all_polygons)\n            == len(all_text)\n            == len(all_task_names)\n        ), (\n            f\"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}\"\n        )\n\n        return {\n            \"slices\": all_slices,\n            \"slice_map\": slice_map,\n            \"polygons\": all_polygons,\n            \"input_text\": all_text,\n            \"task_names\": all_task_names,\n            \"res_scales\": [(1, 1) for _ in range(len(all_slices))],\n        }\n\n    def get_bboxes_text(\n        self,\n        flat: dict,\n        predicted_tokens: list,\n        scores: list,\n        predicted_polygons: list,\n        drop_repeated_text: bool = False,\n    ) -> list:\n        char_predictions = []\n        needs_boxes = [\n            self.tasks[task_name][\"needs_bboxes\"] for task_name in flat[\"task_names\"]\n        ]\n\n        for slice_idx, (\n            slice_image,\n            image_tokens,\n            image_polygons,\n            image_scores,\n            needs_box,\n        ) in enumerate(\n            zip(\n                flat[\"slices\"],\n                predicted_tokens,\n                predicted_polygons,\n                scores,\n                needs_boxes,\n            )\n        ):\n            blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]]\n            if self.processor.no_output_token in image_tokens:\n                char_predictions.append(None)\n                continue\n\n            # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely\n            if drop_repeated_text and detect_repeat_token(image_tokens):\n                char_predictions.append(\n                    [\n                        TextChar(\n                            text=\"\",\n                            polygon=blank_bbox,\n                            confidence=0,\n                            bbox_valid=False,\n                        )\n                    ]\n                )\n                continue\n\n            image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist()\n\n            detokenize_sequences = []\n            detokenize_sequence = []\n            past_char_qwen_token = False\n\n            def _add_detokenize_sequence(\n                special_token: bool,\n                past_special_token: bool,\n                force: bool = False,\n            ):\n                nonlocal detokenize_sequence, detokenize_sequences\n\n                if (\n                    special_token\n                    or past_special_token\n                    or force\n                ) and detokenize_sequence:\n                    chars = [dt[0] for dt in detokenize_sequence]\n                    scores = [dt[1] for dt in detokenize_sequence]\n                    bboxes = [dt[2] for dt in detokenize_sequence]\n\n                    if past_special_token:\n                        detokenize_sequences.append((chars, scores, None, \"special\"))\n                    else:\n                        detokenize_sequences.append((chars, scores, bboxes, \"ocr\"))\n\n                    detokenize_sequence = []\n\n            # Split up into sequences to detokenize separately\n            past_special_token = False\n            for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores):\n                if char_id in [\n                    self.processor.eos_token_id,\n                    self.processor.pad_token_id,\n                ]:\n                    break\n\n                special_token = (\n                    char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE\n                )\n                _add_detokenize_sequence(\n                    special_token, past_special_token\n                )\n                detokenize_sequence.append((char_id, score, bbox))\n                past_special_token = special_token\n\n            _add_detokenize_sequence(\n                False, past_special_token, force=True\n            )\n\n            img_chars = []\n            for sequence in detokenize_sequences:\n                token_ids, seq_score, bboxes, token_type = sequence\n                if token_type == \"ocr\":\n                    text = self.processor.ocr_tokenizer.decode(\n                        token_ids, task=TaskNames.ocr_with_boxes\n                    )\n                    bboxes = clean_close_polygons(\n                        bboxes\n                    )  # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char\n                    bbox_idx = 0\n                    for text_idx, text_line in enumerate(text):\n                        img_chars.append(\n                            TextChar(\n                                text=text_line,\n                                polygon=bboxes[bbox_idx],\n                                confidence=seq_score[bbox_idx],\n                                bbox_valid=True,\n                            )\n                        )\n\n                        # Ensure we don't exceed the bbox count\n                        # Use the last bbox for the rest of the text\n                        if bbox_idx < len(bboxes) - 1:\n                            bbox_idx += 1\n                elif token_type == \"special\":\n                    text = self.processor.ocr_tokenizer.decode(\n                        token_ids, task=\"ocr_without_boxes\"\n                    )\n                    if text in [NOMATH_TOKEN] or re.match(r\"<SCRIPT-\\w+>\", text):\n                        continue\n\n                    img_chars.append(\n                        TextChar(\n                            text=text,\n                            polygon=blank_bbox,\n                            confidence=seq_score[0],\n                            bbox_valid=False,\n                        )\n                    )\n                else:\n                    text = self.processor.ocr_tokenizer.decode(\n                        token_ids, task=TaskNames.block_without_boxes\n                    )\n                    img_chars.append(\n                        TextChar(\n                            text=text,\n                            polygon=blank_bbox,\n                            confidence=seq_score[0],\n                            bbox_valid=False,\n                        )\n                    )\n\n            char_predictions.append(img_chars)\n\n        return char_predictions\n\n    def __call__(\n        self,\n        images: List[Image.Image],\n        task_names: List[str] | None = None,\n        det_predictor: DetectionPredictor | None = None,\n        detection_batch_size: int | None = None,\n        recognition_batch_size: int | None = None,\n        highres_images: List[Image.Image] | None = None,\n        bboxes: List[List[List[int]]] | None = None,\n        polygons: List[List[List[List[int]]]] | None = None,\n        input_text: List[List[str | None]] | None = None,\n        sort_lines: bool = False,\n        math_mode: bool = True,\n        return_words: bool = False,\n        drop_repeated_text: bool = False,\n        max_sliding_window: int | None = None,\n        max_tokens: int | None = None,\n        filter_tag_list: List[str] = None\n    ) -> List[OCRResult]:\n        if task_names is None:\n            task_names = [TaskNames.ocr_with_boxes] * len(images)\n        if recognition_batch_size is None:\n            recognition_batch_size = self.get_batch_size()\n\n        assert len(images) == len(task_names), (\n            \"You need to pass in one task name for each image\"\n        )\n\n        images = convert_if_not_rgb(images)\n        if highres_images is not None:\n            assert len(images) == len(highres_images), (\n                \"You need to pass in one highres image for each image\"\n            )\n\n        highres_images = (\n            convert_if_not_rgb(highres_images)\n            if highres_images is not None\n            else [None] * len(images)\n        )\n\n        if bboxes is None and polygons is None:\n            assert det_predictor is not None, (\n                \"You need to pass in a detection predictor if you don't provide bboxes or polygons\"\n            )\n\n            # Detect then slice\n            flat = self.detect_and_slice_bboxes(\n                images,\n                task_names,\n                det_predictor,\n                detection_batch_size=detection_batch_size,\n                highres_images=highres_images,\n            )\n        else:\n            if bboxes is not None:\n                assert len(images) == len(bboxes), (\n                    \"You need to pass in one list of bboxes for each image\"\n                )\n            if polygons is not None:\n                assert len(images) == len(polygons), (\n                    \"You need to pass in one list of polygons for each image\"\n                )\n\n            flat = self.slice_bboxes(\n                images,\n                bboxes=bboxes,\n                polygons=polygons,\n                input_text=input_text,\n                task_names=task_names,\n            )\n\n        # No images passed, or no boxes passed, or no text detected in the images\n        if len(flat[\"slices\"]) == 0:\n            return [\n                OCRResult(\n                    text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]]\n                )\n                for im in images\n            ]\n\n        # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better\n        sorted_pairs = sorted(\n            enumerate(flat[\"slices\"]),\n            key=lambda x: -(x[1].shape[0] * x[1].shape[1])  # height * width\n        )\n        indices, sorted_slices = zip(*sorted_pairs)\n\n        # Reorder input_text and task_names based on the new order\n        flat[\"slices\"] = list(sorted_slices)\n        flat[\"input_text\"] = [flat[\"input_text\"][i] for i in indices]\n        flat[\"task_names\"] = [flat[\"task_names\"][i] for i in indices]\n\n        # Make predictions\n        predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop(\n            images=flat[\"slices\"],\n            input_texts=flat[\"input_text\"],\n            task_names=flat[\"task_names\"],\n            batch_size=recognition_batch_size,\n            math_mode=math_mode,\n            drop_repeated_tokens=True,\n            max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance,\n            max_sliding_window=max_sliding_window,\n            max_tokens=max_tokens,\n            tqdm_desc=\"Recognizing Text\"\n        )\n\n        # Get text and bboxes in structured form\n        bbox_size = self.bbox_size\n        image_sizes = [img.shape for img in flat[\"slices\"]]\n        predicted_polygons = prediction_to_polygon_batch(\n            batch_bboxes, image_sizes, bbox_size, bbox_size // 2\n        )\n        char_predictions = self.get_bboxes_text(\n            flat,\n            predicted_tokens,\n            scores,\n            predicted_polygons,\n            drop_repeated_text=drop_repeated_text,\n        )\n\n        char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0])\n        char_predictions = [pred for _, pred in char_predictions]\n\n        predictions_by_image = []\n        slice_start = 0\n        for idx, image in enumerate(images):\n            slice_end = slice_start + flat[\"slice_map\"][idx]\n            image_lines = char_predictions[slice_start:slice_end]\n            polygons = flat[\"polygons\"][slice_start:slice_end]\n            res_scales = flat[\"res_scales\"][slice_start:slice_end]\n            slice_start = slice_end\n\n            lines = []\n            for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales):\n                # Special case when input text is good\n                if not text_line:\n                    lines.append(\n                        TextLine(\n                            text=\"\",\n                            polygon=polygon,\n                            chars=[],\n                            confidence=1,\n                            original_text_good=True,\n                        )\n                    )\n                else:\n                    confidence = (\n                        float(np.mean([char.confidence for char in text_line]))\n                        if len(text_line) > 0\n                        else 0\n                    )\n                    poly_box = PolygonBox(polygon=polygon)\n                    for char in text_line:\n                        char.rescale(\n                            res_scale, (1, 1)\n                        )  # Rescale from highres if needed\n                        char.shift(\n                            poly_box.bbox[0], poly_box.bbox[1]\n                        )  # Ensure character boxes match line boxes (relative to page)\n                        char.clamp(poly_box.bbox)\n\n                    text_line = fix_unbalanced_tags(\n                        text_line, self.processor.ocr_tokenizer.special_tokens\n                    )\n                    text_line = filter_blacklist_tags(text_line, filter_tag_list)\n                    text = \"\".join([char.text for char in text_line])\n                    text = unwrap_math(text)\n                    text = clean_math_tags(text)\n                    lines.append(\n                        TextLine(\n                            text=text,\n                            polygon=polygon,\n                            chars=text_line,\n                            confidence=confidence,\n                            words=words_from_chars(text_line, poly_box)\n                            if return_words\n                            else [],\n                        )\n                    )\n\n            if sort_lines:\n                lines = sort_text_lines(lines)\n            predictions_by_image.append(\n                OCRResult(\n                    text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]]\n                )\n            )\n\n        return predictions_by_image\n"
  },
  {
    "path": "surya/recognition/languages.py",
    "content": "CODE_TO_LANGUAGE = {\n    \"_math\": \"Math\",\n    \"af\": \"Afrikaans\",\n    \"am\": \"Amharic\",\n    \"ar\": \"Arabic\",\n    \"as\": \"Assamese\",\n    \"az\": \"Azerbaijani\",\n    \"be\": \"Belarusian\",\n    \"bg\": \"Bulgarian\",\n    \"bn\": \"Bengali\",\n    \"br\": \"Breton\",\n    \"bs\": \"Bosnian\",\n    \"ca\": \"Catalan\",\n    \"cs\": \"Czech\",\n    \"cy\": \"Welsh\",\n    \"da\": \"Danish\",\n    \"de\": \"German\",\n    \"el\": \"Greek\",\n    \"en\": \"English\",\n    \"eo\": \"Esperanto\",\n    \"es\": \"Spanish\",\n    \"et\": \"Estonian\",\n    \"eu\": \"Basque\",\n    \"fa\": \"Persian\",\n    \"fi\": \"Finnish\",\n    \"fr\": \"French\",\n    \"fy\": \"Western Frisian\",\n    \"ga\": \"Irish\",\n    \"gd\": \"Scottish Gaelic\",\n    \"gl\": \"Galician\",\n    \"gu\": \"Gujarati\",\n    \"ha\": \"Hausa\",\n    \"he\": \"Hebrew\",\n    \"hi\": \"Hindi\",\n    \"hr\": \"Croatian\",\n    \"hu\": \"Hungarian\",\n    \"hy\": \"Armenian\",\n    \"id\": \"Indonesian\",\n    \"is\": \"Icelandic\",\n    \"it\": \"Italian\",\n    \"ja\": \"Japanese\",\n    \"jv\": \"Javanese\",\n    \"ka\": \"Georgian\",\n    \"kk\": \"Kazakh\",\n    \"km\": \"Khmer\",\n    \"kn\": \"Kannada\",\n    \"ko\": \"Korean\",\n    \"ku\": \"Kurdish\",\n    \"ky\": \"Kyrgyz\",\n    \"la\": \"Latin\",\n    \"lo\": \"Lao\",\n    \"lt\": \"Lithuanian\",\n    \"lv\": \"Latvian\",\n    \"mg\": \"Malagasy\",\n    \"mk\": \"Macedonian\",\n    \"ml\": \"Malayalam\",\n    \"mn\": \"Mongolian\",\n    \"mr\": \"Marathi\",\n    \"ms\": \"Malay\",\n    \"my\": \"Burmese\",\n    \"ne\": \"Nepali\",\n    \"nl\": \"Dutch\",\n    \"no\": \"Norwegian\",\n    \"om\": \"Oromo\",\n    \"or\": \"Oriya\",\n    \"pa\": \"Punjabi\",\n    \"pl\": \"Polish\",\n    \"ps\": \"Pashto\",\n    \"pt\": \"Portuguese\",\n    \"ro\": \"Romanian\",\n    \"ru\": \"Russian\",\n    \"sa\": \"Sanskrit\",\n    \"sd\": \"Sindhi\",\n    \"si\": \"Sinhala\",\n    \"sk\": \"Slovak\",\n    \"sl\": \"Slovenian\",\n    \"so\": \"Somali\",\n    \"sq\": \"Albanian\",\n    \"sr\": \"Serbian\",\n    \"su\": \"Sundanese\",\n    \"sv\": \"Swedish\",\n    \"sw\": \"Swahili\",\n    \"ta\": \"Tamil\",\n    \"te\": \"Telugu\",\n    \"th\": \"Thai\",\n    \"tl\": \"Tagalog\",\n    \"tr\": \"Turkish\",\n    \"ug\": \"Uyghur\",\n    \"uk\": \"Ukrainian\",\n    \"ur\": \"Urdu\",\n    \"uz\": \"Uzbek\",\n    \"vi\": \"Vietnamese\",\n    \"xh\": \"Xhosa\",\n    \"yi\": \"Yiddish\",\n    \"zh\": \"Chinese\",\n}\n\nLANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}\n"
  },
  {
    "path": "surya/recognition/postprocessing.py",
    "content": "import re\nfrom typing import List, Dict\n\nfrom surya.recognition.schema import TextChar\n\n\ndef truncate_repetitions(text: str, min_len=15):\n    # From nougat, with some cleanup\n    if len(text) < 2 * min_len:\n        return text\n\n    # try to find a length at which the tail is repeating\n    max_rep_len = None\n    for rep_len in range(min_len, int(len(text) / 2)):\n        # check if there is a repetition at the end\n        same = True\n        for i in range(0, rep_len):\n            if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]:\n                same = False\n                break\n\n        if same:\n            max_rep_len = rep_len\n\n    if max_rep_len is None:\n        return text\n\n    lcs = text[-max_rep_len:]\n\n    # remove all but the last repetition\n    text_to_truncate = text\n    while text_to_truncate.endswith(lcs):\n        text_to_truncate = text_to_truncate[:-max_rep_len]\n\n    return text[: len(text_to_truncate)]\n\n\ndef extract_tags(proposed_tags: List[str]) -> List[str]:\n    tags = []\n    for tag in proposed_tags:\n        tag_match = re.match(tag_pattern, tag)\n        if not tag_match:\n            continue\n\n        if not tag_match.group(1) == \"/\":\n            continue\n\n        tags.append(tag_match.group(2))\n    return tags\n\n\ntag_pattern = re.compile(r\"<(/?)([a-z]+)([^>]*)>?\", re.IGNORECASE)\n\n\ndef cleanup_math(line: str):\n    matches = re.finditer(r\"(<math[^>]*>)(.*?)</math>\", line, re.DOTALL)\n    result = line\n\n    for match in matches:\n        opening_tag = match.group(1)  # The opening <math> tag with attributes\n        full_match = match.group(0)  # The entire <math>content</math> tag\n        block_content = match.group(2)  # Just the content inside the tags\n\n        clean_block = re.sub(r\"<[^>]+>\", \"\", block_content)\n\n        if not re.search(r\"[\\\\\\_]\", clean_block):\n            result = result.replace(full_match, clean_block)\n        else:\n            result = result.replace(full_match, f\"{opening_tag}{clean_block}</math>\")\n\n    return result\n\n\ndef fix_unbalanced_tags(\n    text_chars: List[TextChar], special_tokens: Dict[str, list]\n) -> List[TextChar]:\n    self_closing_tags = [\"br\"]\n\n    open_tags = []\n\n    format_tags = extract_tags(special_tokens[\"formatting\"]) + extract_tags(\n        special_tokens[\"math_external\"]\n    )\n\n    for char in text_chars:\n        if len(char.text) <= 1:\n            continue\n\n        tag_match = re.match(tag_pattern, char.text)\n        if not tag_match:\n            continue\n\n        is_closing = tag_match.group(1) == \"/\"\n        tag_name = tag_match.group(2).lower()\n\n        if tag_name not in format_tags:\n            continue\n\n        if tag_name in self_closing_tags:\n            continue\n\n        # Self-closing tags\n        if tag_match.group(3) and tag_match.group(3).strip().endswith(\"/\"):\n            continue\n\n        if is_closing:\n            if open_tags and open_tags[-1] == tag_name:\n                open_tags.pop()\n        else:\n            open_tags.append(tag_name)\n\n    for tag in open_tags:\n        text_chars.append(\n            TextChar(\n                text=f\"</{tag}>\",\n                confidence=0,\n                polygon=[[0, 0], [1, 0], [1, 1], [0, 1]],\n                bbox_valid=False,\n            )\n        )\n    return text_chars\n"
  },
  {
    "path": "surya/recognition/schema.py",
    "content": "import math\nimport numpy as np\nfrom typing import Optional, List\n\nfrom pydantic import BaseModel, field_validator\n\nfrom surya.common.polygon import PolygonBox\n\n\nclass BaseChar(PolygonBox):\n    text: str\n    confidence: Optional[float] = 0\n\n    @field_validator(\"confidence\", mode=\"before\")\n    @classmethod\n    def validate_confidence(cls, v: float) -> float:\n        if v is None:\n            return 0\n        elif math.isnan(v) or np.isnan(v):\n            return 0\n        return v\n\n\nclass TextChar(BaseChar):\n    bbox_valid: bool = True  # This is false when the given bbox is not valid\n\n\nclass TextWord(BaseChar):\n    bbox_valid: bool = True\n\n\nclass TextLine(BaseChar):\n    chars: List[TextChar]  # Individual characters in the line\n    original_text_good: bool = False\n    words: List[TextWord] | None = None\n\n\nclass OCRResult(BaseModel):\n    text_lines: List[TextLine]\n    image_bbox: List[float]\n"
  },
  {
    "path": "surya/recognition/util.py",
    "content": "import re\nfrom typing import List, Tuple\n\nimport numpy\nimport torch\n\nfrom surya.common.polygon import PolygonBox\nfrom surya.recognition.schema import TextLine, TextWord, TextChar\n\nMATH_SYMBOLS = [\"+\", \"-\", \"*\", \"=\", \"^\", \"_\", \"\\\\\", \"{\", \"}\"]\n\n\ndef unwrap_math(text: str) -> str:\n    if len(text) > 50:\n        return text\n\n    # Detected as math, but does not contain LaTeX commands\n    if (\n        re.match(r'^\\s*<math(?:\\s+display=\"inline\")?.*?</math>\\s*$', text, re.DOTALL)\n        and text.count(\"<math\") == 1\n        and not any([symb in text for symb in MATH_SYMBOLS])\n    ):\n        # Remove math tags\n        text = re.sub(r\"<math.*?>\", \"\", text)\n        text = re.sub(r\"</math>\", \"\", text)\n\n    return text\n\n\nMATH_BLOCK = re.compile(r\"(<math\\b[^>]*>)(.*?)</math>\", flags=re.I | re.S)\nSTRIP_TAGS = re.compile(r\"</?(?:br|u|del|mark|i|b|sup|sub)\\b[^>]*>\", flags=re.I | re.S)\nDEFAULT_TAGS_TO_FILTER = [\"p\", \"li\", \"ul\", \"ol\", \"table\", \"td\", \"tr\", \"th\", \"tbody\", \"pre\"]\n\ndef filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]:\n    filtered_chars = []\n    char_buffer = []\n    in_tag = False\n    if tags_to_filter is None:\n        tags_to_filter = DEFAULT_TAGS_TO_FILTER\n\n    for text_char in text_chars:\n        char = text_char.text\n\n        if char.startswith(\"<\") or in_tag:\n            in_tag = True\n            char_buffer.append(text_char)\n            if char.endswith(\">\"):\n                full_tag = ''.join(c.text for c in char_buffer)\n                inner = full_tag[1:-1].strip()  # remove < >\n                inner = inner.strip(\"/\")  # remove '/'\n                \n                # Possible that it is just an empty <>\n                if not inner:\n                    filtered_chars.extend(char_buffer)\n                    in_tag = False\n                    char_buffer = []\n                    continue\n                \n                tag_name_candidate = inner.split()[0]   # remove any attributes\n                if tag_name_candidate in tags_to_filter:\n                    # Discard tag\n                    pass\n                else:\n                    # Keep tag\n                    filtered_chars.extend(char_buffer)\n\n                in_tag = False\n                char_buffer = []\n        else:\n            filtered_chars.append(text_char)\n\n    # Flush buffer if we never reached a tag close\n    if char_buffer:\n        filtered_chars.extend(char_buffer)\n\n    return filtered_chars\n\n\ndef clean_math_tags(html: str) -> str:\n    # strip unwanted tags inside every well‑formed <math>…</math>\n    def _inner(m):\n        inner = STRIP_TAGS.sub(\"\", m.group(2))\n        return f\"{m.group(1)}{inner}</math>\" if inner.strip() else \"\"\n\n    cleaned = MATH_BLOCK.sub(_inner, html)\n\n    # drop only orphan *closing* </math> tags\n    depth = 0\n    parts = []\n    for token in re.split(r\"(</?math[^>]*>)\", cleaned, flags=re.I):\n        if token.lower().startswith(\"<math\"):\n            depth += 1\n            parts.append(token)\n        elif token.lower() == \"</math>\":\n            if depth:  # keep it only if it matches an open\n                depth -= 1\n                parts.append(token)\n            # else: skip orphan closing tag\n        else:\n            parts.append(token)\n    return \"\".join(parts)\n\n\ndef sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):\n    # Sorts in reading order.  Not 100% accurate, this should only\n    # be used as a starting point for more advanced sorting.\n    vertical_groups = {}\n    for line in lines:\n        group_key = (\n            round(\n                line.bbox[1]\n                if isinstance(line, TextLine)\n                else line[\"bbox\"][1] / tolerance\n            )\n            * tolerance\n        )\n        if group_key not in vertical_groups:\n            vertical_groups[group_key] = []\n        vertical_groups[group_key].append(line)\n\n    # Sort each group horizontally and flatten the groups into a single list\n    sorted_lines = []\n    for _, group in sorted(vertical_groups.items()):\n        sorted_group = sorted(\n            group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x[\"bbox\"][0]\n        )\n        sorted_lines.extend(sorted_group)\n\n    return sorted_lines\n\n\ndef clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1):\n    if len(bboxes) < 2:\n        return bboxes\n\n    new_bboxes = [bboxes[0]]\n    for i in range(1, len(bboxes)):\n        close = True\n        prev_bbox = bboxes[i - 1]\n        bbox = bboxes[i]\n        for j in range(4):\n            if (\n                abs(bbox[j][0] - prev_bbox[j][0]) > thresh\n                or abs(bbox[j][1] - prev_bbox[j][1]) > thresh\n            ):\n                close = False\n                break\n\n        if not close:\n            new_bboxes.append(bboxes[i])\n\n    return new_bboxes\n\n\ndef words_from_chars(chars: List[TextChar], line_box: PolygonBox):\n    words = []\n    word = None\n    for i, char in enumerate(chars):\n        if not char.bbox_valid:\n            if word:\n                words.append(word)\n                word = None\n            continue\n\n        if not word:\n            word = TextWord(**char.model_dump())\n\n            # Fit bounds to line if first word\n            if i == 0:\n                word.merge_left(line_box)\n\n        elif not char.text.strip():\n            if word:\n                words.append(word)\n            word = None\n        else:\n            # Merge bboxes\n            word.merge(char)\n            word.text = word.text + char.text\n\n            if i == len(chars) - 1:\n                word.merge_right(line_box)\n    if word:\n        words.append(word)\n\n    return words"
  },
  {
    "path": "surya/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "surya/scripts/config.py",
    "content": "from typing import List\n\nimport click\nimport os\nfrom surya.input.load import load_from_folder, load_from_file\nfrom surya.settings import settings\n\n\nclass CLILoader:\n    def __init__(self, filepath: str, cli_options: dict, highres: bool = False):\n        self.page_range = cli_options.get(\"page_range\")\n        if self.page_range:\n            self.page_range = self.parse_range_str(self.page_range)\n        self.filepath = filepath\n        self.config = cli_options\n        self.save_images = cli_options.get(\"images\", False)\n        self.debug = cli_options.get(\"debug\", False)\n        self.output_dir = cli_options.get(\"output_dir\")\n\n        self.load(highres)\n\n    @staticmethod\n    def common_options(fn):\n        fn = click.argument(\"input_path\", type=click.Path(exists=True), required=True)(fn)\n        fn = click.option(\"--output_dir\", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, \"surya\"), help=\"Directory to save output.\")(fn)\n        fn = click.option(\"--page_range\", type=str, default=None, help=\"Page range to convert, specify comma separated page numbers or ranges.  Example: 0,5-10,20\")(fn)\n        fn = click.option(\"--images\", is_flag=True, help=\"Save images of detected bboxes.\", default=False)(fn)\n        fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn)\n        return fn\n\n    def load(self, highres: bool = False):\n        highres_images = None\n        if os.path.isdir(self.filepath):\n            images, names = load_from_folder(self.filepath, self.page_range)\n            folder_name = os.path.basename(self.filepath)\n            if highres:\n                highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)\n        else:\n            images, names = load_from_file(self.filepath, self.page_range)\n            folder_name = os.path.basename(self.filepath).split(\".\")[0]\n            if highres:\n                highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)\n\n\n        self.images = images\n        self.highres_images = highres_images\n        self.names = names\n\n        self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name))\n        os.makedirs(self.result_path, exist_ok=True)\n\n    @staticmethod\n    def parse_range_str(range_str: str) -> List[int]:\n        range_lst = range_str.split(\",\")\n        page_lst = []\n        for i in range_lst:\n            if \"-\" in i:\n                start, end = i.split(\"-\")\n                page_lst += list(range(int(start), int(end) + 1))\n            else:\n                page_lst.append(int(i))\n        page_lst = sorted(list(set(page_lst)))  # Deduplicate page numbers and sort in order\n        return page_lst"
  },
  {
    "path": "surya/scripts/detect_layout.py",
    "content": "import time\nimport click\nimport copy\nimport json\nfrom collections import defaultdict\n\nfrom surya.foundation import FoundationPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.debug.draw import draw_polys_on_image\nfrom surya.logging import configure_logging, get_logger\nfrom surya.scripts.config import CLILoader\nfrom surya.settings import settings\nimport os\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@click.command(help=\"Detect layout of an input file or folder (PDFs or image).\")\n@CLILoader.common_options\ndef detect_layout_cli(input_path: str, **kwargs):\n    loader = CLILoader(input_path, kwargs)\n\n    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)\n    layout_predictor = LayoutPredictor(foundation_predictor)\n\n    start = time.time()\n    layout_predictions = layout_predictor(loader.images)\n\n    if loader.debug:\n        logger.debug(f\"Layout took {time.time() - start} seconds\")\n\n    if loader.save_images:\n        for idx, (image, layout_pred, name) in enumerate(\n            zip(loader.images, layout_predictions, loader.names)\n        ):\n            polygons = [p.polygon for p in layout_pred.bboxes]\n            labels = [f\"{p.label}-{p.position}\" for p in layout_pred.bboxes]\n            bbox_image = draw_polys_on_image(\n                polygons, copy.deepcopy(image), labels=labels\n            )\n            bbox_image.save(\n                os.path.join(loader.result_path, f\"{name}_{idx}_layout.png\")\n            )\n\n    predictions_by_page = defaultdict(list)\n    for idx, (pred, name, image) in enumerate(\n        zip(layout_predictions, loader.names, loader.images)\n    ):\n        out_pred = pred.model_dump()\n        out_pred[\"page\"] = len(predictions_by_page[name]) + 1\n        predictions_by_page[name].append(out_pred)\n\n    with open(\n        os.path.join(loader.result_path, \"results.json\"), \"w+\", encoding=\"utf-8\"\n    ) as f:\n        json.dump(predictions_by_page, f, ensure_ascii=False)\n\n    logger.info(f\"Wrote results to {loader.result_path}\")\n"
  },
  {
    "path": "surya/scripts/detect_text.py",
    "content": "import click\nimport copy\nimport json\nimport time\nfrom collections import defaultdict\n\nfrom surya.detection import DetectionPredictor\nfrom surya.debug.draw import draw_polys_on_image\nfrom surya.logging import configure_logging, get_logger\nfrom surya.scripts.config import CLILoader\nimport os\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@click.command(help=\"Detect bboxes in an input file or folder (PDFs or image).\")\n@CLILoader.common_options\ndef detect_text_cli(input_path: str, **kwargs):\n    loader = CLILoader(input_path, kwargs)\n\n    det_predictor = DetectionPredictor()\n\n    start = time.time()\n    predictions = det_predictor(loader.images, include_maps=loader.debug)\n    end = time.time()\n    if loader.debug:\n        logger.debug(f\"Detection took {end - start} seconds\")\n\n    if loader.save_images:\n        for idx, (image, pred, name) in enumerate(\n            zip(loader.images, predictions, loader.names)\n        ):\n            polygons = [p.polygon for p in pred.bboxes]\n            bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))\n            bbox_image.save(os.path.join(loader.result_path, f\"{name}_{idx}_bbox.png\"))\n\n            if loader.debug:\n                heatmap = pred.heatmap\n                heatmap.save(os.path.join(loader.result_path, f\"{name}_{idx}_heat.png\"))\n\n    predictions_by_page = defaultdict(list)\n    for idx, (pred, name, image) in enumerate(\n        zip(predictions, loader.names, loader.images)\n    ):\n        out_pred = pred.model_dump(exclude=[\"heatmap\", \"affinity_map\"])\n        out_pred[\"page\"] = len(predictions_by_page[name]) + 1\n        predictions_by_page[name].append(out_pred)\n\n    with open(\n        os.path.join(loader.result_path, \"results.json\"), \"w+\", encoding=\"utf-8\"\n    ) as f:\n        json.dump(predictions_by_page, f, ensure_ascii=False)\n\n    logger.info(f\"Wrote results to {loader.result_path}\")\n"
  },
  {
    "path": "surya/scripts/finetune_ocr.py",
    "content": "from __future__ import annotations\nfrom dataclasses import dataclass, field\nfrom typing import Optional, Tuple\nfrom datasets import load_dataset\nimport numpy as np\nimport torch\nfrom transformers import (\n    HfArgumentParser,\n    TrainingArguments,\n    Trainer,\n)\n\nfrom surya.common.surya import SuryaModel\nfrom surya.common.surya.processor import SuryaOCRProcessor\nfrom surya.foundation import FoundationPredictor\nfrom surya.common.surya.processor.schema import ImageInput, TextInput\nfrom surya.common.surya.schema import TaskNames\nfrom surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING\n\n# Do not change these defaults\nOCR_TASK_NAME = TaskNames.ocr_with_boxes\nOCR_MAX_IMAGE_SIZE = (1024, 512)\n\n# Simple wrapper for huggingface dataset\nclass SuryaOCRDataset(torch.utils.data.Dataset):\n    def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):\n        super().__init__()\n        self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split=\"train\")\n        self.processor = processor\n\n    def __len__(self):\n        return len(self.hf_dataset)\n\n    def get_script_text(self, text: str) -> str:\n        scripts = get_top_scripts(text)\n        script_text = \"\".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts)\n        return script_text\n\n    def __getitem__(self, index):\n        try:\n            data = self.hf_dataset[index]\n            image = data[\"image\"]\n            image = image.convert(\"RGB\")\n            image = np.asarray(image, dtype=np.float32)\n            image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE)\n\n            # Add in script information\n            gt_text = data[\"text\"]\n            gt_text = self.get_script_text(gt_text) + gt_text\n\n            return_dict = {\n                \"task\": TaskNames.ocr_with_boxes,\n                \"inputs\": [\n                    ImageInput(type=\"image\", image=image, rotated=False),\n                    # This empty TextInput **must be included** to match the original format\n                    TextInput(type=\"text\", text=\"\"),\n                    TextInput(type=\"text\",text=gt_text),\n                ],\n            }\n            return return_dict\n        except:\n            import traceback; traceback.print_exc()\n            return self.__getitem__((index + 1) % self.__len__())\n\nclass SuryaOCRDataCollator:\n    def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):\n        self.processor = processor\n        self.max_sequence_length = data_args.max_sequence_length\n\n    def __call__(self, inputs):\n        # Use right padding for training. Defaults to left for inference\n        processed_batch = self.processor(inputs, padding_side=\"right\")\n        \n        if self.max_sequence_length is not None:\n            processed_batch[\"input_ids\"] = processed_batch[\"input_ids\"][:, :self.max_sequence_length]\n            processed_batch[\"attention_mask\"] = processed_batch[\"attention_mask\"][:, :self.max_sequence_length]\n            processed_batch[\"position_ids\"] = processed_batch[\"position_ids\"][:, :self.max_sequence_length]\n\n        lm_labels = processed_batch[\"input_ids\"].clone()\n        skip_label_mask = (\n            (lm_labels == self.processor.pad_token_id )\n            | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes])\n            | (lm_labels == self.processor.eoi_token_id)\n            | (lm_labels == self.processor.image_token_id)\n        )\n        lm_labels[skip_label_mask] = -100\n        processed_batch[\"labels\"] = lm_labels\n\n        return processed_batch\n\ndef load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]:\n    foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path)\n    return foundation_predictor.model, foundation_predictor.processor\n\n@dataclass\nclass SuryaOCRModelArguments:\n    pretrained_checkpoint_path: Optional[str] = field(default=None)\n\n@dataclass\nclass SuryaOCRDataArguments:\n    dataset_name: str = field(default=\"datalab-to/ocr_finetune_example\")\n    num_loading_proc: int = field(default=16)\n    max_sequence_length: Optional[int] = field(default=None)\n\n@dataclass\nclass SuryaOCRTrainingArguments(TrainingArguments):\n    remove_unused_columns: bool = field(default=False)\n    \ndef main():\n    parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path)\n    dataset = SuryaOCRDataset(processor, data_args)\n    collator = SuryaOCRDataCollator(processor, data_args)\n\n    trainer = Trainer(\n        model=model,\n        args=training_args,\n        train_dataset=dataset,\n        data_collator=collator\n    )\n\n    trainer.train()\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "surya/scripts/hf_to_s3.py",
    "content": "import json\nimport shutil\nimport datetime\nfrom pathlib import Path\nimport boto3\n\nfrom huggingface_hub import snapshot_download\n\nimport click\nfrom tqdm import tqdm\n\nS3_API_URL = \"https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com\"\n\n\n# Example usage - python scripts/hf_to_s3.py <REPO_NAME> layout\n# This will upload to s3://layout/TODAYS_DATE\n@click.command(help=\"Uploads the data from huggingface to an S3 bucket\")\n@click.argument(\"hf_repo_id\", type=str)\n@click.argument(\"s3_path\", type=str)\n@click.option(\"--bucket_name\", type=str, default=\"datalab\")\n@click.option(\"--revision_hash\", type=str, default=None)\n@click.option(\"--access_key_id\", type=str, default=\"<access_key_id>\")\n@click.option(\"--access_key_secret\", type=str, default=\"<access_key_secret>\")\n@click.option(\"--suffix\", type=str, default=\"\")\ndef main(\n    hf_repo_id: str,\n    s3_path: str,\n    bucket_name: str,\n    revision_hash: str,\n    access_key_id: str,\n    access_key_secret: str,\n    suffix: str,\n):\n    curr_date = datetime.datetime.now().strftime(\"%Y_%m_%d\")\n    s3_path = f\"{s3_path}/{curr_date}\"\n    if suffix:\n        s3_path = f\"{s3_path}_{suffix}\"\n\n    download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash)\n    download_folder = Path(download_folder)\n    contained_files = list(download_folder.glob(\"*\"))\n    contained_files = [f.name for f in contained_files]  # Just get the base name\n    manifest_file = download_folder / \"manifest.json\"\n\n    with open(manifest_file, \"w\") as f:\n        json.dump({\"files\": contained_files}, f)\n\n    # Upload the files to S3\n    s3_client = boto3.client(\n        service_name=\"s3\",\n        endpoint_url=S3_API_URL,\n        aws_access_key_id=access_key_id,\n        aws_secret_access_key=access_key_secret,\n        region_name=\"auto\",\n    )\n\n    # Iterate through all files in the folder\n    for file_path in tqdm(\n        download_folder.glob(\"*\"), desc=\"Uploading files\", unit=\"file\"\n    ):\n        s3_key = f\"{s3_path}/{file_path.name}\"\n\n        try:\n            s3_client.upload_file(str(file_path), bucket_name, s3_key)\n        except Exception as e:\n            print(f\"Error uploading {file_path}: {str(e)}\")\n\n    shutil.rmtree(download_folder)\n\n    print(f\"Uploaded files to {s3_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "surya/scripts/ocr_latex.py",
    "content": "import os\n\nimport click\nimport json\nimport time\nfrom collections import defaultdict\n\nfrom surya.logging import configure_logging, get_logger\nfrom surya.scripts.config import CLILoader\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.common.surya.schema import TaskNames\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@click.command(help=\"OCR LaTeX equations.\")\n@CLILoader.common_options\ndef ocr_latex_cli(input_path: str, **kwargs):\n    loader = CLILoader(input_path, kwargs, highres=True)\n\n    foundation_predictor = FoundationPredictor()\n    texify_predictor = RecognitionPredictor(foundation_predictor)\n    tasks = [TaskNames.block_without_boxes] * len(loader.images)\n    bboxes = [[[0, 0, image.width, image.height]] for image in loader.images]\n\n    start = time.time()\n    predictions_by_image = texify_predictor(\n        loader.images,\n        tasks,\n        bboxes=bboxes,\n    )\n\n    latex_predictions = [p.text_lines[0].text for p in predictions_by_image]\n\n    if loader.debug:\n        logger.debug(f\"OCR took {time.time() - start:.2f} seconds\")\n        max_chars = max([len(latex) for latex in latex_predictions])\n        logger.debug(f\"Max chars: {max_chars}\")\n\n    out_preds = defaultdict(list)\n    for name, pred, image in zip(loader.names, latex_predictions, loader.images):\n        out_pred = {\n            \"equation\": pred,\n            \"page\": len(out_preds[name]) + 1,\n        }\n        out_preds[name].append(out_pred)\n\n    with open(\n        os.path.join(loader.result_path, \"results.json\"), \"w+\", encoding=\"utf-8\"\n    ) as f:\n        json.dump(out_preds, f, ensure_ascii=False)\n\n    logger.info(f\"Wrote results to {loader.result_path}\")\n"
  },
  {
    "path": "surya/scripts/ocr_text.py",
    "content": "import os\nimport click\nimport json\nimport time\nfrom collections import defaultdict\n\nfrom surya.common.surya.schema import TaskNames\nfrom surya.detection import DetectionPredictor\nfrom surya.debug.text import draw_text_on_image\nfrom surya.logging import configure_logging, get_logger\nfrom surya.foundation import FoundationPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.scripts.config import CLILoader\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@click.command(help=\"OCR text.\")\n@click.option(\"--task_name\", type=str, default=TaskNames.ocr_with_boxes)\n@click.option(\n    \"--disable_math\", is_flag=True, default=False, help=\"Do not recognize math in OCR.\"\n)\n@CLILoader.common_options\ndef ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs):\n    loader = CLILoader(input_path, kwargs, highres=True)\n    task_names = [task_name] * len(loader.images)\n\n    foundation_predictor = FoundationPredictor()\n    det_predictor = DetectionPredictor()\n    rec_predictor = RecognitionPredictor(foundation_predictor)\n\n    start = time.time()\n    predictions_by_image = rec_predictor(\n        loader.images,\n        task_names=task_names,\n        det_predictor=det_predictor,\n        highres_images=loader.highres_images,\n        math_mode=not disable_math,\n    )\n\n    if loader.debug:\n        logger.debug(f\"OCR took {time.time() - start:.2f} seconds\")\n        max_chars = max(\n            [len(line.text) for p in predictions_by_image for line in p.text_lines]\n        )\n        logger.debug(f\"Max chars: {max_chars}\")\n\n    if loader.save_images:\n        for idx, (name, image, pred) in enumerate(\n            zip(loader.names, loader.images, predictions_by_image)\n        ):\n            bboxes = [line.bbox for line in pred.text_lines]\n            pred_text = [line.text for line in pred.text_lines]\n            page_image = draw_text_on_image(bboxes, pred_text, image.size)\n            page_image.save(os.path.join(loader.result_path, f\"{name}_{idx}_text.png\"))\n\n    out_preds = defaultdict(list)\n    for name, pred, image in zip(loader.names, predictions_by_image, loader.images):\n        out_pred = pred.model_dump()\n        out_pred[\"page\"] = len(out_preds[name]) + 1\n        out_preds[name].append(out_pred)\n\n    with open(\n        os.path.join(loader.result_path, \"results.json\"), \"w+\", encoding=\"utf-8\"\n    ) as f:\n        json.dump(out_preds, f, ensure_ascii=False)\n\n    logger.info(f\"Wrote results to {loader.result_path}\")\n"
  },
  {
    "path": "surya/scripts/run_streamlit_app.py",
    "content": "import subprocess\nimport os\n\n\ndef streamlit_app_cli():\n    cur_dir = os.path.dirname(os.path.abspath(__file__))\n    ocr_app_path = os.path.join(cur_dir, \"streamlit_app.py\")\n    cmd = [\"streamlit\", \"run\", ocr_app_path, \"--server.fileWatcherType\", \"none\", \"--server.headless\", \"true\"]\n    subprocess.run(cmd, env={**os.environ, \"IN_STREAMLIT\": \"true\"})"
  },
  {
    "path": "surya/scripts/run_texify_app.py",
    "content": "import subprocess\nimport os\n\n\ndef texify_app_cli():\n    cur_dir = os.path.dirname(os.path.abspath(__file__))\n    ocr_app_path = os.path.join(cur_dir, \"texify_app.py\")\n    cmd = [\"streamlit\", \"run\", ocr_app_path, \"--server.fileWatcherType\", \"none\", \"--server.headless\", \"true\"]\n    subprocess.run(cmd, env={**os.environ, \"IN_STREAMLIT\": \"true\"})"
  },
  {
    "path": "surya/scripts/streamlit_app.py",
    "content": "import io\nimport tempfile\nfrom typing import List\n\nimport pypdfium2\nimport streamlit as st\n\nfrom surya.common.surya.schema import TaskNames\nfrom surya.models import load_predictors\n\nfrom surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image\n\nfrom surya.debug.text import draw_text_on_image\nfrom PIL import Image, ImageDraw\nfrom surya.table_rec import TableResult\nfrom surya.detection import TextDetectionResult\nfrom surya.recognition import OCRResult\nfrom surya.layout import LayoutResult\nfrom surya.settings import settings\nfrom surya.common.util import rescale_bbox, expand_bbox\n\n\n@st.cache_resource()\ndef load_predictors_cached():\n    return load_predictors()\n\n\ndef ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):\n    from pdftext.extraction import plain_text_output\n\n    with tempfile.NamedTemporaryFile(suffix=\".pdf\") as f:\n        f.write(pdf_file.getvalue())\n        f.seek(0)\n\n        # Sample the text from the middle of the PDF\n        page_middle = page_count // 2\n        page_range = range(\n            max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)\n        )\n        text = plain_text_output(f.name, page_range=page_range)\n\n    sample_gap = len(text) // max_samples\n    if len(text) == 0 or sample_gap == 0:\n        return \"This PDF has no text or very little text\", [\"no text\"]\n\n    if sample_gap < sample_len:\n        sample_gap = sample_len\n\n    # Split the text into samples for the model\n    samples = []\n    for i in range(0, len(text), sample_gap):\n        samples.append(text[i : i + sample_len])\n\n    results = predictors[\"ocr_error\"](samples)\n    label = \"This PDF has good text.\"\n    if results.labels.count(\"bad\") / len(results.labels) > 0.2:\n        label = \"This PDF may have garbled or bad OCR text.\"\n    return label, results.labels\n\n\ndef text_detection(img) -> (Image.Image, TextDetectionResult):\n    text_pred = predictors[\"detection\"]([img])[0]\n    text_polygons = [p.polygon for p in text_pred.bboxes]\n    det_img = draw_polys_on_image(text_polygons, img.copy())\n    return det_img, text_pred\n\n\ndef layout_detection(img) -> (Image.Image, LayoutResult):\n    pred = predictors[\"layout\"]([img])[0]\n    polygons = [p.polygon for p in pred.bboxes]\n    labels = [\n        f\"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}\" for p in pred.bboxes\n    ]\n    layout_img = draw_polys_on_image(\n        polygons, img.copy(), labels=labels, label_font_size=18\n    )\n    return layout_img, pred\n\n\ndef table_recognition(\n    img, highres_img, skip_table_detection: bool\n) -> (Image.Image, List[TableResult]):\n    if skip_table_detection:\n        layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]\n        table_imgs = [highres_img]\n    else:\n        _, layout_pred = layout_detection(img)\n        layout_tables_lowres = [\n            line.bbox\n            for line in layout_pred.bboxes\n            if line.label in [\"Table\", \"TableOfContents\"]\n        ]\n        table_imgs = []\n        layout_tables = []\n        for tb in layout_tables_lowres:\n            highres_bbox = rescale_bbox(tb, img.size, highres_img.size)\n            # Slightly expand the box\n            highres_bbox = expand_bbox(highres_bbox)\n            table_imgs.append(highres_img.crop(highres_bbox))\n            layout_tables.append(highres_bbox)\n\n    table_preds = predictors[\"table_rec\"](table_imgs)\n    table_img = img.copy()\n\n    for results, table_bbox in zip(table_preds, layout_tables):\n        adjusted_bboxes = []\n        labels = []\n        colors = []\n\n        for item in results.cells:\n            adjusted_bboxes.append(\n                [\n                    (item.bbox[0] + table_bbox[0]),\n                    (item.bbox[1] + table_bbox[1]),\n                    (item.bbox[2] + table_bbox[0]),\n                    (item.bbox[3] + table_bbox[1]),\n                ]\n            )\n            labels.append(item.label)\n            if \"Row\" in item.label:\n                colors.append(\"blue\")\n            else:\n                colors.append(\"red\")\n        table_img = draw_bboxes_on_image(\n            adjusted_bboxes,\n            highres_img,\n            labels=labels,\n            label_font_size=18,\n            color=colors,\n        )\n    return table_img, table_preds\n\n\n# Function for OCR\ndef ocr(\n    img: Image.Image,\n    highres_img: Image.Image,\n    skip_text_detection: bool = False,\n    recognize_math: bool = True,\n    with_bboxes: bool = True,\n) -> (Image.Image, OCRResult):\n    if skip_text_detection:\n        img = highres_img\n        bboxes = [[[0, 0, img.width, img.height]]]\n    else:\n        bboxes = None\n\n    if with_bboxes:\n        tasks = [TaskNames.ocr_with_boxes]\n    else:\n        tasks = [TaskNames.ocr_without_boxes]\n\n    img_pred = predictors[\"recognition\"](\n        [img],\n        task_names=tasks,\n        bboxes=bboxes,\n        det_predictor=predictors[\"detection\"],\n        highres_images=[highres_img],\n        math_mode=recognize_math,\n        return_words=True,\n    )[0]\n\n    bboxes = [line.bbox for line in img_pred.text_lines]\n    text = [line.text for line in img_pred.text_lines]\n    rec_img = draw_text_on_image(bboxes, text, img.size)\n\n    word_boxes = []\n    for line in img_pred.text_lines:\n        if line.words:\n            word_boxes.extend([word.bbox for word in line.words])\n\n    box_img = img.copy()\n    draw = ImageDraw.Draw(box_img)\n    for word_box in word_boxes:\n        draw.rectangle(word_box, outline=\"red\", width=2)\n\n    return rec_img, img_pred, box_img\n\n\ndef open_pdf(pdf_file):\n    stream = io.BytesIO(pdf_file.getvalue())\n    return pypdfium2.PdfDocument(stream)\n\n\n@st.cache_data()\ndef get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):\n    doc = open_pdf(pdf_file)\n    renderer = doc.render(\n        pypdfium2.PdfBitmap.to_pil,\n        page_indices=[page_num - 1],\n        scale=dpi / 72,\n    )\n    png = list(renderer)[0]\n    png_image = png.convert(\"RGB\")\n    doc.close()\n    return png_image\n\n\n@st.cache_data()\ndef page_counter(pdf_file):\n    doc = open_pdf(pdf_file)\n    doc_len = len(doc)\n    doc.close()\n    return doc_len\n\n\nst.set_page_config(layout=\"wide\")\ncol1, col2 = st.columns([0.5, 0.5])\n\npredictors = load_predictors_cached()\n\nst.markdown(\"\"\"\n# Surya OCR Demo\n\nThis app will let you try surya, a multilingual OCR toolkit.\n\nNotes:\n\n- This works best on documents with printed text.\n- For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines.\n- If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).\n\nFind the project [here](https://github.com/VikParuchuri/surya).\n\"\"\")\n\nin_file = st.sidebar.file_uploader(\n    \"PDF file or image:\", type=[\"pdf\", \"png\", \"jpg\", \"jpeg\", \"gif\", \"webp\"]\n)\n\nif in_file is None:\n    st.stop()\n\nfiletype = in_file.type\npage_count = None\nif \"pdf\" in filetype:\n    page_count = page_counter(in_file)\n    page_number = st.sidebar.number_input(\n        f\"Page number out of {page_count}:\", min_value=1, value=1, max_value=page_count\n    )\n\n    pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)\n    pil_image_highres = get_page_image(\n        in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES\n    )\nelse:\n    pil_image = Image.open(in_file).convert(\"RGB\")\n    pil_image_highres = pil_image\n    page_number = None\n\nrun_text_det = st.sidebar.button(\"Run Text Detection\")\nrun_text_rec = st.sidebar.button(\"Run OCR\")\nrun_layout_det = st.sidebar.button(\"Run Layout Analysis\")\nrun_table_rec = st.sidebar.button(\"Run Table Rec\")\nrun_ocr_errors = st.sidebar.button(\"Run bad PDF text detection\")\nuse_pdf_boxes = st.sidebar.checkbox(\n    \"PDF table boxes\",\n    value=True,\n    help=\"Table recognition only: Use the bounding boxes from the PDF file vs text detection model.\",\n)\nskip_table_detection = st.sidebar.checkbox(\n    \"Skip table detection\",\n    value=False,\n    help=\"Table recognition only: Skip table detection and treat the whole image/page as a table.\",\n)\nskip_text_detection = st.sidebar.checkbox(\n    \"Skip text detection\",\n    value=False,\n    help=\"OCR only: Skip text detection and treat the whole image as a single line.\",\n)\nrecognize_math = st.sidebar.checkbox(\n    \"Recognize math in OCR\",\n    value=True,\n    help=\"Enable math mode in OCR - this will recognize math.\",\n)\nocr_with_boxes = st.sidebar.checkbox(\n    \"OCR with boxes\",\n    value=True,\n    help=\"Enable OCR with boxes - this will predict character-level boxes.\",\n)\n\nif pil_image is None:\n    st.stop()\n\n# Run Text Detection\nif run_text_det:\n    det_img, text_pred = text_detection(pil_image)\n    with col1:\n        st.image(det_img, caption=\"Detected Text\", use_container_width=True)\n        st.json(\n            text_pred.model_dump(exclude=[\"heatmap\", \"affinity_map\"]), expanded=True\n        )\n\n\n# Run layout\nif run_layout_det:\n    layout_img, pred = layout_detection(pil_image)\n    with col1:\n        st.image(layout_img, caption=\"Detected Layout\", use_container_width=True)\n        st.json(pred.model_dump(exclude=[\"segmentation_map\"]), expanded=True)\n\n# Run OCR\nif run_text_rec:\n    rec_img, pred, box_img = ocr(\n        pil_image,\n        pil_image_highres,\n        skip_text_detection,\n        recognize_math,\n        with_bboxes=ocr_with_boxes,\n    )\n    with col1:\n        st.image(rec_img, caption=\"OCR Result\", use_container_width=True)\n        json_tab, text_tab = st.tabs([\"JSON\", \"Text Lines (for debugging)\"])\n        with json_tab:\n            st.json(pred.model_dump(), expanded=False)\n        with text_tab:\n            st.text(\"\\n\".join([p.text for p in pred.text_lines]))\n\n        st.image(\n            box_img,\n            caption=\"OCR with Word Boxes (for debugging)\",\n            use_container_width=True,\n        )\n\n\nif run_table_rec:\n    table_img, pred = table_recognition(\n        pil_image, pil_image_highres, skip_table_detection\n    )\n    with col1:\n        st.image(table_img, caption=\"Table Recognition\", use_container_width=True)\n        st.json([p.model_dump() for p in pred], expanded=True)\n\nif run_ocr_errors:\n    if \"pdf\" not in filetype:\n        st.error(\"This feature only works with PDFs.\")\n    label, results = ocr_errors(in_file, page_count)\n    with col1:\n        st.write(label)\n        st.json(results)\n\nwith col2:\n    st.image(pil_image, caption=\"Uploaded Image\", use_container_width=True)\n"
  },
  {
    "path": "surya/scripts/table_recognition.py",
    "content": "import os\nimport click\nimport copy\nimport json\nfrom collections import defaultdict\n\nfrom surya.logging import configure_logging, get_logger\nfrom surya.scripts.config import CLILoader\nfrom surya.foundation import FoundationPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.table_rec import TableRecPredictor\nfrom surya.debug.draw import draw_bboxes_on_image\nfrom surya.common.util import rescale_bbox, expand_bbox\nfrom surya.settings import settings\n\nconfigure_logging()\nlogger = get_logger()\n\n\n@click.command(help=\"Detect layout of an input file or folder (PDFs or image).\")\n@CLILoader.common_options\n@click.option(\n    \"--skip_table_detection\",\n    is_flag=True,\n    help=\"Tables are already cropped, so don't re-detect tables.\",\n    default=False,\n)\ndef table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs):\n    loader = CLILoader(input_path, kwargs, highres=True)\n\n    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)\n    layout_predictor = LayoutPredictor(foundation_predictor)\n    table_rec_predictor = TableRecPredictor()\n\n    pnums = []\n    prev_name = None\n    for i, name in enumerate(loader.names):\n        if prev_name is None or prev_name != name:\n            pnums.append(0)\n        else:\n            pnums.append(pnums[-1] + 1)\n\n        prev_name = name\n\n    layout_predictions = layout_predictor(loader.images)\n\n    table_imgs = []\n    table_counts = []\n\n    for layout_pred, img, highres_img in zip(\n        layout_predictions, loader.images, loader.highres_images\n    ):\n        # The table may already be cropped\n        if skip_table_detection:\n            table_imgs.append(highres_img)\n            table_counts.append(1)\n        else:\n            # The bbox for the entire table\n            bbox = [\n                line.bbox\n                for line in layout_pred.bboxes\n                if line.label in [\"Table\", \"TableOfContents\"]\n            ]\n            # Number of tables per page\n            table_counts.append(len(bbox))\n\n            if len(bbox) == 0:\n                continue\n\n            page_table_imgs = []\n            highres_bbox = []\n            for bb in bbox:\n                highres_bb = rescale_bbox(bb, img.size, highres_img.size)\n                highres_bb = expand_bbox(highres_bb)\n                page_table_imgs.append(highres_img.crop(highres_bb))\n                highres_bbox.append(highres_bb)\n\n            table_imgs.extend(page_table_imgs)\n\n    table_preds = table_rec_predictor(table_imgs)\n\n    img_idx = 0\n    prev_count = 0\n    table_predictions = defaultdict(list)\n    for i in range(sum(table_counts)):\n        while i >= prev_count + table_counts[img_idx]:\n            prev_count += table_counts[img_idx]\n            img_idx += 1\n\n        pred = table_preds[i]\n        orig_name = loader.names[img_idx]\n        pnum = pnums[img_idx]\n        table_img = table_imgs[i]\n\n        out_pred = pred.model_dump()\n        out_pred[\"page\"] = pnum + 1\n        table_idx = i - prev_count\n        out_pred[\"table_idx\"] = table_idx\n        table_predictions[orig_name].append(out_pred)\n\n        if loader.save_images:\n            rows = [line.bbox for line in pred.rows]\n            cols = [line.bbox for line in pred.cols]\n            row_labels = [f\"Row {line.row_id}\" for line in pred.rows]\n            col_labels = [f\"Col {line.col_id}\" for line in pred.cols]\n            cells = [line.bbox for line in pred.cells]\n\n            rc_image = copy.deepcopy(table_img)\n            rc_image = draw_bboxes_on_image(\n                rows, rc_image, labels=row_labels, label_font_size=20, color=\"blue\"\n            )\n            rc_image = draw_bboxes_on_image(\n                cols, rc_image, labels=col_labels, label_font_size=20, color=\"red\"\n            )\n            rc_image.save(\n                os.path.join(\n                    loader.result_path, f\"{name}_page{pnum + 1}_table{table_idx}_rc.png\"\n                )\n            )\n\n            cell_image = copy.deepcopy(table_img)\n            cell_image = draw_bboxes_on_image(cells, cell_image, color=\"green\")\n            cell_image.save(\n                os.path.join(\n                    loader.result_path,\n                    f\"{name}_page{pnum + 1}_table{table_idx}_cells.png\",\n                )\n            )\n\n    with open(\n        os.path.join(loader.result_path, \"results.json\"), \"w+\", encoding=\"utf-8\"\n    ) as f:\n        json.dump(table_predictions, f, ensure_ascii=False)\n\n    logger.info(f\"Wrote results to {loader.result_path}\")\n"
  },
  {
    "path": "surya/scripts/texify_app.py",
    "content": "import os\nimport re\nfrom typing import List\n\nfrom surya.recognition import RecognitionPredictor\nfrom surya.foundation import FoundationPredictor\nfrom surya.common.surya.schema import TaskNames\n\nos.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = (\n    \"1\"  # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS\n)\n\nimport io\n\nimport pandas as pd\nimport streamlit as st\nfrom streamlit_drawable_canvas import st_canvas\nimport hashlib\nimport pypdfium2\n\nfrom surya.settings import settings\nfrom PIL import Image\n\nMAX_WIDTH = 800\nMAX_HEIGHT = 1000\n\n\ndef replace_fences(text):\n    text = re.sub(r'<math display=\"block\">(.*?)</math>', r\"$$\\1$$\", text)\n    text = re.sub(r\"<math>(.*?)</math>\", r\"$\\1$\", text)\n    text = re.sub(r'<math display=\"inline\">(.*?)</math>', r\"$\\1$\", text)\n    return text\n\n\n@st.cache_resource()\ndef load_predictor():\n    foundation_predictor = FoundationPredictor()\n    return RecognitionPredictor(foundation_predictor)\n\n\n@st.cache_data()\ndef inference(pil_image: Image.Image, bbox: List[float]):\n    input_img = pil_image.crop(bbox)\n    bbox = [0, 0, input_img.width, input_img.height]\n    model_output = predictor(\n        [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]]\n    )\n    return model_output[0].text_lines[0].text\n\n\ndef open_pdf(pdf_file):\n    stream = io.BytesIO(pdf_file.getvalue())\n    return pypdfium2.PdfDocument(stream)\n\n\n@st.cache_data()\ndef get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES):\n    doc = open_pdf(pdf_file)\n    renderer = doc.render(\n        pypdfium2.PdfBitmap.to_pil,\n        page_indices=[page_num - 1],\n        scale=dpi / 72,\n    )\n    png = list(renderer)[0]\n    png_image = png.convert(\"RGB\")\n    doc.close()\n    return png_image\n\n\n@st.cache_data()\ndef page_counter(pdf_file):\n    doc = open_pdf(pdf_file)\n    doc_len = len(doc)\n    doc.close()\n    return doc_len\n\n\ndef resize_image(pil_image):\n    if pil_image is None:\n        return\n    pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)\n\n\ndef get_canvas_hash(pil_image):\n    return hashlib.md5(pil_image.tobytes()).hexdigest()\n\n\nst.set_page_config(layout=\"wide\")\n\ntop_message = \"\"\"### LaTeX OCR\n\nAfter the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right.\n\"\"\"\n\nst.markdown(top_message)\ncol1, col2 = st.columns([0.7, 0.3])\n\npredictor = load_predictor()\n\nin_file = st.sidebar.file_uploader(\n    \"PDF file or image:\", type=[\"pdf\", \"png\", \"jpg\", \"jpeg\", \"gif\", \"webp\"]\n)\nif in_file is None:\n    st.stop()\n\nif in_file is None:\n    st.stop()\n\nfiletype = in_file.type\npage_count = None\nif \"pdf\" in filetype:\n    page_count = page_counter(in_file)\n    page_number = st.sidebar.number_input(\n        f\"Page number out of {page_count}:\", min_value=1, value=1, max_value=page_count\n    )\n    pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)\nelse:\n    pil_image = Image.open(in_file).convert(\"RGB\")\n    page_number = None\n\nif pil_image is None:\n    st.stop()\n\npil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)\ncanvas_hash = get_canvas_hash(pil_image)\n\nwith col1:\n    # Create a canvas component\n    canvas_result = st_canvas(\n        fill_color=\"rgba(255, 165, 0, 0.1)\",  # Fixed fill color with some opacity\n        stroke_width=1,\n        stroke_color=\"#FFAA00\",\n        background_color=\"#FFF\",\n        background_image=pil_image,\n        update_streamlit=True,\n        height=pil_image.height,\n        width=pil_image.width,\n        drawing_mode=\"rect\",\n        point_display_radius=0,\n        key=canvas_hash,\n    )\n\nif not canvas_result.json_data:\n    st.stop()\n\nobjects = pd.json_normalize(\n    canvas_result.json_data[\"objects\"]\n)  # need to convert obj to str because PyArrow\nbbox_list = None\nif objects.shape[0] > 0:\n    boxes = objects[objects[\"type\"] == \"rect\"][[\"left\", \"top\", \"width\", \"height\"]]\n    boxes[\"right\"] = boxes[\"left\"] + boxes[\"width\"]\n    boxes[\"bottom\"] = boxes[\"top\"] + boxes[\"height\"]\n    bbox_list = boxes[[\"left\", \"top\", \"right\", \"bottom\"]].values.tolist()\n\nif bbox_list:\n    with col2:\n        texts = [inference(pil_image, bbox) for bbox in bbox_list]\n        for idx, latex in enumerate(reversed(texts)):\n            st.markdown(f\"### {len(texts) - idx}\")\n            st.markdown(replace_fences(latex), unsafe_allow_html=True)\n            st.code(latex)\n            st.divider()\n\nwith col2:\n    tips = \"\"\"\n    ### Usage tips\n    - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple.\n    \"\"\"\n    st.markdown(tips)\n"
  },
  {
    "path": "surya/settings.py",
    "content": "import os\nfrom typing import Callable, Dict, Optional\n\nimport torch\nfrom dotenv import find_dotenv\nfrom pydantic import computed_field\nfrom pydantic_settings import BaseSettings\nfrom pathlib import Path\nfrom platformdirs import user_cache_dir\n\n\nclass Settings(BaseSettings):\n    # General\n    TORCH_DEVICE: Optional[str] = None\n    IMAGE_DPI: int = 96  # Used for detection, layout, reading order\n    IMAGE_DPI_HIGHRES: int = 192  # Used for OCR, table rec\n    IN_STREAMLIT: bool = False  # Whether we're running in streamlit\n    FLATTEN_PDF: bool = True  # Flatten PDFs by merging form fields before processing\n    DISABLE_TQDM: bool = False  # Disable tqdm progress bars\n    S3_BASE_URL: str = \"https://models.datalab.to\"\n    PARALLEL_DOWNLOAD_WORKERS: int = (\n        10  # Number of workers for parallel model downloads\n    )\n    MODEL_CACHE_DIR: str = str(Path(user_cache_dir(\"datalab\")) / \"models\")\n    LOGLEVEL: str = \"INFO\"  # Logging level\n\n    # Paths\n    DATA_DIR: str = \"data\"\n    RESULT_DIR: str = \"results\"\n    BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\n    FONT_DIR: str = os.path.join(BASE_DIR, \"static\", \"fonts\")\n\n    @computed_field\n    def TORCH_DEVICE_MODEL(self) -> str:\n        if self.TORCH_DEVICE is not None:\n            return self.TORCH_DEVICE\n\n        if torch.cuda.is_available():\n            return \"cuda\"\n\n        if torch.backends.mps.is_available():\n            return \"mps\"\n\n        try:\n            import torch_xla\n\n            if len(torch_xla.devices()) > 0:\n                return \"xla\"\n        except Exception:\n            pass\n\n        return \"cpu\"\n\n    # Text detection\n    DETECTOR_BATCH_SIZE: Optional[int] = None  # Defaults to 2 for CPU/MPS, 32 otherwise\n    DETECTOR_MODEL_CHECKPOINT: str = \"s3://text_detection/2025_05_07\"\n    DETECTOR_BENCH_DATASET_NAME: str = \"vikp/doclaynet_bench\"\n    DETECTOR_IMAGE_CHUNK_HEIGHT: int = (\n        1400  # Height at which to slice images vertically\n    )\n    DETECTOR_TEXT_THRESHOLD: float = (\n        0.6  # Threshold for text detection (above this is considered text)\n    )\n    DETECTOR_BLANK_THRESHOLD: float = (\n        0.35  # Threshold for blank space (below this is considered blank)\n    )\n    DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(\n        8, os.cpu_count()\n    )  # Number of workers for postprocessing\n    DETECTOR_MIN_PARALLEL_THRESH: int = (\n        3  # Minimum number of images before we parallelize\n    )\n    DETECTOR_BOX_Y_EXPAND_MARGIN: float = (\n        0.05  # Margin by which to expand detected boxes vertically\n    )\n    COMPILE_DETECTOR: bool = False\n\n    # Text recognition\n    FOUNDATION_MODEL_CHECKPOINT: str = \"s3://text_recognition/2025_09_23\"\n    FOUNDATION_MODEL_QUANTIZE: bool = False\n    FOUNDATION_MAX_TOKENS: Optional[int] = None\n    FOUNDATION_CHUNK_SIZE: Optional[int] = None\n    FOUNDATION_PAD_TO_NEAREST: int = 256\n    COMPILE_FOUNDATION: bool = False\n    FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9\n\n    RECOGNITION_MODEL_CHECKPOINT: str = \"s3://text_recognition/2025_09_23\"\n    RECOGNITION_BATCH_SIZE: Optional[int] = (\n        None  # Defaults to 8 for CPU/MPS, 256 otherwise\n    )\n    RECOGNITION_RENDER_FONTS: Dict[str, str] = {\n        \"all\": os.path.join(FONT_DIR, \"GoNotoCurrent-Regular.ttf\"),\n        \"zh\": os.path.join(FONT_DIR, \"GoNotoCJKCore.ttf\"),\n        \"ja\": os.path.join(FONT_DIR, \"GoNotoCJKCore.ttf\"),\n        \"ko\": os.path.join(FONT_DIR, \"GoNotoCJKCore.ttf\"),\n    }\n    RECOGNITION_FONT_DL_BASE: str = (\n        \"https://github.com/satbyy/go-noto-universal/releases/download/v7.0\"\n    )\n    RECOGNITION_BENCH_DATASET_NAME: str = \"vikp/rec_bench\"\n    RECOGNITION_PAD_VALUE: int = 255  # Should be 0 or 255\n\n    # Layout\n    LAYOUT_MODEL_CHECKPOINT: str = \"s3://layout/2025_09_23\"\n    LAYOUT_IMAGE_SIZE: Dict = {\"height\": 768, \"width\": 768}\n    LAYOUT_SLICE_MIN: Dict = {\n        \"height\": 1500,\n        \"width\": 1500,\n    }  # When to start slicing images\n    LAYOUT_SLICE_SIZE: Dict = {\"height\": 1200, \"width\": 1200}  # Size of slices\n    LAYOUT_BATCH_SIZE: Optional[int] = None\n    LAYOUT_BENCH_DATASET_NAME: str = \"vikp/publaynet_bench\"\n    LAYOUT_MAX_BOXES: int = 100\n    COMPILE_LAYOUT: bool = False\n    LAYOUT_BENCH_DATASET_NAME: str = \"vikp/publaynet_bench\"\n    ORDER_BENCH_DATASET_NAME: str = \"vikp/order_bench\"\n\n    # Table Rec\n    TABLE_REC_MODEL_CHECKPOINT: str = \"s3://table_recognition/2025_02_18\"\n    TABLE_REC_IMAGE_SIZE: Dict = {\"height\": 768, \"width\": 768}\n    TABLE_REC_MAX_BOXES: int = 150\n    TABLE_REC_BATCH_SIZE: Optional[int] = None\n    TABLE_REC_BENCH_DATASET_NAME: str = \"datalab-to/fintabnet_bench\"\n    COMPILE_TABLE_REC: bool = False\n\n    # Texify\n    TEXIFY_BENCHMARK_DATASET: str = \"datalab-to/texify_bench\"\n\n    # OCR Error Detection\n    OCR_ERROR_MODEL_CHECKPOINT: str = \"s3://ocr_error_detection/2025_02_18\"\n    OCR_ERROR_BATCH_SIZE: Optional[int] = None\n    COMPILE_OCR_ERROR: bool = False\n\n    # Tesseract (for benchmarks only)\n    TESSDATA_PREFIX: Optional[str] = None\n\n    COMPILE_ALL: bool = False\n\n    @computed_field\n    def DETECTOR_STATIC_CACHE(self) -> bool:\n        return (\n            self.COMPILE_ALL\n            or self.COMPILE_DETECTOR\n            or self.TORCH_DEVICE_MODEL == \"xla\"\n        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise\n\n    @computed_field\n    def LAYOUT_STATIC_CACHE(self) -> bool:\n        return (\n            self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == \"xla\"\n        )\n\n    @computed_field\n    def FOUNDATION_XLA(self) -> bool:\n        return (\n            self.TORCH_DEVICE_MODEL == \"xla\"\n        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise\n\n    @computed_field\n    def FOUNDATION_STATIC_CACHE(self) -> bool:\n        return (\n            self.COMPILE_ALL\n            or self.COMPILE_FOUNDATION\n            or self.TORCH_DEVICE_MODEL == \"xla\"\n        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise\n\n    @computed_field\n    def TABLE_REC_STATIC_CACHE(self) -> bool:\n        return (\n            self.COMPILE_ALL\n            or self.COMPILE_TABLE_REC\n            or self.TORCH_DEVICE_MODEL == \"xla\"\n        )\n\n    @computed_field\n    def OCR_ERROR_STATIC_CACHE(self) -> bool:\n        return (\n            self.COMPILE_ALL\n            or self.COMPILE_OCR_ERROR\n            or self.TORCH_DEVICE_MODEL == \"xla\"\n        )\n\n    @computed_field\n    def MODEL_DTYPE(self) -> torch.dtype:\n        if self.TORCH_DEVICE_MODEL == \"cpu\":\n            return torch.float32\n        if self.TORCH_DEVICE_MODEL == \"xla\":\n            return torch.bfloat16\n        return torch.float16\n\n    @computed_field\n    def MODEL_DTYPE_BFLOAT(self) -> torch.dtype:\n        if self.TORCH_DEVICE_MODEL == \"cpu\":\n            return torch.float32\n        if self.TORCH_DEVICE_MODEL == \"mps\":\n            return torch.bfloat16\n        return torch.bfloat16\n\n    @computed_field\n    def INFERENCE_MODE(self) -> Callable:\n        if self.TORCH_DEVICE_MODEL == \"xla\":\n            return torch.no_grad\n        return torch.inference_mode\n\n    class Config:\n        env_file = find_dotenv(\"local.env\")\n        extra = \"ignore\"\n\n\nsettings = Settings()\n"
  },
  {
    "path": "surya/table_rec/__init__.py",
    "content": "from copy import deepcopy\nfrom itertools import chain\nfrom typing import List\n\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom surya.common.xla import mark_step\nfrom surya.common.predictor import BasePredictor\nfrom surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult\nfrom surya.common.polygon import PolygonBox\nfrom surya.settings import settings\nfrom surya.table_rec.loader import TableRecModelLoader\nfrom surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \\\n    MERGE_VALUES\nfrom surya.table_rec.shaper import LabelShaper\n\n\nclass TableRecPredictor(BasePredictor):\n    model_loader_cls = TableRecModelLoader\n    batch_size = settings.TABLE_REC_BATCH_SIZE\n    default_batch_sizes = {\n        \"cpu\": 8,\n        \"mps\": 8,\n        \"cuda\": 32,\n        \"xla\": 16\n    }\n\n    def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]:\n        return self.batch_table_recognition(images, batch_size)\n\n    def inference_loop(\n            self,\n            encoder_hidden_states: torch.Tensor,\n            batch_input_ids: torch.Tensor,\n            current_batch_size: int,\n            batch_size: int\n    ):\n        shaper = LabelShaper()\n        batch_predictions = [[] for _ in range(current_batch_size)]\n        max_tokens = settings.TABLE_REC_MAX_BOXES\n        decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum(\n            0) - 1\n        inference_token_count = batch_input_ids.shape[1]\n\n        if settings.TABLE_REC_STATIC_CACHE:\n            encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size)\n            batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)\n\n        # Move to device after padding for XLA\n        encoder_hidden_states = encoder_hidden_states.to(self.model.device)\n        batch_input_ids = batch_input_ids.to(self.model.device)\n\n        self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype)\n\n        with settings.INFERENCE_MODE():\n            token_count = 0\n            all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device)\n\n            while token_count < max_tokens:\n                is_prefill = token_count == 0\n                return_dict = self.model.decoder(\n                    input_ids=batch_input_ids,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cache_position=decoder_position_ids,\n                    use_cache=True,\n                    prefill=is_prefill\n                )\n\n                decoder_position_ids = decoder_position_ids[-1:] + 1\n\n                # Get predictions for each box element\n                box_properties = []\n                done = []\n\n                # Pre-process all logits at once\n                processed_logits = {}\n                for k, _, mode in BOX_PROPERTIES:\n                    k_logits = return_dict[\"box_property_logits\"][k][:, -1, :]  # Get all batch logits at once\n                    \n                    if mode == \"classification\":\n                        # Process all classification logits in one operation\n                        items = torch.argmax(k_logits, dim=-1)\n                        if k == \"category\":\n                            done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id)\n                        items = items - SPECIAL_TOKENS\n                        processed_logits[k] = items\n                    elif mode == \"regression\":\n                        if k == \"bbox\":\n                            k_logits = k_logits * BOX_DIM\n                            processed_logits[k] = k_logits\n                        elif k == \"colspan\":\n                            k_logits = torch.clamp(k_logits, min=1)\n                            processed_logits[k] = torch.round(k_logits)\n\n                items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES}\n                for j in range(current_batch_size):\n                    box_property = {}\n                    for k, _, mode in BOX_PROPERTIES:\n                        if mode == \"classification\":\n                            box_property[k] = int(items[k][j].item())\n                        elif mode == \"regression\":\n                            if k == \"bbox\":\n                                box_property[k] = items[k][j].tolist()\n                            elif k == \"colspan\":\n                                box_property[k] = int(items[k][j].item())\n                    box_properties.append(box_property)\n\n                all_done = all_done | done\n                all_done_cpu = all_done.cpu()\n\n                if all_done_cpu[:current_batch_size].all():\n                    break\n\n                batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long)\n                batch_input_ids = batch_input_ids.unsqueeze(1)  # Add sequence length dimension\n\n                for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)):\n                    if not status:\n                        batch_predictions[j].append(box_property)\n\n                token_count += inference_token_count\n                inference_token_count = batch_input_ids.shape[1]\n\n                if settings.TABLE_REC_STATIC_CACHE:\n                    batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)\n\n                # Move to device after padding for XLA\n                batch_input_ids = batch_input_ids.to(self.model.device)\n        return batch_predictions\n\n    def batch_table_recognition(\n            self,\n            images: List,\n            batch_size=None) -> List[TableResult]:\n        assert all([isinstance(image, Image.Image) for image in images])\n        if batch_size is None:\n            batch_size = self.get_batch_size()\n\n        if len(images) == 0:\n            return []\n\n        query_items = []\n        for image in images:\n            query_items.append({\n                \"polygon\": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]],\n                \"category\": CATEGORY_TO_ID[\"Table\"],\n                \"colspan\": 0,\n                \"merges\": 0,\n                \"is_header\": 0\n            })\n\n        output_order = []\n        for i in tqdm(range(0, len(images), batch_size), desc=\"Recognizing tables\", disable=self.disable_tqdm):\n            batch_query_items = query_items[i:i + batch_size]\n\n            batch_images = images[i:i + batch_size]\n            batch_images = [image.convert(\"RGB\") for image in batch_images]  # also copies the images\n\n            current_batch_size = len(batch_images)\n\n            orig_sizes = [image.size for image in batch_images]\n            model_inputs = self.processor(images=batch_images, query_items=batch_query_items)\n\n            batch_pixel_values = model_inputs[\"pixel_values\"]\n\n            batch_input_ids = model_inputs[\"input_ids\"]\n            batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype)\n\n            if settings.TABLE_REC_STATIC_CACHE:\n                batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size)\n\n            # Move to device after padding for XLA\n            batch_pixel_values = batch_pixel_values.to(self.model.device)\n\n            shaper = LabelShaper()\n\n            # We only need to process each image once\n            with settings.INFERENCE_MODE():\n                encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state\n\n            # Inference to get rows and columns\n            rowcol_predictions = self.inference_loop(\n                encoder_hidden_states,\n                batch_input_ids,\n                current_batch_size,\n                batch_size\n            )\n            mark_step()\n\n            row_query_items = []\n            row_encoder_hidden_states = []\n            idx_map = []\n            columns = []\n            for j, img_predictions in enumerate(rowcol_predictions):\n                for row_prediction in img_predictions:\n                    polygon = shaper.convert_bbox_to_polygon(row_prediction[\"bbox\"])\n                    if row_prediction[\"category\"] == CATEGORY_TO_ID[\"Table-row\"]:\n                        row_query_items.append({\n                            \"polygon\": polygon,\n                            \"category\": row_prediction[\"category\"],\n                            \"colspan\": 0,\n                            \"merges\": 0,\n                            \"is_header\": int(row_prediction[\"is_header\"] == 1)\n                        })\n                        row_encoder_hidden_states.append(encoder_hidden_states[j])\n                        idx_map.append(j)\n                    elif row_prediction[\"category\"] == CATEGORY_TO_ID[\"Table-column\"]:\n                        columns.append({\n                            \"polygon\": polygon,\n                            \"category\": row_prediction[\"category\"],\n                            \"colspan\": 0,\n                            \"merges\": 0,\n                            \"is_header\": int(row_prediction[\"is_header\"] == 1)\n                        })\n\n            # Re-inference to predict cells\n            row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)\n            row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)\n            row_input_ids = row_inputs[\"input_ids\"]\n            cell_predictions = []\n            for j in range(0, len(row_input_ids), batch_size):\n                cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size]\n                cell_batch_input_ids = row_input_ids[j:j + batch_size]\n                cell_batch_size = len(cell_batch_input_ids)\n                cell_predictions.extend(\n                    self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)\n                )\n                mark_step()\n\n            result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper)\n            output_order.extend(result)\n\n        return output_order\n\n\n    def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper):\n        results = []\n        for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):\n            row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j]\n            # Each row prediction matches a cell prediction\n            rows = []\n            cells = []\n            columns = []\n\n            cell_id = 0\n            row_predictions = [pred for pred in img_predictions if pred[\"category\"] == CATEGORY_TO_ID[\"Table-row\"]]\n            col_predictions = [pred for pred in img_predictions if pred[\"category\"] == CATEGORY_TO_ID[\"Table-column\"]]\n\n            # Generate table columns\n            for z, col_prediction in enumerate(col_predictions):\n                polygon = shaper.convert_bbox_to_polygon(col_prediction[\"bbox\"])\n                polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)\n                columns.append(\n                    TableCol(\n                        polygon=polygon,\n                        col_id=z,\n                        is_header=col_prediction[\"is_header\"] == 1\n                    )\n                )\n\n            # Generate table rows\n            for z, row_prediction in enumerate(row_predictions):\n                polygon = shaper.convert_bbox_to_polygon(row_prediction[\"bbox\"])\n                polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)\n                row = TableRow(\n                    polygon=polygon,\n                    row_id=z,\n                    is_header=row_prediction[\"is_header\"] == 1\n                )\n                rows.append(row)\n\n                # Get cells that span multiple columns within a row\n                spanning_cells = []\n                for l, spanning_cell in enumerate(row_cell_predictions[z]):\n                    polygon = shaper.convert_bbox_to_polygon(spanning_cell[\"bbox\"])\n                    polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)\n                    colspan = max(1, int(spanning_cell[\"colspan\"]))\n                    if colspan == 1 and spanning_cell[\"merges\"] not in MERGE_VALUES:\n                        # Skip single column cells if they don't merge\n                        continue\n                    if PolygonBox(polygon=polygon).height < row.height * .85:\n                        # Spanning cell must cover most of the row\n                        continue\n\n                    spanning_cells.append(\n                        TableCell(\n                            polygon=polygon,\n                            row_id=z,\n                            rowspan=1,\n                            cell_id=cell_id,\n                            within_row_id=l,\n                            colspan=colspan,\n                            merge_up=spanning_cell[\"merges\"] in [MERGE_KEYS[\"merge_up\"], MERGE_KEYS[\"merge_both\"]],\n                            merge_down=spanning_cell[\"merges\"] in [MERGE_KEYS[\"merge_down\"],\n                                                                   MERGE_KEYS[\"merge_both\"]],\n                            is_header=row.is_header or z == 0\n                        )\n                    )\n                    cell_id += 1\n\n                # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col\n                used_spanning_cells = set()\n                skip_columns = 0\n                for l, col in enumerate(columns):\n                    if skip_columns:\n                        skip_columns -= 1\n                        continue\n                    cell_polygon = row.intersection_polygon(col)\n                    cell_added = False\n                    for zz, spanning_cell in enumerate(spanning_cells):\n                        cell_polygonbox = PolygonBox(polygon=cell_polygon)\n                        intersection_pct = cell_polygonbox.intersection_pct(spanning_cell)\n                        # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns)\n                        correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]])\n                        if intersection_pct > .9:\n                            if spanning_cell.width > (correct_col_width * .85):\n                                cell_added = True\n                                if zz not in used_spanning_cells:\n                                    used_spanning_cells.add(zz)\n                                    spanning_cell.col_id = l\n                                    cells.append(spanning_cell)\n                                    skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell\n                            else:\n                                used_spanning_cells.add(zz) # Skip this spanning cell\n\n                    if not cell_added:\n                        cells.append(\n                            TableCell(\n                                polygon=cell_polygon,\n                                row_id=z,\n                                rowspan=1,\n                                cell_id=cell_id,\n                                within_row_id=l,\n                                colspan=1,\n                                merge_up=False,\n                                merge_down=False,\n                                col_id=l,\n                                is_header=row.is_header or col.is_header or z == 0\n                            )\n                        )\n                        cell_id += 1\n\n            # Turn cells into a row grid\n            grid_cells = deepcopy([\n                [cell for cell in cells if cell.row_id == row.row_id]\n                for row in rows\n            ])\n\n            # Merge cells across rows\n            for z, grid_row in enumerate(grid_cells[1:]):\n                prev_row = grid_cells[z]\n                for l, cell in enumerate(grid_row):\n                    if l >= len(prev_row):\n                        continue\n\n                    above_cell = prev_row[l]\n                    if all([\n                        above_cell.merge_down,\n                        cell.merge_up,\n                        above_cell.col_id == cell.col_id,\n                        above_cell.colspan == cell.colspan,\n                    ]):\n                        above_cell.merge(cell)\n                        above_cell.rowspan += cell.rowspan\n                        grid_row[l] = above_cell\n\n            merged_cells_all = list(chain.from_iterable(grid_cells))\n            used_ids = set()\n            merged_cells = []\n            for cell in merged_cells_all:\n                if cell.cell_id in used_ids:\n                    continue\n                used_ids.add(cell.cell_id)\n                merged_cells.append(cell)\n\n            result = TableResult(\n                cells=merged_cells,\n                unmerged_cells=cells,\n                rows=rows,\n                cols=columns,\n                image_bbox=[0, 0, orig_size[0], orig_size[1]],\n            )\n            results.append(result)\n        return results\n"
  },
  {
    "path": "surya/table_rec/loader.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom surya.common.load import ModelLoader\nfrom surya.logging import get_logger\nfrom surya.settings import settings\nfrom surya.table_rec.model.config import (\n    SuryaTableRecConfig,\n    SuryaTableRecDecoderConfig,\n    DonutSwinTableRecConfig,\n)\nfrom surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel\nfrom surya.table_rec.processor import SuryaTableRecProcessor\n\nlogger = get_logger()\n\n\nclass TableRecModelLoader(ModelLoader):\n    def __init__(self, checkpoint: Optional[str] = None):\n        super().__init__(checkpoint)\n\n        if self.checkpoint is None:\n            self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT\n\n    def model(\n        self,\n        device=settings.TORCH_DEVICE_MODEL,\n        dtype=settings.MODEL_DTYPE,\n        attention_implementation: Optional[str] = None,\n    ) -> TableRecEncoderDecoderModel:\n        if device is None:\n            device = settings.TORCH_DEVICE_MODEL\n        if dtype is None:\n            dtype = settings.MODEL_DTYPE\n\n        if device == \"mps\":\n            logger.warning(\n                \"`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead\"\n            )\n            device = \"cpu\"\n            dtype = \"float32\"\n\n        config = SuryaTableRecConfig.from_pretrained(self.checkpoint)\n        decoder_config = config.decoder\n        decoder = SuryaTableRecDecoderConfig(**decoder_config)\n        config.decoder = decoder\n\n        encoder_config = config.encoder\n        encoder = DonutSwinTableRecConfig(**encoder_config)\n        config.encoder = encoder\n\n        model = TableRecEncoderDecoderModel.from_pretrained(\n            self.checkpoint, config=config, dtype=dtype\n        )\n\n        model = model.to(device)\n        model = model.eval()\n\n        if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC:\n            torch.set_float32_matmul_precision(\"high\")\n            torch._dynamo.config.cache_size_limit = 16\n            torch._dynamo.config.suppress_errors = False\n\n            logger.info(\n                f\"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}\"\n            )\n            compile_args = {\"backend\": \"openxla\"} if device == \"xla\" else {}\n            model.encoder = torch.compile(model.encoder, **compile_args)\n            model.decoder = torch.compile(model.decoder, **compile_args)\n\n        logger.debug(\n            f\"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}\"\n        )\n        return model\n\n    def processor(\n        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE\n    ) -> SuryaTableRecProcessor:\n        processor = SuryaTableRecProcessor(self.checkpoint)\n\n        processor.token_pad_id = 0\n        processor.token_eos_id = 1\n        processor.token_bos_id = 1\n        processor.token_query_end_id = 4\n        return processor\n"
  },
  {
    "path": "surya/table_rec/model/__init__.py",
    "content": ""
  },
  {
    "path": "surya/table_rec/model/config.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict\n\nimport torch\nfrom transformers import PretrainedConfig\nfrom transformers.utils import ModelOutput\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.settings import settings\n\nBOX_DIM = 1024\nSPECIAL_TOKENS = 5\nMAX_BOXES = 150\n\nMERGE_KEYS = {\n    \"none\": 0,\n    \"merge_up\": 1,\n    \"merge_down\": 2,\n    \"merge_both\": 3\n}\nMERGE_VALUES = [MERGE_KEYS[\"merge_up\"], MERGE_KEYS[\"merge_down\"], MERGE_KEYS[\"merge_both\"]]\n\nID_TO_CATEGORY = {\n    0: 'Blank',\n    1: 'Table-row',\n    2: 'Table-column',\n    3: 'Table-cell',\n    4: 'Table'\n}\nCATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()}\n\nID_TO_HEADER = {\n    0: \"None\",\n    1: \"Header\"\n}\nHEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()}\n\nBOX_PROPERTIES = [\n    (\"bbox\", 6, \"regression\"),\n    (\"category\", len(ID_TO_CATEGORY), \"classification\"),\n    (\"merges\", len(MERGE_KEYS), \"classification\"),\n    (\"colspan\", 1, \"regression\"),\n    (\"is_header\", len(ID_TO_HEADER), \"classification\")\n]\n\n\n@dataclass\nclass TableRecModelOutput(ModelOutput):\n    box_property_logits: Dict[str, torch.Tensor]\n    hidden_states: torch.Tensor | None = None\n\n\nclass SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig):\n    model_type = \"vision-encoder-decoder\"\n    is_composition = True\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        if \"encoder\" in kwargs:\n            encoder_config = kwargs.pop(\"encoder\")\n            decoder_config = kwargs.pop(\"decoder\")\n        else:\n            encoder_config = DonutSwinTableRecConfig()\n            decoder_config = SuryaTableRecDecoderConfig()\n\n        self.encoder = encoder_config\n        self.decoder = decoder_config\n        self.is_encoder_decoder = True\n\n        if isinstance(decoder_config, dict):\n            self.decoder_start_token_id = decoder_config[\"bos_token_id\"]\n            self.pad_token_id = decoder_config[\"pad_token_id\"]\n            self.eos_token_id = decoder_config[\"eos_token_id\"]\n        else:\n            self.decoder_start_token_id = decoder_config.bos_token_id\n            self.pad_token_id = decoder_config.pad_token_id\n            self.eos_token_id = decoder_config.eos_token_id\n\n\nclass DonutSwinTableRecConfig(PretrainedConfig):\n    model_type = \"donut-swin\"\n\n    attribute_map = {\n        \"num_attention_heads\": \"num_heads\",\n        \"num_hidden_layers\": \"num_layers\",\n    }\n\n    def __init__(\n        self,\n        image_size=(settings.TABLE_REC_IMAGE_SIZE[\"width\"], settings.TABLE_REC_IMAGE_SIZE[\"height\"]),\n        patch_size=4,\n        num_channels=3,\n        embed_dim=128,\n        depths=[2, 2, 12, 2],\n        num_heads=[4, 8, 16, 32],\n        num_kv_heads=[4, 8, 16, 32],\n        window_size=8,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        hidden_dropout_prob=0.0,\n        attention_probs_dropout_prob=0.0,\n        drop_path_rate=0.1,\n        hidden_act=\"gelu\",\n        use_absolute_embeddings=False,\n        initializer_range=0.02,\n        layer_norm_eps=1e-5,\n        encoder_length=1024,\n        use_positional_embeddings=True,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.num_channels = num_channels\n        self.embed_dim = embed_dim\n        self.depths = depths\n        self.num_layers = len(depths)\n        self.num_heads = num_heads\n        self.num_kv_heads = num_kv_heads\n        self.window_size = window_size\n        self.mlp_ratio = mlp_ratio\n        self.qkv_bias = qkv_bias\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.drop_path_rate = drop_path_rate\n        self.hidden_act = hidden_act\n        self.use_absolute_embeddings = use_absolute_embeddings\n        self.layer_norm_eps = layer_norm_eps\n        self.initializer_range = initializer_range\n        # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel\n        # this indicates the channel dimension after the last stage of the model\n        self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))\n        self.encoder_length = encoder_length\n        self.use_positional_embeddings = use_positional_embeddings\n\n\nclass SuryaTableRecDecoderConfig(PretrainedConfig):\n    model_type = \"surya_tablerec\"\n\n    def __init__(\n        self,\n        num_hidden_layers=6,\n        vocab_size=BOX_DIM + 1,\n        bbox_size=BOX_DIM,\n        hidden_size=512,\n        property_embed_size=64,\n        box_embed_size=512 - 64,\n        intermediate_size=4 * 512,\n        encoder_hidden_size=1024,\n        num_attention_heads=8,\n        lru_width=None,\n        attention_window_size=16,\n        conv1d_width=4,\n        logits_soft_cap=30.0,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        eos_token_id=1,\n        bos_token_id=1,\n        pause_token_id=2,\n        query_end_token_id=4,\n        hidden_activation=\"gelu_pytorch_tanh\",\n        rope_theta=10000.0,\n        block_types=(\"attention\",),\n        cross_attn_layers=tuple(range(10)),\n        encoder_cross_attn_layers=tuple(range(10)),\n        self_attn_layers=tuple(range(10)),\n        global_attn_layers=tuple(range(10)),\n        attention_dropout=0.0,\n        num_key_value_heads=4,\n        attention_bias=False,\n        w_init_variance_scale=0.01,\n        init_std=0.02,\n        tie_word_embeddings=False,\n        aux_heads=0, # How many n-token-ahead heads to add\n        causal=True,\n        layer_norm_eps=1e-5,\n        dropout=0.0,\n        special_token_count=SPECIAL_TOKENS,\n        **kwargs,\n    ):\n        self.num_hidden_layers = num_hidden_layers\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_attention_heads = num_attention_heads\n        self.lru_width = lru_width if lru_width is not None else hidden_size\n        self.attention_window_size = attention_window_size\n        self.conv1d_width = conv1d_width\n        self.logits_soft_cap = logits_soft_cap\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.block_types = list(block_types)\n        self.hidden_activation = hidden_activation\n        self.head_dim = self.hidden_size // self.num_attention_heads\n        self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads\n        if self.num_key_value_heads > self.num_attention_heads:\n            raise ValueError(\"The number of `num_key_value_heads` must be smaller than `num_attention_heads`\")\n        self.cross_attn_layers = cross_attn_layers\n        self.self_attn_layers = self_attn_layers\n        self.global_attn_layers = global_attn_layers\n        self.attention_dropout = attention_dropout\n        self.attention_bias = attention_bias\n        self.w_init_variance_scale = w_init_variance_scale\n        self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers\n        self.init_std = init_std\n        self.tie_word_embeddings = tie_word_embeddings\n        self.aux_heads = aux_heads\n        self.encoder_hidden_size=encoder_hidden_size\n        self.causal = causal\n        self.encoder_cross_attn_layers = encoder_cross_attn_layers\n        self.layer_norm_eps = layer_norm_eps\n        self.dropout = dropout\n        self.bbox_size = bbox_size\n        self.pause_token_id = pause_token_id\n        self.box_properties = BOX_PROPERTIES\n        self.property_embed_size = property_embed_size\n        self.box_embed_size = box_embed_size\n        self.special_token_count = special_token_count\n        self.query_end_token_id = query_end_token_id\n        self.double_residual_flow = False\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            **kwargs,\n        )\n\n    @property\n    def layers_block_type(self):\n        return (self.block_types * 100)[: self.num_hidden_layers]"
  },
  {
    "path": "surya/table_rec/model/decoder.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel\nfrom surya.table_rec.model.config import TableRecModelOutput\nfrom surya.table_rec.shaper import LabelShaper\nfrom surya.settings import settings\n\n\nclass LabelEmbedding(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        # Bboxes\n        self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n\n        self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n        self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)\n\n        # Get indexes for passed in tensor\n        shaper = LabelShaper()\n        self.component_idxs = shaper.component_idx_dict()\n        merge_count = shaper.get_box_property(\"merges\")[1] + config.special_token_count\n        category_count = shaper.get_box_property(\"category\")[1] + config.special_token_count\n\n        # Other box properties\n        self.category_embed = nn.Embedding(category_count, config.property_embed_size)\n        self.merge_embed = nn.Embedding(merge_count, config.property_embed_size)\n        self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size)\n\n        self.config = config\n\n    def forward(self, boxes: torch.LongTensor, *args):\n        # Need to keep *args for compatibility with common decoder\n        boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size)\n\n        boxes_unbound = boxes.to(torch.long).unbind(dim=-1)\n        cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs[\"bbox\"][0]:self.component_idxs[\"bbox\"][1]]\n        category = boxes_unbound[self.component_idxs[\"category\"][0]:self.component_idxs[\"category\"][1]][0]\n        merges = boxes_unbound[self.component_idxs[\"merges\"][0]:self.component_idxs[\"merges\"][1]][0]\n        colspan = boxes_unbound[self.component_idxs[\"colspan\"][0]:self.component_idxs[\"colspan\"][1]][0]\n\n        xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long)\n        yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long)\n\n        x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)\n        y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)\n        x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)\n        y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)\n\n        size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)\n        skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew)\n        corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3)\n        box_embeds = size_embeds + skew_embeds + corner_embeds\n\n        property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan)\n\n        # Cat bbox and property embeddings\n        embedded = torch.cat([box_embeds, property_embeds], dim=-1)\n        return embedded\n\n\nclass SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel):\n    _tied_weights_keys = None\n\n    def __init__(self, config, **kwargs):\n        super().__init__(config)\n        embed_tokens = LabelEmbedding(config)\n        self.model = SuryaADETRDecoderModel(\n            config,\n            embedder=embed_tokens,\n            static_cache=settings.TABLE_REC_STATIC_CACHE,\n            max_boxes=settings.TABLE_REC_MAX_BOXES\n        )\n        self.vocab_size = config.vocab_size\n\n        shaper = LabelShaper()\n        property_heads = {}\n        for k in shaper.property_keys:\n            _, kcount, mode = shaper.get_box_property(k)\n            property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False)\n\n        self.box_property_heads = nn.ModuleDict(property_heads)\n        self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    # Ignore copy\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        prefill: bool = False,\n        **kwargs\n    ) -> Union[Tuple, TableRecModelOutput]:\n        outputs = self.model(\n            input_ids=input_ids,\n            cache_position=cache_position,\n            attention_mask=attention_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            use_cache=use_cache,\n            output_hidden_states=True,\n            return_dict=True,\n            prefill=prefill,\n        )\n\n        hidden_states = self.pre_output_norm(outputs[0])\n        box_property_logits = {}\n        for key in self.box_property_heads:\n            box_property_logits[key] = self.box_property_heads[key](hidden_states)\n\n        bbox_logits = nn.functional.sigmoid(box_property_logits[\"bbox\"])\n        box_property_logits[\"bbox\"] = bbox_logits\n\n        return TableRecModelOutput(\n            box_property_logits=box_property_logits,\n            hidden_states=hidden_states,\n        )"
  },
  {
    "path": "surya/table_rec/model/encoder.py",
    "content": "from typing import Optional, Union, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder\n\n\nclass DonutSwinModel(DonutSwinPreTrainedModel):\n    def __init__(self, config, add_pooling_layer=True, use_mask_token=False):\n        super().__init__(config)\n        self.config = config\n        self.num_layers = len(config.depths)\n        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))\n\n        self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)\n        self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)\n\n        self.position_embeddings = None\n        if hasattr(config, \"encoder_length\"):\n            self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size))\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embeddings.patch_embeddings\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        bool_masked_pos: Optional[torch.BoolTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: bool = False,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, DonutSwinModelOutput]:\n        r\"\"\"\n        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):\n            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, len(self.config.depths))\n\n        embedding_output, input_dimensions = self.embeddings(\n            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding\n        )\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            input_dimensions,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        last_hidden_state = encoder_outputs[0]\n\n        if self.position_embeddings is not None:\n            last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :]\n\n        return DonutSwinModelOutput(\n            last_hidden_state=last_hidden_state,\n        )\n"
  },
  {
    "path": "surya/table_rec/model/encoderdecoder.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional, Union, Tuple, Dict\n\nimport torch\nfrom transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig\n\nfrom surya.common.pretrained import SuryaPreTrainedModel\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.table_rec.model.decoder import SuryaTableRecDecoder\nfrom surya.table_rec.model.encoder import DonutSwinModel\nfrom transformers.utils import ModelOutput\n\n\n@dataclass\nclass TableRecOutput(ModelOutput):\n    box_property_logits: Dict[str, torch.FloatTensor]\n    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n\n\nclass TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel):\n    config_class = VisionEncoderDecoderConfig\n    base_model_prefix = \"vision_encoder_decoder\"\n    main_input_name = \"pixel_values\"\n    supports_gradient_checkpointing = True\n    _supports_param_buffer_assignment = False\n\n    def __init__(\n        self,\n        config: Optional[PretrainedConfig] = None,\n        encoder: Optional[PreTrainedModel] = None,\n        decoder: Optional[PreTrainedModel] = None,\n        **kwargs,\n    ):\n        # initialize with config\n        # make sure input & output embeddings is not tied\n        config.tie_word_embeddings = False\n        config.decoder.tie_word_embeddings = False\n        super().__init__(config, **kwargs)\n\n        if encoder is None:\n            encoder = DonutSwinModel(config.encoder)\n\n        if decoder is None:\n            decoder = SuryaTableRecDecoder(\n                config.decoder, attn_implementation=config._attn_implementation\n            )\n\n        self.encoder = encoder\n        self.decoder = decoder\n\n        # make sure that the individual model's config refers to the shared config\n        # so that the updates to the config will be synced\n        self.encoder.config = self.config.encoder\n        self.decoder.config = self.config.decoder\n\n    def get_encoder(self):\n        return self.encoder\n\n    def get_decoder(self):\n        return self.decoder\n\n    def get_output_embeddings(self):\n        return self.decoder.get_output_embeddings()\n\n    def set_output_embeddings(self, new_embeddings):\n        return self.decoder.set_output_embeddings(new_embeddings)\n\n    def forward(\n        self,\n        decoder_input_ids: torch.LongTensor = None,\n        decoder_cache_position: Optional[torch.LongTensor] = None,\n        decoder_attention_mask: Optional[torch.LongTensor] = None,\n        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,\n        use_cache: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs,\n    ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:\n        kwargs_decoder = {\n            argument[len(\"decoder_\") :]: value\n            for argument, value in kwargs.items()\n            if argument.startswith(\"decoder_\")\n        }\n\n        # Decode\n        decoder_outputs = self.decoder(\n            input_labels=decoder_input_ids,\n            input_boxes_counts=None,\n            cache_position=decoder_cache_position,\n            attention_mask=decoder_attention_mask,\n            encoder_hidden_states=encoder_outputs,\n            encoder_attention_mask=None,\n            use_cache=use_cache,\n            **kwargs_decoder,\n        )\n\n        return TableRecOutput(\n            box_property_logits=decoder_outputs.box_property_logits,\n            decoder_hidden_states=decoder_outputs.hidden_states,\n        )\n\n    def resize_token_embeddings(self, *args, **kwargs):\n        raise NotImplementedError(\n            \"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the\"\n            \" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))\"\n        )\n\n    def _reorder_cache(self, past_key_values, beam_idx):\n        # apply decoder cache reordering here\n        return self.decoder._reorder_cache(past_key_values, beam_idx)\n"
  },
  {
    "path": "surya/table_rec/processor.py",
    "content": "from typing import List\n\nimport PIL\nimport torch\nfrom transformers import ProcessorMixin\n\nfrom surya.common.s3 import S3DownloaderMixin\nfrom surya.common.donut.processor import SuryaEncoderImageProcessor\nfrom surya.table_rec.shaper import LabelShaper\nfrom surya.settings import settings\nfrom surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS\n\n\nclass SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin):\n    attributes = [\"image_processor\"]\n    image_processor_class = \"AutoImageProcessor\"\n\n    def __init__(self, checkpoint, **kwargs):\n        image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint)\n        image_processor.do_align_long_axis = False\n        image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE\n        self.image_processor = image_processor\n        super().__init__(image_processor)\n\n        self.box_size = (BOX_DIM, BOX_DIM)\n        self.special_token_count = SPECIAL_TOKENS\n        self.shaper = LabelShaper()\n\n    def resize_polygon(self, polygon, orig_size, new_size):\n        w_scaler = new_size[0] / orig_size[0]\n        h_scaler = new_size[1] / orig_size[1]\n\n        for corner in polygon:\n            corner[0] = corner[0] * w_scaler\n            corner[1] = corner[1] * h_scaler\n\n            if corner[0] < 0:\n                corner[0] = 0\n            if corner[1] < 0:\n                corner[1] = 0\n            if corner[0] > new_size[0]:\n                corner[0] = new_size[0]\n            if corner[1] > new_size[1]:\n                corner[1] = new_size[1]\n\n        return polygon\n\n    def __call__(\n            self,\n            images: List[PIL.Image.Image] | None,\n            query_items: List[dict],\n            columns: List[dict] | None = None,\n            convert_images: bool = True,\n            *args,\n            **kwargs\n    ):\n        if convert_images:\n            assert len(images) == len(query_items)\n            assert len(images) > 0\n\n            # Resize input query items\n            for image, query_item in zip(images, query_items):\n                query_item[\"polygon\"] = self.resize_polygon(query_item[\"polygon\"], image.size, self.box_size)\n\n        query_items = self.shaper.convert_polygons_to_bboxes(query_items)\n        query_labels = self.shaper.dict_to_labels(query_items)\n\n        decoder_input_boxes = []\n        col_count = len(query_labels[0])\n        for label in query_labels:\n            decoder_input_boxes.append([\n                [self.token_bos_id] * col_count,\n                label,\n                [self.token_query_end_id] * col_count\n            ])\n\n        # Add columns to end of decoder input\n        if columns:\n            columns = self.shaper.convert_polygons_to_bboxes(columns)\n            column_labels = self.shaper.dict_to_labels(columns)\n            for decoder_box in decoder_input_boxes:\n                decoder_box += column_labels\n\n        input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long)\n        input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long)\n\n        inputs = {\n            \"input_ids\": input_boxes,\n            \"attention_mask\": input_boxes_mask\n        }\n        if convert_images:\n            inputs[\"pixel_values\"] = self.image_processor(images, *args, **kwargs)[\"pixel_values\"]\n        return inputs\n"
  },
  {
    "path": "surya/table_rec/schema.py",
    "content": "from typing import List\n\nfrom pydantic import BaseModel\n\nfrom surya.common.polygon import PolygonBox\n\n\nclass TableCell(PolygonBox):\n    row_id: int\n    colspan: int\n    within_row_id: int\n    cell_id: int\n    is_header: bool\n    rowspan: int | None = None\n    merge_up: bool = False\n    merge_down: bool = False\n    col_id: int | None = None\n    text_lines: List[dict] | None = None\n\n    @property\n    def label(self):\n        return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}'\n\n\nclass TableRow(PolygonBox):\n    row_id: int\n    is_header: bool\n\n    @property\n    def label(self):\n        return f'Row {self.row_id}'\n\n\nclass TableCol(PolygonBox):\n    col_id: int\n    is_header: bool\n\n    @property\n    def label(self):\n        return f'Column {self.col_id}'\n\n\nclass TableResult(BaseModel):\n    cells: List[TableCell]\n    unmerged_cells: List[TableCell]\n    rows: List[TableRow]\n    cols: List[TableCol]\n    image_bbox: List[float]\n"
  },
  {
    "path": "surya/table_rec/shaper.py",
    "content": "import math\nfrom typing import List, Dict\nimport numpy as np\n\nfrom surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM\n\n\nclass LabelShaper:\n    def __init__(self):\n        self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES]\n\n    def dict_to_labels(self, label_components: List[dict]):\n        if len(label_components) == 0:\n            return []\n\n        out_list = []\n        for (k, kcount, mode) in BOX_PROPERTIES:\n            for label_component in label_components:\n                if k not in label_component:\n                    raise ValueError(f\"Missing key {k} in label component {label_component}\")\n\n                if mode == \"classification\":\n                    assert isinstance(label_component[k], int)\n                elif mode == \"regression\":\n                    assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount\n                else:\n                    raise ValueError(f\"Invalid mode {k['mode']} for key {k}\")\n\n        for label_component in label_components:\n            bbox = label_component[\"bbox\"]\n            for i in range(len(bbox)):\n                if bbox[i] < 0:\n                    bbox[i] = 0\n                if bbox[i] > BOX_DIM:\n                    bbox[i] = BOX_DIM\n\n            vector = []\n            for (k, kcount, mode) in BOX_PROPERTIES:\n                item = label_component[k]\n                if isinstance(item, (list, tuple)):\n                    vector += list(item)\n                elif isinstance(item, (float, int)):\n                    if mode == \"classification\":\n                        # Shift up for model\n                        item += SPECIAL_TOKENS\n                    vector.append(item)\n                else:\n                    raise ValueError(f\"Invalid item {item} for key {k}\")\n\n            out_list.append(vector)\n\n        return out_list\n\n    def component_idx(self, key):\n        idx = 0\n        for (k, kcount, mode) in BOX_PROPERTIES:\n            if mode == \"regression\":\n                incr = kcount\n            elif mode == \"classification\":\n                incr = 1\n            else:\n                raise ValueError(f\"Invalid mode {mode} for key {k}\")\n            if k == key:\n                return (idx, idx + incr)\n            idx += incr\n        raise ValueError(f\"Key {key} not found in properties\")\n\n    def get_box_property(self, key, add_special_tokens=True):\n        for (k, kcount, mode) in BOX_PROPERTIES:\n            if k == key:\n                # Add special token count\n                if mode == \"classification\" and add_special_tokens:\n                    kcount += SPECIAL_TOKENS\n                return (k, kcount, mode)\n        raise ValueError(f\"Key {key} not found in properties\")\n\n    def component_idx_dict(self):\n        idx_dict = {}\n        for (k, kcount, mode) in BOX_PROPERTIES:\n            idx_dict[k] = self.component_idx(k)\n        return idx_dict\n\n    def convert_polygons_to_bboxes(self, label_components: List[Dict]):\n        for i, label_component in enumerate(label_components):\n            poly = label_component[\"polygon\"]\n            poly = np.clip(poly, 0, BOX_DIM)\n\n            (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly\n            cx = (x1 + x2 + x3 + x4) / 4\n            cy = (y1 + y2 + y3 + y4) / 4\n            width = (x2 + x3) / 2 - (x1 + x4) / 2\n            height = (y3 + y4) / 2 - (y2 + y1) / 2\n            bottom_avg_x = (x3 + x4) / 2\n            top_avg_x = (x1 + x2) / 2\n            right_avg_y = (y2 + y3) / 2\n            left_avg_y = (y1 + y4) / 2\n\n            x_skew = bottom_avg_x - top_avg_x\n            y_skew = right_avg_y - left_avg_y\n            x_skew += BOX_DIM // 2 # Shift up into positive space\n            y_skew += BOX_DIM // 2 # Shift up into positive space\n            new_poly = [\n                cx,\n                cy,\n                width,\n                height,\n                x_skew,\n                y_skew\n            ]\n            label_component[\"bbox\"] = new_poly\n\n        return label_components\n\n    def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001):\n        cx = box[0]\n        cy = box[1]\n        width = box[2]\n        height = box[3]\n        x1 = cx - width / 2\n        y1 = cy - height / 2\n        x2 = cx + width / 2\n        y2 = cy + height / 2\n        skew_x = math.floor((box[4] - skew_scaler) / 2)\n        skew_y = math.floor((box[5] - skew_scaler) / 2)\n\n        # Ensures we don't get slightly warped boxes\n        # Note that the values are later scaled, so this is in 1/1024 space\n        if abs(skew_x) < skew_min:\n            skew_x = 0\n\n        if abs(skew_y) < skew_min:\n            skew_y = 0\n\n        polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x,\n                   y2 - skew_y]\n        poly = []\n        for i in range(4):\n            poly.append([\n                polygon[2 * i],\n                polygon[2 * i + 1]\n            ])\n        return poly\n\n\n\n"
  },
  {
    "path": "table_recognition.py",
    "content": "from surya.scripts.table_recognition import table_recognition_cli\n\nif __name__ == \"__main__\":\n    table_recognition_cli()"
  },
  {
    "path": "tests/conftest.py",
    "content": "import os\n\nos.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n\nimport pytest\nfrom PIL import Image, ImageDraw\n\nfrom surya.detection import DetectionPredictor\nfrom surya.ocr_error import OCRErrorPredictor\nfrom surya.layout import LayoutPredictor\nfrom surya.recognition import RecognitionPredictor\nfrom surya.foundation import FoundationPredictor\nfrom surya.table_rec import TableRecPredictor\nfrom surya.settings import settings\n\n@pytest.fixture(scope=\"session\")\ndef ocr_error_predictor() -> OCRErrorPredictor:\n    ocr_error_predictor = OCRErrorPredictor()\n    yield ocr_error_predictor\n    del ocr_error_predictor\n\n\n@pytest.fixture(scope=\"session\")\ndef layout_predictor() -> LayoutPredictor:\n    layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))\n    yield layout_predictor\n    del layout_predictor\n\n\n@pytest.fixture(scope=\"session\")\ndef detection_predictor() -> DetectionPredictor:\n    detection_predictor = DetectionPredictor()\n    yield detection_predictor\n    del detection_predictor\n\n\n@pytest.fixture(scope=\"session\")\ndef recognition_predictor() -> RecognitionPredictor:\n    recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT))\n    yield recognition_predictor\n    del recognition_predictor\n\n\n@pytest.fixture(scope=\"session\")\ndef table_rec_predictor() -> TableRecPredictor:\n    table_rec_predictor = TableRecPredictor()\n    yield table_rec_predictor\n    del table_rec_predictor\n\n\n@pytest.fixture()\ndef test_image():\n    image = Image.new(\"RGB\", (1024, 1024), \"white\")\n    draw = ImageDraw.Draw(image)\n    draw.text((10, 10), \"Hello World\", fill=\"black\", font_size=72)\n    draw.text(\n        (10, 200),\n        \"This is a sentence of text.\\nNow it is a paragraph.\\nA three-line one.\",\n        fill=\"black\",\n        font_size=24,\n    )\n    return image\n\n\n@pytest.fixture()\ndef test_image_tall():\n    image = Image.new(\"RGB\", (4096, 4096), \"white\")\n    draw = ImageDraw.Draw(image)\n    draw.text((10, 10), \"Hello World\", fill=\"black\", font_size=72)\n    draw.text(\n        (4000, 4000),\n        \"This is a sentence of text.\\n\\nNow it is a paragraph.\\n\\nA three-line one.\",\n        fill=\"black\",\n        font_size=24,\n    )\n    return image\n\n@pytest.fixture()\ndef test_image_latex():\n    assets_dir = os.path.join(os.path.dirname(__file__), \"assets\")\n    img_path = os.path.join(assets_dir, \"test_latex.png\")\n    image = Image.open(img_path).convert(\"RGB\")\n    return image"
  },
  {
    "path": "tests/test_detection.py",
    "content": "def test_detection(detection_predictor, test_image):\n    detection_results = detection_predictor([test_image])\n\n    assert len(detection_results) == 1\n    assert detection_results[0].image_bbox == [0, 0, 1024, 1024]\n\n    bboxes = detection_results[0].bboxes\n    assert len(bboxes) == 4\n\n\ndef test_detection_chunking(detection_predictor, test_image_tall):\n    detection_results = detection_predictor([test_image_tall])\n\n    assert len(detection_results) == 1\n    assert detection_results[0].image_bbox == [0, 0, 4096, 4096]\n\n    bboxes = detection_results[0].bboxes\n    assert len(bboxes) >= 3 # Sometimes merges into 3\n    assert abs(4000 - bboxes[1].polygon[0][0]) < 50"
  },
  {
    "path": "tests/test_foundation.py",
    "content": "from surya.foundation import FoundationPredictor\n\n\ndef test_foundation_flash2():\n    try:\n        f = FoundationPredictor(None, None, None, \"flash_attention_2\")\n        assert f.model.decoder.config._attn_implementation == \"flash_attention_2\"\n        assert f.model.vision_encoder.config._attn_implementation == \"flash_attention_2\"\n    except Exception as e:\n        assert False, (\n            f\"FoundationPredictor with flash_attention_2 raised an exception: {e}\"\n        )\n"
  },
  {
    "path": "tests/test_latex_ocr.py",
    "content": "from typing import List\n\nfrom PIL import Image, ImageDraw\n\nfrom surya.common.surya.schema import TaskNames\nfrom surya.recognition import OCRResult\n\n\ndef test_latex_ocr(recognition_predictor, test_image_latex):\n    width, height = test_image_latex.size\n    results: List[OCRResult] = recognition_predictor(\n        [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]]\n    )\n    text = results[0].text_lines[0].text\n    assert len(results) == 1\n\n    assert text.startswith(\"<math\")\n    assert text.endswith(\"</math>\")\n"
  },
  {
    "path": "tests/test_layout.py",
    "content": "def test_layout_topk(layout_predictor, test_image):\n    layout_results = layout_predictor([test_image])\n\n    assert len(layout_results) == 1\n    assert layout_results[0].image_bbox == [0, 0, 1024, 1024]\n\n    bboxes = layout_results[0].bboxes\n    assert len(bboxes) == 2\n\n    assert bboxes[0].label == \"SectionHeader\"\n    assert len(bboxes[0].top_k) == 5\n\n    assert bboxes[1].label == \"Text\"\n    assert len(bboxes[1].top_k) == 5\n"
  },
  {
    "path": "tests/test_ocr_errors.py",
    "content": "def test_garbled_text(ocr_error_predictor):\n    text = \"\"\"\"\n    ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj\n    2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d\n    \"\"\".strip()\n    results = ocr_error_predictor([text])\n    assert results.labels[0] == \"bad\"\n\n\ndef test_good_text(ocr_error_predictor):\n    text = \"\"\"\"\n    There are professions more harmful than industrial design, but only a very few of them.\n    \"\"\".strip()\n    results = ocr_error_predictor([text])\n    assert results.labels[0] == \"good\""
  },
  {
    "path": "tests/test_recognition.py",
    "content": "import time\nfrom PIL import ImageDraw, Image\nfrom surya.recognition.util import clean_math_tags\n\n\ndef test_recognition(recognition_predictor, detection_predictor, test_image):\n    recognition_results = recognition_predictor([test_image], None, detection_predictor)\n\n    assert len(recognition_results) == 1\n    assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]\n\n    text_lines = recognition_results[0].text_lines\n    assert len(text_lines) == 4\n    assert \"Hello World\" in text_lines[0].text\n\n\ndef test_recognition_input_text(recognition_predictor, detection_predictor, test_image):\n    start = time.time()\n    recognition_predictor([test_image], None, detection_predictor)\n    end = time.time() - start\n\n    input_text = \"a\" * 400\n    start2 = time.time()\n    recognition_results = recognition_predictor(\n        [test_image], None, detection_predictor, input_text=[input_text]\n    )\n    end2 = time.time() - start2\n\n    assert max([end, end2]) / min([end, end2]) < 1.5, (\n        \"Input text should be truncated and not change inference time\"\n    )\n\n    assert len(recognition_results) == 1\n    assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]\n\n    text_lines = recognition_results[0].text_lines\n    assert len(text_lines) == 4\n    assert \"Hello World\" in text_lines[0].text\n\n\ndef test_recognition_drop_repeats(recognition_predictor, detection_predictor):\n    image = Image.new(\"RGB\", (1024, 128), \"white\")\n    draw = ImageDraw.Draw(image)\n    text = \"a\" * 80\n    draw.text((5, 5), text, fill=\"black\", font_size=24)\n\n    recognition_results = recognition_predictor(\n        [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True\n    )\n    assert len(recognition_results) == 1\n    result = recognition_results[0].text_lines\n    assert result[0].text == \"\"\n\n\ndef test_recognition_clean_math():\n    math = \"\"\"<math display=\"block\">na_n^{1+2r} \\\\text{cov}(\\\\hat{f}_n^{(r)}(x), \\\\hat{f}_n^{(r)}(y)) = \\\\frac{1}{n} \\\\sum_{j=1}^n \\\\frac{a_n^{1+2r}}{a_j^{1+2r}} \\\\text{cov}\\\\left(K^{(r)}\\\\left(\\\\frac{x-X_j}{a_j}\\\\right), K^{(r)}\\\\left(\\\\frac{y-X_j}{a_j}\\\\right)\\\\right) <br>+ \\\\frac{a_n^{1+2r}}{n} \\\\sum_{\\\\substack{j \\\\neq k \\\\\\\\ 1 \\\\le j, k \\\\le n}} \\\\frac{1}{(a_j a_k)^{1+r}} \\\\text{cov}\\\\left(K^{(r)}\\\\left(\\\\frac{x-X_j}{a_j}\\\\right), K^{(r)}\\\\left(\\\\frac{y-X_k}{a_k}\\\\right)\\\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'\"\"\"\n    clean_math = clean_math_tags(math)\n\n    assert clean_math.count(\"</math>\") == 1, \"Should have one closing math tag\"\n    assert \"<br>\" not in clean_math, \"Should not have <br> tags in cleaned math\"\n\n\ndef test_recognition_clean_math_preserve_text():\n    text = \"\"\"Hello, this is a sentence with <math display=\"inline\">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>.\"\"\"\n    clean_text = clean_math_tags(text)\n\n    assert clean_text == text\n"
  },
  {
    "path": "tests/test_table_rec.py",
    "content": "from PIL import Image, ImageDraw\n\ndef test_table_rec(table_rec_predictor):\n    data = [\n        [\"Name\", \"Age\", \"City\"],\n        [\"Alice\", 25, \"New York\"],\n        [\"Bob\", 30, \"Los Angeles\"],\n        [\"Charlie\", 35, \"Chicago\"],\n    ]\n    test_image = draw_table(data)\n\n    results = table_rec_predictor([test_image])\n    assert len(results) == 1\n    assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]]\n\n    cells = results[0].cells\n    assert len(cells) == 12\n    for row_id in range(4):\n        for col_id in range(3):\n            cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id]\n            assert len(cell) == 1, f\"Missing cell at row {row_id}, col {col_id}\"\n\ndef draw_table(data, cell_width=100, cell_height=40):\n    rows = len(data)\n    cols = len(data[0])\n    width = cols * cell_width\n    height = rows * cell_height\n\n    image = Image.new('RGB', (width, height), 'white')\n    draw = ImageDraw.Draw(image)\n\n    for i in range(rows + 1):\n        y = i * cell_height\n        draw.line([(0, y), (width, y)], fill='black', width=1)\n\n    for i in range(cols + 1):\n        x = i * cell_width\n        draw.line([(x, 0), (x, height)], fill='black', width=1)\n\n    for i in range(rows):\n        for j in range(cols):\n            text = str(data[i][j])\n            text_bbox = draw.textbbox((0, 0), text)\n            text_width = text_bbox[2] - text_bbox[0]\n            text_height = text_bbox[3] - text_bbox[1]\n\n            x = j * cell_width + (cell_width - text_width) // 2\n            y = i * cell_height + (cell_height - text_height) // 2\n\n            draw.text((x, y), text, fill='black')\n\n    return image"
  },
  {
    "path": "texify_app.py",
    "content": "from surya.scripts.run_texify_app import texify_app_cli\n\nif __name__ == \"__main__\":\n    texify_app_cli()"
  }
]