Repository: datalab-to/surya Branch: master Commit: e735028979a2 Files: 136 Total size: 740.2 KB Directory structure: gitextract_x32e43uo/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows/ │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CLA.md ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── benchmark/ │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils/ │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── detect_layout.py ├── detect_text.py ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── pyproject.toml ├── pytest.ini ├── signatures/ │ └── version1/ │ └── cla.json ├── static/ │ └── fonts/ │ └── .gitignore ├── surya/ │ ├── __init__.py │ ├── common/ │ │ ├── __init__.py │ │ ├── adetr/ │ │ │ └── decoder.py │ │ ├── donut/ │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder/ │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder/ │ │ │ │ └── __init__.py │ │ │ ├── encoder/ │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor/ │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug/ │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection/ │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation/ │ │ ├── __init__.py │ │ ├── cache/ │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input/ │ │ ├── load.py │ │ └── processing.py │ ├── layout/ │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error/ │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition/ │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec/ │ ├── __init__.py │ ├── loader.py │ ├── model/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests/ │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/breaking-bug-report.md ================================================ --- name: Breaking bug report about: Create a report about a breaking bug title: "[BUG: Breaking]" labels: 'bug: breaking' assignees: '' --- ## 🧨 Describe the Bug A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc). ## 📄 Input Document Attach the PDF or input file that triggered the error. ## 📤 Output Trace / Stack Trace Paste the **complete** stack trace or error output, if available.
Click to expand ``` Paste stack trace here ```
## ⚙️ Environment Please fill in all relevant details: - **Marker version**: - **Surya version**: - **Python version**: - **PyTorch version**: - **Transformers version**: - **Operating System** (incl. container info if relevant): ## ✅ Expected Behavior What did you expect Marker to do? ## 📟 Command or Code Used Paste the **exact bash command** or **Python code** you used to run Marker:
Click to expand ```bash # or Python code block your_command_here --with-flags ```
## 📎 Additional Context Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings). ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: "[FEAT]" labels: enhancement assignees: '' --- ## ✨ Is your feature request related to a problem? A clear and concise description of what the problem is. ## 💡 Describe the Solution You'd Like A concise description of what you want to happen or how you envision it working. ## 📋 Alternatives Considered Any alternative solutions or workarounds you've tried. ## 🧩 Additional Context Any additional context, references, or related issues. ================================================ FILE: .github/ISSUE_TEMPLATE/output-bug-report.md ================================================ --- name: Output bug report about: Create a report about poor output quality title: "[BUG: Output]" labels: 'bug: output' assignees: '' --- ## 📝 Describe the Output Issue A clear and concise description of the incorrect or unexpected output. ## 📄 Input Document Attach the PDF or input file used. ## 📤 Current Output Paste the Markdown or HTML that Marker generated: ````markdown Paste output here ````` ## ✅ Expected Output Describe or paste what you expected Marker to generate. ## ⚙️ Environment Please fill in all relevant details: * **Marker version**: * **Surya version**: * **Python version**: * **PyTorch version**: * **Transformers version**: * **Operating System**: ## 📟 Command or Code Used Paste the **exact bash command** or **Python code** you used to run Marker:
Click to expand ```bash # or Python code block your_command_here --with-flags ```
## 📎 Additional Context Any other relevant info, configs, or assumptions. ================================================ FILE: .github/workflows/benchmarks.yml ================================================ name: Integration test on: [push] env: PYTHONIOENCODING: "utf-8" jobs: build: runs-on: t4_gpu steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Run detection benchmark test run: | poetry run python benchmark/detection.py --max_rows 2 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection - name: Run recognition benchmark test run: | poetry run python benchmark/recognition.py --max_rows 2 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition - name: Run layout benchmark test run: | poetry run python benchmark/layout.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout - name: Run ordering benchmark run: | poetry run python benchmark/ordering.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering - name: Run table recognition benchmark run: | poetry run python benchmark/table_recognition.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition - name: Run texify benchmark run: | poetry run python benchmark/texify.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify ================================================ FILE: .github/workflows/ci.yml ================================================ name: Unit tests on: [push] jobs: build: runs-on: ${{ matrix.os }} strategy: matrix: os: [t4_gpu, ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Run tests run: poetry run pytest ================================================ FILE: .github/workflows/cla.yml ================================================ name: "Surya CLA Assistant" on: issue_comment: types: [created] pull_request_target: types: [opened,closed,synchronize] # explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings permissions: actions: write contents: write pull-requests: write statuses: write jobs: CLAAssistant: runs-on: ubuntu-latest steps: - name: "Surya CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' uses: contributor-assistant/github-action@v2.3.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # the below token should have repo scope and must be manually added by you in the repository's secret # This token is required only if you have configured to store the signatures in a remote repository/organization PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} with: path-to-signatures: 'signatures/version1/cla.json' path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md' # branch should not be protected branch: 'master' allowlist: VikParuchuri ================================================ FILE: .github/workflows/publish.yml ================================================ name: Python package on: push: tags: - "v*.*.*" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Build package run: | poetry build - name: Publish package env: PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} run: | poetry config pypi-token.pypi "$PYPI_TOKEN" poetry publish ================================================ FILE: .github/workflows/scripts.yml ================================================ name: Test CLI scripts on: [push] jobs: build: runs-on: t4_gpu steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Download benchmark data run: | wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi" unzip -o benchmark_data.zip - name: Test detection run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test OCR env: RECOGNITION_MAX_TOKENS: 25 run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test layout run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test table run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test texify env: TEXIFY_MAX_TOKENS: 25 run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test detection folder run: poetry run surya_detect benchmark_data/pdfs --page_range 0 ================================================ FILE: .gitignore ================================================ private.py .DS_Store local.env experiments test_data training wandb notebooks results data slices # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.9.10 hooks: # Run the linter. - id: ruff types_or: [ python, pyi ] args: [ --fix ] # Run the formatter. - id: ruff-format types_or: [ python, pyi ] ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, please cite it using the following metadata." title: "Surya: A lightweight framework for analyzing documents and PDFs at scale" authors: - family-names: Paruchuri given-names: Vikas - name: Datalab Team date-released: 2025-05-13 url: https://github.com/VikParuchuri/surya version: 0.14.0 repository-code: https://github.com/VikParuchuri/surya ================================================ FILE: CLA.md ================================================ Surya Contributor Agreement This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. 1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to: - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed. 4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 5. You covenant, represent, warrant and agree that: - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. ================================================ FILE: LICENSE ================================================ GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU General Public License is a free, copyleft license for software and other kinds of works. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. We, the Free Software Foundation, use the GNU General Public License for most of our software; it applies also to any other work released this way by its authors. You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. To protect your rights, we need to prevent others from denying you these rights or asking you to surrender the rights. Therefore, you have certain responsibilities if you distribute copies of the software, or if you modify it: responsibilities to respect the freedom of others. For example, if you distribute copies of such a program, whether gratis or for a fee, you must pass on to the recipients the same freedoms that you received. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. Developers that use the GNU GPL protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License giving you legal permission to copy, distribute and/or modify it. For the developers' and authors' protection, the GPL clearly explains that there is no warranty for this free software. For both users' and authors' sake, the GPL requires that modified versions be marked as changed, so that their problems will not be attributed erroneously to authors of previous versions. Some devices are designed to deny users access to install or run modified versions of the software inside them, although the manufacturer can do so. This is fundamentally incompatible with the aim of protecting users' freedom to change the software. The systematic pattern of such abuse occurs in the area of products for individuals to use, which is precisely where it is most unacceptable. Therefore, we have designed this version of the GPL to prohibit the practice for those products. If such problems arise substantially in other domains, we stand ready to extend this provision to those domains in future versions of the GPL, as needed to protect the freedom of users. Finally, every program is threatened constantly by software patents. States should not allow patents to restrict development and use of software on general-purpose computers, but in those that do, we wish to avoid the special danger that patents applied to a free program could make it effectively proprietary. To prevent this, the GPL assures that patents cannot be used to render the program non-free. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Use with the GNU Affero General Public License. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU Affero General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the special requirements of the GNU Affero General Public License, section 13, concerning interaction through a network will apply to the combination as such. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Surya OCR Copyright (C) 2024 Endless Labs, Inc. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: Surya OCR Copyright (C) 2024 Endless Labs, Inc. This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, your program's commands might be different; for a GUI interface, you would use an "about box". You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see . The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . ================================================ FILE: MODEL_LICENSE ================================================ AI PUBS OPEN RAIL-M LICENSE (MODIFIED) Version 0.1, March 2, 2023 (Modified) http://licenses.ai/ PLEASE READ THESE TERMS CAREFULLY BEFORE USING THE MODEL OR A DERIVATIVE WORKS OF THE MODEL MADE AVAILABLE IN CONNECTION WITH THESE TERMS. BY DOWNLOADING, REPRODUCING, DISTRIBUTING OR USING THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER, YOU (“YOU”) AGREE TO BE BOUND BY THESE TERMS (THE “AGREEMENT”) TO THE EXCLUSION OF ALL OTHER TERMS. YOU REPRESENT AND WARRANT THAT YOU HAVE THE AUTHORITY TO ENTER INTO THIS AGREEMENT; IF YOU ARE ENTERING INTO THIS AGREEMENT ON BEHALF OF AN ORGANIZATION OR ENTITY, REFERENCES TO AND “YOU” IN THIS AGREEMENT, REFER TO THAT ORGANIZATION OR ENTITY. IF YOU DO NOT AGREE TO ALL OF THE FOLLOWING, YOU MAY NOT DOWNLOAD, REPRODUCE, DISTRIBUTE OR USE THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER. Section I: PREAMBLE This OpenRAIL-M License, as modified, is generally applicable to any machine-learning Model. The “Open” nomenclature indicates that the licensed Model is be freely accessible to downstream and other users. The “RAIL” nomenclature indicates that there are use restrictions prohibiting the use of the Model. These restrictions are intended to avoid potential misuse. This License specifies that the use restrictions in the original License must apply to such derivatives. NOW THEREFORE, You and Licensor agree as follows: 1. Definitions (a) “Complementary Material” means the applicable source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, and any related information, if any. Complementary Material is not licensed under this License. (b) "Contribution" means any work, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the rights owner or by an individual or legal entity authorized to submit on behalf of the rights owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the rights owner as "Not a Contribution." (c) "Contributor" means Licensor and any individual or legal entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. (d) “Data” means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. (e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. (f) “Distribution” means any transmission, reproduction, publication, distribution, or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means, including but not limited to API-based or web access. (g) “Harm” includes but is not limited to physical, mental, psychological, financial and reputational damage, pain, or loss (h) "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. (i) “Licensor” means the rights owner or entity authorized by the rights owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. (j) “Model” means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. (k) “Output” means the results of operating a Model as embodied in informational content resulting therefrom. (l) “Third Parties” means individuals or legal entities that are not under common control with Licensor or You. (m) "You" (or "Your") means an individual or legal entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application, including but not limited to a chatbot, translator, or image generator. Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants may apply to the Model and Derivatives of the Model. The Model and Derivatives of the Model are subject to additional terms as described in Section III, which shall govern the use of the Model and Derivatives of the Model even in the event Section II is held unenforceable. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Model and Derivatives of the Model. 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and/or Derivatives of the Model where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model or Derivatives of the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model or Derivative of the Model and/or a Contribution incorporated within the Model or Derivative of the Model constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Derivative of the Model shall terminate as of the date such litigation is asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION 4. Distribution and Redistribution. You may host the Model or Derivatives of the Model for remote access by Third Parties, including but not limited to software-as-a-service, reproduce, or Distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the conditions in this Section III: (a) Use-based restrictions in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (for example, a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model and Derivatives of the Model are subject to paragraph 5; (b) You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; (c) You must cause any modified files to carry prominent notices stating that You changed the files; and (d) You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model or Derivatives of the Model. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions, consistent with paragraph 4.a., for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Accordingly, You cannot use the Model or the Derivatives of the Model in violation of such restrictions. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, fine-tuning, updating, running, training, evaluating and/or re-parametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph 5. 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are solely responsible for the Output you generate and its subsequent uses. No use of the Output can contravene any provision as stated in the License. 7. Attribution. In connection with any Output, or use of Distribution of any Model or Derivatives of the Model, You agree to give appropriate credit and attribution to Licensor, provide a link to the original Model or Derivatives of the Model, provide a copy of this License, and identify any changes You have made to the Model or Derivatives of the Model (collectively, the “Attribution”). The Attribution must not suggest endorsement by any Licensor. 8. Share-a-Like. As a condition to the license and authorizations herein, You agree to apply this License (to the exclusion of all others) to any and all copies of the Model, Derivatives of the Model, any changes or improvements to the Model or Derivatives of the Model, and to the Output and any derivatives, changes or improvements to or of the Output. Section IV: OTHER PROVISIONS 9. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or cause modification to the Output resulting from updates to the Model based. 10. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. 11. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model and Derivatives of the Model, and assume any risks associated with Your exercise of permissions under this License. 12. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 13. Accepting Warranty or Additional Liability. While Distributing the Model or Derivatives of the Model, You may choose to charge a fee in exchange for support, warranty, indemnity, or other obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor or Licensor, and only if You agree to indemnify, defend, and hold each Contributor and the Licensor harmless for any liability incurred by, or claims asserted against, such Contributor or Licensor by reason of your accepting any such warranty or additional liability. 14. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. END OF TERMS AND CONDITIONS Attachment A USE RESTRICTIONS As conditions to the Licenses set forth in this Agreement, You agree not to use, reproduce, modify, create or Distribute the Model, Derivatives of the Model, or Output (collectively, “Use”) in any of the following ways: 1. Legal: (a) In any way that violates any applicable national, federal, state, local or international law or regulation; or (b) to directly or indirectly infringe or misappropriate any third party intellectual property rights (including those of Licensor or any Contributor) 2. Commercial: (a) for any purpose if You (your employer, or the entity you are affiliated with) generated more than two million US Dollars ($2,000,000) in gross revenue in the prior year, except where Your Use is limited to personal use or research purposes; (b) for any purpose if You (your employer, or the entity you are affiliated with) has raised more than two million US dollars ($2,000,000) in total equity or debt funding from any source, except where Your Use is limited to personal use or research purposes; or (c) for any purpose if You (your employer, or the entity you are affiliated with) provides or otherwise makes available any product or service that competes with any product or service offered by or made available by Licensor or any of its affiliates. Commercial and broader use licenses may be available from Licensor at the following URL: https://www.datalab.to/ ================================================ FILE: README.md ================================================ # Surya Surya is a document OCR toolkit that does: - OCR in 90+ languages that benchmarks favorably vs cloud services - Line-level text detection in any language - Layout analysis (table, image, header, etc detection) - Reading order detection - Table recognition (detecting rows/columns) - LaTeX OCR It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya). | Detection | OCR | |:----------------------------------------------------------------:|:-----------------------------------------------------------------------:| | | | | Layout | Reading Order | |:------------------------------------------------------------------:|:--------------------------------------------------------------------------:| | | | | Table Recognition | LaTeX OCR | |:-------------------------------------------------------------:|:------------------------------------------------------:| | | | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision. ## Community [Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development. ## Examples | Name | Detection | OCR | Layout | Order | Table Rec | |------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:| | Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) | | Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | | | Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | | | Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | | | Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | | | Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) | | Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) | | Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) | | New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | | | Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) | | Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | | # Hosted API There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya): - Works with PDF, images, word docs, and powerpoints - Consistent speed, with no latency spikes - High reliability and uptime # Commercial usage Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya). # Installation You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details. Install with: ```shell pip install surya-ocr ``` Model weights will automatically download the first time you run surya. # Usage - Inspect the settings in `surya/settings.py`. You can override any settings with environment variables. - Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. ## Interactive App I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with: ```shell pip install streamlit pdftext surya_gui ``` ## OCR (text recognition) This command will write out a json file with the detected text and bboxes: ```shell surya_ocr DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--task_name` will specify which task to use for predicting the lines. `ocr_with_boxes` is the default, which will format text and give you bboxes. If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes. For blocks like equations and paragraphs, try `block_without_boxes`. - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. - `--disable_math` - by default, surya will recognize math in text. This can lead to false positives - you can disable this with this flag. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `text_lines` - the detected text and bounding boxes for each line - `text` - the text in the line - `confidence` - the confidence of the model in the detected text (0-1) - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `chars` - the individual characters in the line - `text` - the text of the character - `bbox` - the character bbox (same format as line bbox) - `polygon` - the character polygon (same format as line polygon) - `confidence` - the confidence of the model in the detected character (0-1) - `bbox_valid` - if the character is a special token or math, the bbox may not be valid - `words` - the individual words in the line (computed from the characters) - `text` - the text of the word - `bbox` - the word bbox (same format as line bbox) - `polygon` - the word polygon (same format as line polygon) - `confidence` - mean character confidence - `bbox_valid` - if the word is a special token or math, the bbox may not be valid - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `512`, which will use about 20GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`. ### From python ```python from PIL import Image from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.detection import DetectionPredictor image = Image.open(IMAGE_PATH) foundation_predictor = FoundationPredictor() recognition_predictor = RecognitionPredictor(foundation_predictor) detection_predictor = DetectionPredictor() predictions = recognition_predictor([image], det_predictor=detection_predictor) ``` ## Text line detection This command will write out a json file with the detected bboxes. ```shell surya_detect DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `bboxes` - detected bounding boxes for text - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `confidence` - the confidence of the model in the detected text (0-1) - `vertical_lines` - vertical lines detected in the document - `bbox` - the axis-aligned line coordinates. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`. ### From python ```python from PIL import Image from surya.detection import DetectionPredictor image = Image.open(IMAGE_PATH) det_predictor = DetectionPredictor() # predictions is a list of dicts, one per image predictions = det_predictor([image]) ``` ## Layout and reading order This command will write out a json file with the detected layout and reading order. ```shell surya_layout DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `bboxes` - detected bounding boxes for text - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `position` - the reading order of the box. - `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`. - `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`. ### From python ```python from PIL import Image from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.settings import settings image = Image.open(IMAGE_PATH) layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) # layout_predictions is a list of dicts, one per image layout_predictions = layout_predictor([image]) ``` ## Table Recognition This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo. You can use the `TableConverter` to detect and extract tables in images and PDFs. It supports output in json (with bboxes), markdown, and html. ```shell surya_table DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected table cells + rows and columns (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. - `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible. - `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `rows` - detected table rows - `bbox` - the bounding box of the table row - `row_id` - the id of the row - `is_header` - if it is a header row. - `cols` - detected table columns - `bbox` - the bounding box of the table column - `col_id`- the id of the column - `is_header` - if it is a header column - `cells` - detected table cells - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `text` - if text could be pulled out of the pdf, the text of this cell. - `row_id` - the id of the row the cell belongs to. - `col_id` - the id of the column the cell belongs to. - `colspan` - the number of columns spanned by the cell. - `rowspan` - the number of rows spanned by the cell. - `is_header` - whether it is a header cell. - `page` - the page number in the file - `table_idx` - the index of the table on the page (sorted in vertical order) - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`. ### From python ```python from PIL import Image from surya.table_rec import TableRecPredictor image = Image.open(IMAGE_PATH) table_rec_predictor = TableRecPredictor() table_predictions = table_rec_predictor([image]) ``` ## LaTeX OCR This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want. ```shell surya_latex_ocr DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. See the OCR section above for the format of the output. ### From python ```python from PIL import Image from surya.texify import TexifyPredictor image = Image.open(IMAGE_PATH) predictor = TexifyPredictor() predictor([image]) ``` ### Interactive app You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with: ```shell pip install streamlit==1.40 streamlit-drawable-canvas-jsretry texify_gui ``` ## Compilation The following models have support for compilation. You will need to set the following environment variables to enable compilation: - Detection: `COMPILE_DETECTOR=true` - Layout: `COMPILE_LAYOUT=true` - Table recognition: `COMPILE_TABLE_REC=true` Alternatively, you can also set `COMPILE_ALL=true` which will compile all models. Here are the speedups on an A10 GPU: | Model | Time per page (s) | Compiled time per page (s) | Speedup (%) | | ----------------- | ----------------- | -------------------------- | ----------- | | Detection | 0.108808 | 0.10521 | 3.306742151 | | Layout | 0.27319 | 0.27063 | 0.93707676 | | Table recognition | 0.0219 | 0.01938 | 11.50684932 | # Limitations - This is specialized for document OCR. It will likely not work on photos or other images. - It is for printed text, not handwriting (though it may work on some handwriting). - The text detection model has trained itself to ignore advertisements. - You can find language support for OCR in `surya/recognition/languages.py`. Text detection, layout analysis, and reading order will work with any language. ## Troubleshooting If OCR isn't working properly: - Try increasing resolution of the image so the text is bigger. If the resolution is already very high, try decreasing it to no more than a `2048px` width. - Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images. - You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space. `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text. `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range. Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds). # Manual install If you want to develop surya, you can install it manually: - `git clone https://github.com/VikParuchuri/surya.git` - `cd surya` - `poetry install` - installs main and dev dependencies - `poetry shell` - activates the virtual environment # Benchmarks ## OCR ![Benchmark chart tesseract](static/images/benchmark_rec_chart.png) | Model | Time per page (s) | Avg similarity (⬆) | |-----------|-------------------|--------------------| | surya | .62 | 0.97 | | tesseract | .45 | 0.88 | [Full language results](static/images/rec_acc_table.png) Tesseract is CPU-based, and surya is CPU or GPU. I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean). ### Google Cloud Vision I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya. ![Benchmark chart google cloud](static/images/gcloud_rec_bench.png) [Full language results](static/images/gcloud_full_langs.png) **Methodology** I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs. I sampled PDFs from common crawl, then filtered out the ones with bad OCR. I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those. I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality. For Google Cloud, I aligned the output from Google Cloud with the ground truth. I had to skip RTL languages since they didn't align well. ## Text line detection ![Benchmark chart](static/images/benchmark_chart_small.png) | Model | Time (s) | Time per page (s) | precision | recall | |-----------|------------|---------------------|-------------|----------| | surya | 47.2285 | 0.094452 | 0.835857 | 0.960807 | | tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 | Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage: - tesseract - 32 CPU cores, or 8 workers using 4 cores each - surya - 36 batch size, for 16GB VRAM usage **Methodology** Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level. It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation. I instead used coverage, which calculates: - Precision - how well the predicted bboxes cover ground truth bboxes - Recall - how well ground truth bboxes cover predicted bboxes First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes. Anything with a coverage of 0.5 or higher is considered a match. Then we calculate precision and recall for the whole dataset. ## Layout analysis | Layout Type | precision | recall | |---------------|-------------|----------| | Image | 0.91265 | 0.93976 | | List | 0.80849 | 0.86792 | | Table | 0.84957 | 0.96104 | | Text | 0.93019 | 0.94571 | | Title | 0.92102 | 0.95404 | Time per image - .13 seconds on GPU (A10). **Methodology** I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data. I had to align publaynet labels with the surya layout labels. I was then able to find coverage for each layout type: - Precision - how well the predicted bboxes cover ground truth bboxes - Recall - how well ground truth bboxes cover predicted bboxes ## Reading Order 88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check. **Methodology** I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth. The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct. ## Table Recognition | Model | Row Intersection | Col Intersection | Time Per Image | |-------------------|--------------------|--------------------|------------------| | Surya | 1 | 0.98625 | 0.30202 | | Table transformer | 0.84 | 0.86857 | 0.08082 | Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions. This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker) **Methodology** The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns. ## LaTeX OCR | Method | edit ⬇ | time taken (s) ⬇ | |--------|----------|------------------| | texify | 0.122617 | 35.6345 | This inferences texify on a ground truth set of LaTeX, then does edit distance. This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them. ## Running your own benchmarks You can benchmark the performance of surya on your machine. - Follow the manual install instructions above. - `poetry install --group dev` - installs dev dependencies **Text line detection** This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench). ```shell python benchmark/detection.py --max_rows 256 ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images and detected bboxes - `--pdf_path` will let you specify a pdf to benchmark instead of the default data - `--results_dir` will let you specify a directory to save results to instead of the default one **Text recognition** This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages). ```shell python benchmark/recognition.py --tesseract ``` - `--max_rows` controls how many images to process for the benchmark - `--debug 2` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one - `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder. - Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark. - Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking. This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus). **Layout analysis** This will evaluate surya on the publaynet dataset. ```shell python benchmark/layout.py ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one **Reading Order** ```shell python benchmark/ordering.py ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one **Table Recognition** ```shell python benchmark/table_recognition.py --max_rows 1024 --tatr ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one - `--tatr` specifies whether to also run table transformer **LaTeX OCR** ```shell python benchmark/texify.py --max_rows 128 ``` - `--max_rows` controls how many images to process for the benchmark - `--results_dir` will let you specify a directory to save results to instead of the default one # Training Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation. Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes). # Finetuning Surya OCR You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py). It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed. To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script. ```bash # Tested on 1xH100 GPU # Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise # the default surya ocr weights will be loaded as the initialization python surya/scripts/finetune_ocr.py \ --output_dir $OUTPUT_DIR \ --dataset_name datalab-to/ocr_finetune_example \ --per_device_train_batch_size 64 \ --gradient_checkpointing true \ --max_sequence_length 1024 ``` This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at hi@datalab.to! # Thanks This work would not have been possible without amazing open source AI work: - [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA - [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT - [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman - [Donut](https://github.com/clovaai/donut) from Naver - [transformers](https://github.com/huggingface/transformers) from huggingface - [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model Thank you to everyone who makes open source AI possible. # Citation If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry: ```bibtex @misc{paruchuri2025surya, author = {Vikas Paruchuri and Datalab Team}, title = {Surya: A lightweight document OCR and analysis toolkit}, year = {2025}, howpublished = {\url{https://github.com/VikParuchuri/surya}}, note = {GitHub repository}, } ================================================ FILE: benchmark/detection.py ================================================ import argparse import collections import copy import json import click from benchmark.utils.bbox import get_pdf_lines from benchmark.utils.metrics import precision_recall from benchmark.utils.tesseract import tesseract_parallel from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb from surya.debug.draw import draw_polys_on_image from surya.common.util import rescale_bbox from surya.settings import settings from surya.detection import DetectionPredictor import os import time from tabulate import tabulate import datasets @click.command(help="Benchmark detection model.") @click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None) @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) @click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False) def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool): det_predictor = DetectionPredictor() if pdf_path is not None: pathname = pdf_path doc = open_pdf(pdf_path) page_count = len(doc) page_indices = list(range(page_count)) page_indices = page_indices[:max_rows] images = get_page_images(doc, page_indices) doc.close() image_sizes = [img.size for img in images] correct_boxes = get_pdf_lines(pdf_path, image_sizes) else: pathname = "det_bench" # These have already been shuffled randomly, so sampling from the start is fine dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") images = list(dataset["image"]) images = convert_if_not_rgb(images) correct_boxes = [] for i, boxes in enumerate(dataset["bboxes"]): img_size = images[i].size # 1000,1000 is bbox size for doclaynet correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes]) if settings.DETECTOR_STATIC_CACHE: # Run through one batch to compile the model det_predictor(images[:1]) start = time.time() predictions = det_predictor(images) surya_time = time.time() - start if tesseract: start = time.time() tess_predictions = tesseract_parallel(images) tess_time = time.time() - start else: tess_predictions = [None] * len(images) tess_time = None folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)): surya_boxes = [s.bbox for s in sb.bboxes] surya_polys = [s.polygon for s in sb.bboxes] surya_metrics = precision_recall(surya_boxes, cb) if tb is not None: tess_metrics = precision_recall(tb, cb) else: tess_metrics = None page_metrics[idx] = { "surya": surya_metrics, "tesseract": tess_metrics } if debug: bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) mean_metrics = {} metric_types = sorted(page_metrics[0]["surya"].keys()) models = ["surya"] if tesseract: models.append("tesseract") for k in models: for m in metric_types: metric = [] for page in page_metrics: metric.append(page_metrics[page][k][m]) if k not in mean_metrics: mean_metrics[k] = {} mean_metrics[k][m] = sum(metric) / len(metric) out_data = { "times": { "surya": surya_time, "tesseract": tess_time }, "metrics": mean_metrics, "page_metrics": page_metrics } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types table_data = [ ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], ] if tesseract: table_data.append( ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types] ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ================================================ FILE: benchmark/layout.py ================================================ import collections import copy import json import click from benchmark.utils.metrics import precision_recall from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.input.processing import convert_if_not_rgb from surya.debug.draw import draw_bboxes_on_image from surya.settings import settings import os import time from tabulate import tabulate import datasets @click.command(help="Benchmark surya layout model.") @click.option( "--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=100, ) @click.option("--debug", is_flag=True, help="Run in debug mode.", default=False) def main(results_dir: str, max_rows: int, debug: bool): foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) pathname = "layout_bench" # These have already been shuffled randomly, so sampling from the start is fine dataset = datasets.load_dataset( settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]" ) images = list(dataset["image"]) images = convert_if_not_rgb(images) if settings.LAYOUT_STATIC_CACHE: layout_predictor(images[:1]) start = time.time() layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) label_alignment = { # First is publaynet, second is surya "Image": [["Figure"], ["Picture", "Figure"]], "Table": [["Table"], ["Table", "Form", "TableOfContents"]], "Text": [ ["Text"], [ "Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting", ], ], "List": [["List"], ["ListItem"]], "Title": [["Title"], ["SectionHeader", "Title"]], } page_metrics = collections.OrderedDict() for idx, pred in enumerate(layout_predictions): row = dataset[idx] all_correct_bboxes = [] page_results = {} for label_name in label_alignment: correct_cats, surya_cats = label_alignment[label_name] correct_bboxes = [ b for b, category in zip(row["bboxes"], row["labels"]) if category in correct_cats ] all_correct_bboxes.extend(correct_bboxes) pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats] metrics = precision_recall( pred_bboxes, correct_bboxes, penalize_double=False ) weight = len(correct_bboxes) metrics["weight"] = weight page_results[label_name] = metrics page_metrics[idx] = page_results if debug: bbox_image = draw_bboxes_on_image( all_correct_bboxes, copy.deepcopy(images[idx]) ) bbox_image.save(os.path.join(result_path, f"{idx}_layout.png")) mean_metrics = collections.defaultdict(dict) layout_types = sorted(page_metrics[0].keys()) metric_types = sorted(page_metrics[0][layout_types[0]].keys()) metric_types.remove("weight") for label in layout_types: for m in metric_types: metric = [] total = 0 for page in page_metrics: metric.append( page_metrics[page][label][m] * page_metrics[page][label]["weight"] ) total += page_metrics[page][label]["weight"] value = sum(metric) if value > 0: value /= total mean_metrics[label][m] = value out_data = { "time": surya_time, "metrics": mean_metrics, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table_headers = [ "Layout Type", ] + metric_types table_data = [] for layout_type in layout_types: table_data.append( [ layout_type, ] + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types] ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print( f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total." ) print( "Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold." ) print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ================================================ FILE: benchmark/ordering.py ================================================ import collections import json import click from surya.foundation import FoundationPredictor from surya.input.processing import convert_if_not_rgb from surya.layout import LayoutPredictor from surya.common.polygon import PolygonBox from surya.settings import settings from benchmark.utils.metrics import rank_accuracy import os import time import datasets @click.command(help="Benchmark surya layout for reading order.") @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=None, ) def main(results_dir: str, max_rows: int): foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) pathname = "order_bench" # These have already been shuffled randomly, so sampling from the start is fine split = "train" if max_rows is not None: split = f"train[:{max_rows}]" dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split) images = list(dataset["image"]) images = convert_if_not_rgb(images) start = time.time() layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() mean_accuracy = 0 for idx, order_pred in enumerate(layout_predictions): row = dataset[idx] labels = row["labels"] bboxes = row["bboxes"] pred_positions = [] for label, bbox in zip(labels, bboxes): max_intersection = 0 matching_idx = 0 for pred_box in order_pred.bboxes: intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox)) if intersection > max_intersection: max_intersection = intersection matching_idx = pred_box.position pred_positions.append(matching_idx) accuracy = rank_accuracy(pred_positions, labels) mean_accuracy += accuracy page_results = {"accuracy": accuracy, "box_count": len(labels)} page_metrics[idx] = page_results mean_accuracy /= len(layout_predictions) out_data = { "time": surya_time, "mean_accuracy": mean_accuracy, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) print(f"Mean accuracy is {mean_accuracy:.2f}.") print( f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total." ) print("Mean accuracy is the % of correct ranking pairs.") print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ================================================ FILE: benchmark/recognition.py ================================================ import re import unicodedata from collections import defaultdict import click from benchmark.utils.scoring import overlap_score, overlap_score_exact from surya.input.processing import convert_if_not_rgb from surya.debug.text import draw_text_on_image from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.settings import settings from surya.recognition.languages import CODE_TO_LANGUAGE from benchmark.utils.tesseract import ( tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE, ) from benchmark.utils.textract import textract_ocr_parallel import os import datasets import json import time from tabulate import tabulate KEY_LANGUAGES = [ "Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese", ] def list_in(lst: str | list, lst2: list): if isinstance(lst, str): lst = [lst] return any([item in lst for item in lst2]) def standardize_bullets(text): patterns = [ r"•\s+", r"·\s+", r"○\s+", r"◦\s+", r"▪\s+", r"▫\s+", r"➢\s+", r"➤\s+", r"★\s+", r"✓\s+", r"✗\s+", r"✦\s+", r"\\bullet\s+", ] combined_pattern = "|".join(patterns) text = re.sub(combined_pattern, "*", text) return text def normalize_text(text: str) -> str: # Remove HTML tags text = re.sub(r"<[^>]+>", "", text) # Remove LaTeX tags text = re.sub(r"\\[a-zA-Z]+", "", text) text = standardize_bullets(text) text = unicodedata.normalize("NFKC", text) return text.strip().lower().replace(",", ".") @click.command(help="Benchmark recognition model.") @click.option( "--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None ) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) @click.option( "--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False ) @click.option( "--textract", is_flag=True, help="Run benchmarks on textract.", default=False ) @click.option( "--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28 ) @click.option( "--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28 ) @click.option( "--languages", type=str, help="Comma-separated list of languages to benchmark.", default=None, ) @click.option( "--print_results", is_flag=True, ) def main( results_dir: str, max_rows: int, debug: bool, tesseract: bool, textract: bool, tess_cpus: int, textract_cpus: int, languages: str | None, print_results: bool, ): foundation_predictor = FoundationPredictor() rec_predictor = RecognitionPredictor(foundation_predictor) split = "train" dataset = datasets.load_dataset( settings.RECOGNITION_BENCH_DATASET_NAME, split=split ) if languages: languages = languages.split(",") dataset = dataset.filter( lambda x: list_in(x["language"], languages), num_proc=4 ) if max_rows and max_rows < len(dataset): dataset = dataset.shuffle(seed=1).select(range(max_rows)) images = list(dataset["image"]) images = convert_if_not_rgb(images) bboxes = list(dataset["bboxes"]) line_text = list(dataset["text"]) languages = list(dataset["language"]) print(f"Loaded {len(images)} images. Running OCR...") start = time.time() predictions_by_image = rec_predictor(images, None, bboxes=bboxes) surya_time = time.time() - start lang_list = [] for lang in languages: if not isinstance(lang, list): lang_list.append([lang]) else: lang_list.append(lang) surya_scores = defaultdict(list) img_surya_scores = [] outputs = [] for idx, (pred, ref_text, langs) in enumerate( zip(predictions_by_image, line_text, lang_list) ): pred_text = [line.text for line in pred.text_lines] score_ref_text = [normalize_text(line) for line in ref_text] score_pred_text = [normalize_text(text) for text in pred_text] image_scores, image_weights = overlap_score_exact( score_pred_text, score_ref_text ) normalized_scores = [ score / max(1, weight) for score, weight in zip(image_scores, image_weights) ] image_score = sum(image_scores) / max(1, sum(image_weights)) img_surya_scores.append(image_score) for lang in langs: surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score) assert len(pred_text) == len(ref_text) == len(bboxes[idx]) if debug: for j, (pred_line, ref_line, score, bbox) in enumerate( zip(pred_text, ref_text, normalized_scores, bboxes[idx]) ): image_slice = images[idx].crop(bbox) outputs.append( { "image": image_slice, "bbox": bbox, "score": score, "pred": pred_line, "ref": ref_line, "langs": ",".join(langs), } ) if debug: out_ds = datasets.Dataset.from_list(outputs) out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True) flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]] benchmark_stats = { "surya": { "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)), "lang_scores": { lang: sum(scores) / max(1, len(scores)) for lang, scores in surya_scores.items() }, "time_per_img": surya_time / max(1, len(images)), } } result_path = os.path.join(results_dir, "rec_bench") os.makedirs(result_path, exist_ok=True) with open(os.path.join(result_path, "surya_scores.json"), "w+") as f: json.dump(surya_scores, f) if tesseract: tess_valid = [] tess_langs = [] for idx, lang in enumerate(lang_list): # Tesseract does not support all languages tess_lang = surya_lang_to_tesseract(lang[0]) if tess_lang is None: continue tess_valid.append(idx) tess_langs.append(tess_lang) tess_imgs = [images[i] for i in tess_valid] tess_bboxes = [bboxes[i] for i in tess_valid] tess_reference = [line_text[i] for i in tess_valid] start = time.time() tess_predictions = tesseract_ocr_parallel( tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus ) tesseract_time = time.time() - start tess_scores = defaultdict(list) for idx, (pred, ref_text, lang) in enumerate( zip(tess_predictions, tess_reference, tess_langs) ): image_scores, image_weights, _ = overlap_score(pred, ref_text) image_score = sum(image_scores) / max(1, sum(image_weights)) tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score) flat_tess_scores = [ score for lang in tess_scores for score in tess_scores[lang] ] benchmark_stats["tesseract"] = { "avg_score": sum(flat_tess_scores) / len(flat_tess_scores), "lang_scores": { lang: sum(scores) / len(scores) for lang, scores in tess_scores.items() }, "time_per_img": tesseract_time / len(tess_imgs), } with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f: json.dump(tess_scores, f) if textract: start = time.time() textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus) textract_time = time.time() - start textract_scores = defaultdict(list) for idx, (pred, ref_text, lang) in enumerate( zip(textract_predictions, line_text, lang_list) ): image_scores, image_weights, _ = overlap_score(pred, ref_text) image_score = sum(image_scores) / max(1, sum(image_weights)) for lang in lang: textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score) flat_textract_scores = [ score for lang in textract_scores for score in textract_scores[lang] ] benchmark_stats["textract"] = { "avg_score": sum(flat_textract_scores) / len(flat_textract_scores), "lang_scores": { lang: sum(scores) / len(scores) for lang, scores in textract_scores.items() }, "time_per_img": textract_time / len(images), } print(len(flat_textract_scores)) with open(os.path.join(result_path, "textract_scores.json"), "w+") as f: json.dump(textract_scores, f) with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(benchmark_stats, f) key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages table_data = [ [ "surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"], ] + [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages], ] if tesseract: table_data.append( [ "tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"], ] + [ benchmark_stats["tesseract"]["lang_scores"].get(lang, 0) for lang in key_languages ] ) if textract: table_data.append( [ "textract", benchmark_stats["textract"]["time_per_img"], benchmark_stats["textract"]["avg_score"], ] + [ benchmark_stats["textract"]["lang_scores"][lang] for lang in key_languages ], ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print( "Only a few major languages are displayed. See the result path for additional languages." ) if debug >= 1: bad_detections = [] for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)): if score < 0.8: bad_detections.append((idx, lang, score)) print(f"Found {len(bad_detections)} bad detections. Writing to file...") with open(os.path.join(result_path, "bad_detections.json"), "w+") as f: json.dump(bad_detections, f) if debug == 2: for idx, (image, pred, ref_text, bbox, lang) in enumerate( zip(images, predictions_by_image, line_text, bboxes, lang_list) ): pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" pred_text = [line.text for line in pred.text_lines] pred_image = draw_text_on_image(bbox, pred_text, image.size) pred_image.save(os.path.join(result_path, pred_image_name)) ref_image = draw_text_on_image(bbox, ref_text, image.size) ref_image.save(os.path.join(result_path, ref_image_name)) image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) print(f"Wrote results to {result_path}") if print_results: for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)): print(f"Image {idx}") print("----") for line_idx, (pred_line, ref_line) in enumerate( zip(pred.text_lines, ref_text) ): print(f"Sample {line_idx}") print(f"Pred: {pred_line.text}") print(f"Ref: {ref_line}") print() if settings.TORCH_DEVICE == "xla": import torch_xla.debug.metrics as met print(met.short_metrics_report()) if __name__ == "__main__": main() ================================================ FILE: benchmark/table_recognition.py ================================================ import click import collections import json from surya.debug.draw import draw_bboxes_on_image from tabulate import tabulate from surya.input.processing import convert_if_not_rgb from surya.table_rec import TableRecPredictor from surya.settings import settings from benchmark.utils.metrics import penalized_iou_score from benchmark.utils.tatr import load_tatr, batch_inference_tatr import os import time import datasets @click.command(help="Benchmark table rec dataset") @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=512, ) @click.option("--tatr", is_flag=True, help="Run table transformer.", default=False) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) def main(results_dir: str, max_rows: int, tatr: bool, debug: bool): table_rec_predictor = TableRecPredictor() pathname = "table_rec_bench" # These have already been shuffled randomly, so sampling from the start is fine split = "train" if max_rows is not None: split = f"train[:{max_rows}]" dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split) images = list(dataset["image"]) images = convert_if_not_rgb(images) if settings.TABLE_REC_STATIC_CACHE: # Run through one batch to compile the model table_rec_predictor(images[:1]) start = time.time() table_rec_predictions = table_rec_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() mean_col_iou = 0 mean_row_iou = 0 for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)): row = dataset[idx] pred_row_boxes = [p.bbox for p in pred.rows] pred_col_bboxes = [p.bbox for p in pred.cols] actual_row_bboxes = [r["bbox"] for r in row["rows"]] actual_col_bboxes = [c["bbox"] for c in row["columns"]] row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) page_results = { "row_score": row_score, "col_score": col_score, "row_count": len(actual_row_bboxes), "col_count": len(actual_col_bboxes), } mean_col_iou += col_score mean_row_iou += row_score page_metrics[idx] = page_results if debug: # Save debug images draw_img = image.copy() draw_bboxes_on_image( pred_row_boxes, draw_img, [f"Row {i}" for i in range(len(pred_row_boxes))], ) draw_bboxes_on_image( pred_col_bboxes, draw_img, [f"Col {i}" for i in range(len(pred_col_bboxes))], color="blue", ) draw_img.save(os.path.join(result_path, f"{idx}_bbox.png")) actual_draw_image = image.copy() draw_bboxes_on_image( actual_row_bboxes, actual_draw_image, [f"Row {i}" for i in range(len(actual_row_bboxes))], ) draw_bboxes_on_image( actual_col_bboxes, actual_draw_image, [f"Col {i}" for i in range(len(actual_col_bboxes))], color="blue", ) actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png")) mean_col_iou /= len(table_rec_predictions) mean_row_iou /= len(table_rec_predictions) out_data = { "surya": { "time": surya_time, "mean_row_iou": mean_row_iou, "mean_col_iou": mean_col_iou, "page_metrics": page_metrics, } } if tatr: tatr_model = load_tatr() start = time.time() tatr_predictions = batch_inference_tatr(tatr_model, images, 1) tatr_time = time.time() - start page_metrics = collections.OrderedDict() mean_col_iou = 0 mean_row_iou = 0 for idx, pred in enumerate(tatr_predictions): row = dataset[idx] pred_row_boxes = [p["bbox"] for p in pred["rows"]] pred_col_bboxes = [p["bbox"] for p in pred["cols"]] actual_row_bboxes = [r["bbox"] for r in row["rows"]] actual_col_bboxes = [c["bbox"] for c in row["columns"]] row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) page_results = { "row_score": row_score, "col_score": col_score, "row_count": len(actual_row_bboxes), "col_count": len(actual_col_bboxes), } mean_col_iou += col_score mean_row_iou += row_score page_metrics[idx] = page_results mean_col_iou /= len(tatr_predictions) mean_row_iou /= len(tatr_predictions) out_data["tatr"] = { "time": tatr_time, "mean_row_iou": mean_row_iou, "mean_col_iou": mean_col_iou, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table = [ ["Model", "Row Intersection", "Col Intersection", "Time Per Image"], [ "Surya", f"{out_data['surya']['mean_row_iou']:.2f}", f"{out_data['surya']['mean_col_iou']:.5f}", f"{surya_time / len(images):.5f}", ], ] if tatr: table.append( [ "Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.5f}", f"{tatr_time / len(images):.5f}", ] ) print(tabulate(table, headers="firstrow", tablefmt="github")) print( "Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions." ) print( "Note that table transformers is unbatched, since the example code in the repo is unbatched." ) print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ================================================ FILE: benchmark/texify.py ================================================ import os.path import re import time from pathlib import Path from typing import List import click import datasets from tabulate import tabulate from bs4 import BeautifulSoup from surya.common.surya.schema import TaskNames from surya.settings import settings from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor, OCRResult import json from rapidfuzz.distance import Levenshtein def normalize_text(text): soup = BeautifulSoup(text, "html.parser") # Unwrap math tags for tag in soup.find_all(): if tag.name == "math": tag.unwrap() text = soup.get_text() text = re.sub(r"\n", " ", text) text = re.sub(r"\s+", " ", text) return text.strip() def score_text(predictions, references): lev_dist = [] for p, r in zip(predictions, references): p = normalize_text(p) r = normalize_text(r) lev_dist.append(Levenshtein.normalized_distance(p, r)) return sum(lev_dist) / len(lev_dist) def inference_texify( source_data, predictor: RecognitionPredictor, line_mode: bool = False ): images = [sd["image"] for sd in source_data] mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes tasks = [mode] * len(images) bboxes = [[[0, 0, image.width, image.height]] for image in images] texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes) out_data = [ { "text": texify_predictions[i].text_lines[0].text, "equation": source_data[i]["equation"], } for i in range(len(texify_predictions)) ] return out_data @click.command(help="Benchmark the performance of texify.") @click.option( "--ds_name", type=str, help="Path to dataset file with source images/equations.", default=settings.TEXIFY_BENCHMARK_DATASET, ) @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None ) @click.option( "--line_mode", is_flag=True, help="Use line mode for texify.", default=False ) def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool): foundation_predictor = FoundationPredictor() predictor = RecognitionPredictor(foundation_predictor) ds = datasets.load_dataset(ds_name, split="train") if max_rows: ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True) start = time.time() predictions = inference_texify(ds, predictor, line_mode) time_taken = time.time() - start text = [p["text"] for p in predictions] references = [p["equation"] for p in predictions] scores = score_text(text, references) write_data = { "scores": scores, "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)], } score_table = [["texify", write_data["scores"], time_taken]] score_headers = ["edit", "time taken (s)"] score_dirs = ["⬇", "⬇"] score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] table = tabulate(score_table, headers=["Method", *score_headers]) print() print(table) result_path = Path(results_dir) / "texify_bench" result_path.mkdir(parents=True, exist_ok=True) with open(result_path / "results.json", "w", encoding="utf-8") as f: json.dump(write_data, f, indent=4) if __name__ == "__main__": main() ================================================ FILE: benchmark/utils/__init__.py ================================================ ================================================ FILE: benchmark/utils/bbox.py ================================================ import fitz as pymupdf from surya.common.util import rescale_bbox def get_pdf_lines(pdf_path, img_sizes): doc = pymupdf.open(pdf_path) page_lines = [] for idx, img_size in enumerate(img_sizes): page = doc[idx] blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"] line_boxes = [] for block_idx, block in enumerate(blocks): for l in block["lines"]: line_boxes.append(list(l["bbox"])) page_box = page.bound() pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1] line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] page_lines.append(line_boxes) return page_lines def merge_boxes(box1, box2): return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])) def join_lines(bboxes, max_gap=5): to_merge = {} for i, box1 in bboxes: for z, box2 in bboxes[i + 1:]: j = i + z + 1 if box1 == box2: continue if box1[0] <= box2[0] and box1[2] >= box2[2]: if abs(box1[1] - box2[3]) <= max_gap: if i not in to_merge: to_merge[i] = [] to_merge[i].append(j) merged_boxes = set() merged = [] for i, box in bboxes: if i in merged_boxes: continue if i in to_merge: for j in to_merge[i]: box = merge_boxes(box, bboxes[j][1]) merged_boxes.add(j) merged.append(box) return merged ================================================ FILE: benchmark/utils/metrics.py ================================================ from functools import partial from itertools import repeat import numpy as np from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor def box_area(box): return (box[2] - box[0]) * (box[3] - box[1]) def calculate_iou(box1, box2, box1_only=False): intersection = intersection_area(box1, box2) union = box_area(box1) if not box1_only: union += box_area(box2) - intersection if union == 0: return 0 return intersection / union def match_boxes(preds, references): num_actual = len(references) num_predicted = len(preds) iou_matrix = np.zeros((num_actual, num_predicted)) for i, actual in enumerate(references): for j, pred in enumerate(preds): iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True) sorted_indices = np.argsort(iou_matrix, axis=None)[::-1] sorted_ious = iou_matrix.flatten()[sorted_indices] actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape) assigned_actual = set() assigned_pred = set() matches = [] for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious): i, j = idx if i not in assigned_actual and j not in assigned_pred: iou_val = iou_matrix[i, j] if iou_val > .95: # Account for rounding on box edges iou_val = 1.0 matches.append((i, j, iou_val)) assigned_actual.add(i) assigned_pred.add(j) unassigned_actual = set(range(num_actual)) - assigned_actual unassigned_pred = set(range(num_predicted)) - assigned_pred matches.extend([(i, None, -1.0) for i in unassigned_actual]) matches.extend([(None, j, 0.0) for j in unassigned_pred]) return matches def penalized_iou_score(preds, references): matches = match_boxes(preds, references) iou = sum([match[2] for match in matches]) / len(matches) return iou def intersection_pixels(box1, box2): x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right < x_left or y_bottom < y_top: return set() x_left, x_right = int(x_left), int(x_right) y_top, y_bottom = int(y_top), int(y_bottom) coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom)) pixels = set(zip(coords[0].flat, coords[1].flat)) return pixels def calculate_coverage(box, other_boxes, penalize_double=False): box_area = (box[2] - box[0]) * (box[3] - box[1]) if box_area == 0: return 0 # find total coverage of the box covered_pixels = set() double_coverage = list() for other_box in other_boxes: ia = intersection_pixels(box, other_box) double_coverage.append(list(covered_pixels.intersection(ia))) covered_pixels = covered_pixels.union(ia) # Penalize double coverage - having multiple bboxes overlapping the same pixels double_coverage_penalty = len(double_coverage) if not penalize_double: double_coverage_penalty = 0 covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty) return covered_pixels_count / box_area def intersection_area(box1, box2): x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right < x_left or y_bottom < y_top: return 0.0 return (x_right - x_left) * (y_bottom - y_top) def calculate_coverage_fast(box, other_boxes, penalize_double=False): box = np.array(box) other_boxes = np.array(other_boxes) # Calculate box area box_area = (box[2] - box[0]) * (box[3] - box[1]) if box_area == 0: return 0 x_left = np.maximum(box[0], other_boxes[:, 0]) y_top = np.maximum(box[1], other_boxes[:, 1]) x_right = np.minimum(box[2], other_boxes[:, 2]) y_bottom = np.minimum(box[3], other_boxes[:, 3]) widths = np.maximum(0, x_right - x_left) heights = np.maximum(0, y_bottom - y_top) intersect_areas = widths * heights total_intersect = np.sum(intersect_areas) return min(1.0, total_intersect / box_area) def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True): if len(references) == 0: return { "precision": 1, "recall": 1, } if len(preds) == 0: return { "precision": 0, "recall": 0, } # If we're not penalizing double coverage, we can use a faster calculation coverage_func = calculate_coverage_fast if penalize_double: coverage_func = calculate_coverage with ThreadPoolExecutor(max_workers=workers) as executor: precision_func = partial(coverage_func, penalize_double=penalize_double) precision_iou = executor.map(precision_func, preds, repeat(references)) reference_iou = executor.map(coverage_func, references, repeat(preds)) precision_classes = [1 if i > threshold else 0 for i in precision_iou] precision = sum(precision_classes) / len(precision_classes) recall_classes = [1 if i > threshold else 0 for i in reference_iou] recall = sum(recall_classes) / len(recall_classes) return { "precision": precision, "recall": recall, } def mean_coverage(preds, references): coverages = [] for box1 in references: coverage = calculate_coverage(box1, preds) coverages.append(coverage) for box2 in preds: coverage = calculate_coverage(box2, references) coverages.append(coverage) # Calculate the average coverage over all comparisons if len(coverages) == 0: return 0 coverage = sum(coverages) / len(coverages) return {"coverage": coverage} def rank_accuracy(preds, references): # Preds and references need to be aligned so each position refers to the same bbox pairs = [] for i, pred in enumerate(preds): for j, pred2 in enumerate(preds): if i == j: continue pairs.append((i, j, pred > pred2)) # Find how many of the prediction rankings are correct correct = 0 for i, ref in enumerate(references): for j, ref2 in enumerate(references): if (i, j, ref > ref2) in pairs: correct += 1 return correct / len(pairs) ================================================ FILE: benchmark/utils/scoring.py ================================================ import math from typing import List from rapidfuzz import fuzz def overlap_score(pred_lines: List[str], reference_lines: List[str]): line_scores = [] line_weights = [] line_match = {} for i, pred_line in enumerate(pred_lines): max_score = 0 line_weight = 1 match = None for j, ref_line in enumerate(reference_lines): score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 if score > max_score: max_score = score line_weight = math.sqrt(len(ref_line)) match = j line_scores.append(max_score) line_weights.append(line_weight) line_match[i] = match line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))] return line_scores, line_weights, line_match def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]): line_scores = [] line_weights = [] assert len(pred_lines) == len(reference_lines) for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)): score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 weight = math.sqrt(len(ref_line)) line_scores.append(score * weight) line_weights.append(weight) return line_scores, line_weights ================================================ FILE: benchmark/utils/tatr.py ================================================ import torch from transformers import AutoModelForObjectDetection from surya.settings import settings import numpy as np class MaxResize(object): def __init__(self, max_size=800): self.max_size = max_size def __call__(self, image): width, height = image.size current_max_size = max(width, height) scale = self.max_size / current_max_size resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) return resized_image def to_tensor(image): # Convert PIL Image to NumPy array np_image = np.array(image).astype(np.float32) # Rearrange dimensions to [C, H, W] format np_image = np_image.transpose((2, 0, 1)) # Normalize to [0.0, 1.0] np_image /= 255.0 return torch.from_numpy(np_image) def normalize(tensor, mean, std): for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor def structure_transform(image): image = MaxResize(1000)(image) tensor = to_tensor(image) normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return normalized_tensor def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): width, height = size boxes = box_cxcywh_to_xyxy(out_bbox) boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32) return boxes def outputs_to_objects(outputs, img_sizes, id2label): m = outputs.logits.softmax(-1).max(-1) batch_labels = list(m.indices.detach().cpu().numpy()) batch_scores = list(m.values.detach().cpu().numpy()) batch_bboxes = outputs['pred_boxes'].detach().cpu() batch_objects = [] for i in range(len(img_sizes)): pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])] pred_scores = batch_scores[i] pred_labels = batch_labels[i] objects = [] for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): class_label = id2label[int(label)] if not class_label == 'no object': objects.append({ 'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]} ) rows = [] cols = [] for cell in objects: if cell["label"] == "table column": cols.append(cell) if cell["label"] == "table row": rows.append(cell) batch_objects.append({ "rows": rows, "cols": cols }) return batch_objects def load_tatr(): return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL) def batch_inference_tatr(model, images, batch_size): device = model.device rows_cols = [] for i in range(0, len(images), batch_size): batch_images = images[i:i + batch_size] pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device) # forward pass with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label)) return rows_cols ================================================ FILE: benchmark/utils/tesseract.py ================================================ from typing import List, Optional import numpy as np from tqdm import tqdm from surya.input.processing import slice_bboxes_from_image from surya.settings import settings import os from concurrent.futures import ProcessPoolExecutor from surya.recognition.languages import CODE_TO_LANGUAGE from surya.recognition import RecognitionPredictor from surya.detection import DetectionPredictor def surya_lang_to_tesseract(code: str) -> Optional[str]: lang_str = CODE_TO_LANGUAGE[code] try: tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] except KeyError: return None return tess_lang def tesseract_ocr(img, bboxes, lang: str): import pytesseract line_imgs = slice_bboxes_from_image(img, bboxes) config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' lines = [] for line_img in line_imgs: line = pytesseract.image_to_string(line_img, lang=lang, config=config) lines.append(line) return lines def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size()) if not cpus: cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) # Tesseract uses up to 4 processes per instance # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images tess_parallel = max(tess_parallel_cores // 2, 1) with ProcessPoolExecutor(max_workers=tess_parallel) as executor: tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR") tess_text = list(tess_text) return tess_text def tesseract_bboxes(img): import pytesseract from pytesseract import Output arr_img = np.asarray(img, dtype=np.uint8) ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) bboxes = [] n_boxes = len(ocr['level']) for i in range(n_boxes): # It is possible to merge by line here with line number, but it gives bad results. _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i] bbox = (x, y, x + w, y + h) bboxes.append(bbox) return bboxes def tesseract_parallel(imgs): # Tesseract uses 4 threads per instance tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size()) cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) # Tesseract uses 4 threads per instance tess_parallel = max(tess_parallel_cores // 4, 1) with ProcessPoolExecutor(max_workers=tess_parallel) as executor: tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection") tess_bboxes = list(tess_bboxes) return tess_bboxes TESS_CODE_TO_LANGUAGE = { "afr": "Afrikaans", "amh": "Amharic", "ara": "Arabic", "asm": "Assamese", "aze": "Azerbaijani", "bel": "Belarusian", "ben": "Bengali", "bod": "Tibetan", "bos": "Bosnian", "bre": "Breton", "bul": "Bulgarian", "cat": "Catalan", "ceb": "Cebuano", "ces": "Czech", "chi_sim": "Chinese", "chr": "Cherokee", "cym": "Welsh", "dan": "Danish", "deu": "German", "dzo": "Dzongkha", "ell": "Greek", "eng": "English", "epo": "Esperanto", "est": "Estonian", "eus": "Basque", "fas": "Persian", "fin": "Finnish", "fra": "French", "fry": "Western Frisian", "guj": "Gujarati", "gla": "Scottish Gaelic", "gle": "Irish", "glg": "Galician", "heb": "Hebrew", "hin": "Hindi", "hrv": "Croatian", "hun": "Hungarian", "hye": "Armenian", "iku": "Inuktitut", "ind": "Indonesian", "isl": "Icelandic", "ita": "Italian", "jav": "Javanese", "jpn": "Japanese", "kan": "Kannada", "kat": "Georgian", "kaz": "Kazakh", "khm": "Khmer", "kir": "Kyrgyz", "kor": "Korean", "lao": "Lao", "lat": "Latin", "lav": "Latvian", "lit": "Lithuanian", "mal": "Malayalam", "mar": "Marathi", "mkd": "Macedonian", "mlt": "Maltese", "mon": "Mongolian", "msa": "Malay", "mya": "Burmese", "nep": "Nepali", "nld": "Dutch", "nor": "Norwegian", "ori": "Oriya", "pan": "Punjabi", "pol": "Polish", "por": "Portuguese", "pus": "Pashto", "ron": "Romanian", "rus": "Russian", "san": "Sanskrit", "sin": "Sinhala", "slk": "Slovak", "slv": "Slovenian", "snd": "Sindhi", "spa": "Spanish", "sqi": "Albanian", "srp": "Serbian", "swa": "Swahili", "swe": "Swedish", "syr": "Syriac", "tam": "Tamil", "tel": "Telugu", "tgk": "Tajik", "tha": "Thai", "tir": "Tigrinya", "tur": "Turkish", "uig": "Uyghur", "ukr": "Ukrainian", "urd": "Urdu", "uzb": "Uzbek", "vie": "Vietnamese", "yid": "Yiddish" } TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} ================================================ FILE: benchmark/utils/textract.py ================================================ import os from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import traceback from surya.input.processing import slice_bboxes_from_image from surya.recognition import RecognitionPredictor def textract_ocr(extractor, img): try: document = extractor.detect_document_text(file_source=img) return [line.text for line in document.lines] except: traceback.print_exc() return [None] def textract_ocr_parallel(imgs, cpus=None): from textractor import Textractor # Optional dependency extractor = Textractor(profile_name='default') parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size()) if not cpus: cpus = os.cpu_count() parallel_cores = min(parallel_cores, cpus) with ThreadPoolExecutor(max_workers=parallel_cores) as executor: textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR") textract_text = list(textract_text) return textract_text ================================================ FILE: benchmark/utils/verify_benchmark_scores.py ================================================ import json import click def verify_layout(data): scores = data["metrics"] for layout_type, metrics in scores.items(): if layout_type == "List": # Skip lists since none appear early on continue if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6: raise ValueError("Scores do not meet the required threshold") def verify_det(data): scores = data["metrics"]["surya"] if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: raise ValueError("Scores do not meet the required threshold") def verify_rec(data): scores = data["surya"] if scores["avg_score"] <= 0.9: raise ValueError("Scores do not meet the required threshold") def verify_order(data): score = data["mean_accuracy"] if score < 0.75: raise ValueError("Scores do not meet the required threshold") def verify_table_rec(data): row_score = data["surya"]["mean_row_iou"] col_score = data["surya"]["mean_col_iou"] if row_score < 0.75 or col_score < 0.75: raise ValueError("Scores do not meet the required threshold") def verify_texify(data): edit_dist = data["scores"] if edit_dist > 0.2: raise ValueError("Scores do not meet the required threshold") @click.command(help="Verify benchmark scores") @click.argument("file_path", type=str) @click.option( "--bench_type", type=str, help="Type of benchmark to verify", default="detection" ) def main(file_path, bench_type): with open(file_path, "r") as file: data = json.load(file) if bench_type == "detection": verify_det(data) elif bench_type == "recognition": verify_rec(data) elif bench_type == "layout": verify_layout(data) elif bench_type == "ordering": verify_order(data) elif bench_type == "table_recognition": verify_table_rec(data) elif bench_type == "texify": verify_texify(data) else: raise ValueError("Invalid benchmark type") if __name__ == "__main__": main() ================================================ FILE: detect_layout.py ================================================ from surya.scripts.detect_layout import detect_layout_cli if __name__ == "__main__": detect_layout_cli() ================================================ FILE: detect_text.py ================================================ from surya.scripts.detect_text import detect_text_cli if __name__ == "__main__": detect_text_cli() ================================================ FILE: ocr_app.py ================================================ from surya.scripts.run_streamlit_app import streamlit_app_cli if __name__ == "__main__": streamlit_app_cli() ================================================ FILE: ocr_latex.py ================================================ from surya.scripts.ocr_latex import ocr_latex_cli if __name__ == "__main__": ocr_latex_cli() ================================================ FILE: ocr_text.py ================================================ from surya.scripts.ocr_text import ocr_text_cli if __name__ == "__main__": ocr_text_cli() ================================================ FILE: pyproject.toml ================================================ [tool.poetry] name = "surya-ocr" version = "0.17.1" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" license = "GPL-3.0-or-later" repository = "https://github.com/VikParuchuri/surya" keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"] packages = [ {include = "surya"} ] [tool.poetry.dependencies] python = "^3.10" transformers = ">=4.56.1" torch = "^2.7.0" pydantic = "^2.5.3" pydantic-settings = "^2.1.0" python-dotenv = "^1.0.0" pillow = "^10.2.0" pypdfium2 = "=4.30.0" filetype = "^1.2.0" click = "^8.1.8" platformdirs = "^4.3.6" opencv-python-headless = "==4.11.0.86" einops = "^0.8.1" pre-commit = "^4.2.0" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" pytesseract = "^0.3.10" pymupdf = "^1.23.8" datasets = "^2.16.1" rapidfuzz = "^3.6.1" streamlit = "^1.31.0" pytest = "^8.3.4" pdftext = "^0.5.1" tabulate = "^0.9.0" [tool.poetry.scripts] surya_detect = "surya.scripts.detect_text:detect_text_cli" surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" surya_layout = "surya.scripts.detect_layout:detect_layout_cli" surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" surya_table = "surya.scripts.table_recognition:table_recognition_cli" surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli" texify_gui = "surya.scripts.run_texify_app:texify_app_cli" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [[tool.poetry.source]] name = "libtpu-releases" url = "https://storage.googleapis.com/libtpu-releases/index.html" priority = "supplemental" [[tool.poetry.source]] name = "libtpu-wheels" url = "https://storage.googleapis.com/libtpu-wheels/index.html" priority = "supplemental" [tool.poetry.group.xla] optional = true [tool.poetry.group.xla.dependencies] torch-xla = {version = "^2.4.1", extras = ["tpu"]} ================================================ FILE: pytest.ini ================================================ [pytest] testpaths=tests pythonpath=. filterwarnings = ignore::UserWarning ignore::PendingDeprecationWarning ignore::DeprecationWarning ================================================ FILE: signatures/version1/cla.json ================================================ { "signedContributors": [ { "name": "rishiraj", "id": 44090649, "comment_id": 2170578748, "created_at": "2024-06-15T19:31:20Z", "repoId": 741297064, "pullRequestNo": 135 }, { "name": "mmacvicar", "id": 59354, "comment_id": 2236493182, "created_at": "2024-07-18T13:17:43Z", "repoId": 741297064, "pullRequestNo": 152 }, { "name": "jimexist", "id": 622789, "comment_id": 2255151376, "created_at": "2024-07-29T07:23:55Z", "repoId": 741297064, "pullRequestNo": 160 }, { "name": "michaeldriscoll-avant", "id": 85255083, "comment_id": 2259143427, "created_at": "2024-07-30T20:21:33Z", "repoId": 741297064, "pullRequestNo": 161 }, { "name": "EdoardoPona", "id": 29152472, "comment_id": 2271115922, "created_at": "2024-08-06T11:58:00Z", "repoId": 741297064, "pullRequestNo": 167 }, { "name": "hidenori-endo", "id": 15546605, "comment_id": 2307217499, "created_at": "2024-08-23T14:31:17Z", "repoId": 741297064, "pullRequestNo": 182 }, { "name": "dobosevych", "id": 12053536, "comment_id": 2430376828, "created_at": "2024-10-22T21:48:34Z", "repoId": 741297064, "pullRequestNo": 220 }, { "name": "iammosespaulr", "id": 28682735, "comment_id": 2447941238, "created_at": "2024-10-30T17:55:23Z", "repoId": 741297064, "pullRequestNo": 235 }, { "name": "ArthurMor4is", "id": 42987302, "comment_id": 2515315717, "created_at": "2024-12-03T18:37:45Z", "repoId": 741297064, "pullRequestNo": 255 }, { "name": "tarun-menta", "id": 66506307, "comment_id": 2543457960, "created_at": "2024-12-15T05:43:33Z", "repoId": 741297064, "pullRequestNo": 261 }, { "name": "jonaskahn", "id": 4338500, "comment_id": 2556622097, "created_at": "2024-12-20T09:36:20Z", "repoId": 741297064, "pullRequestNo": 269 }, { "name": "kumsumit", "id": 95072784, "comment_id": 2574534622, "created_at": "2025-01-07T07:05:59Z", "repoId": 741297064, "pullRequestNo": 276 }, { "name": "kevinhu", "id": 6051736, "comment_id": 2614135351, "created_at": "2025-01-25T23:34:12Z", "repoId": 741297064, "pullRequestNo": 291 }, { "name": "zanussbaum", "id": 33707069, "comment_id": 3008673416, "created_at": "2025-06-26T14:20:46Z", "repoId": 741297064, "pullRequestNo": 403 }, { "name": "mebriki", "id": 35892987, "comment_id": 3154706976, "created_at": "2025-08-05T10:54:27Z", "repoId": 741297064, "pullRequestNo": 418 }, { "name": "starikovplusplus", "id": 56602036, "comment_id": 3168958011, "created_at": "2025-08-08T18:29:50Z", "repoId": 741297064, "pullRequestNo": 423 }, { "name": "sandy0kwon", "id": 78377296, "comment_id": 3207932260, "created_at": "2025-08-20T20:07:15Z", "repoId": 741297064, "pullRequestNo": 434 }, { "name": "n0kovo", "id": 16690056, "comment_id": 3208251881, "created_at": "2025-08-20T22:22:06Z", "repoId": 741297064, "pullRequestNo": 435 }, { "name": "davidxifeng", "id": 158052, "comment_id": 3249594859, "created_at": "2025-09-03T14:52:16Z", "repoId": 741297064, "pullRequestNo": 445 }, { "name": "u-ashish", "id": 14264791, "comment_id": 3258734182, "created_at": "2025-09-05T15:16:48Z", "repoId": 741297064, "pullRequestNo": 447 }, { "name": "Mohking1", "id": 63689545, "comment_id": 3314908963, "created_at": "2025-09-20T11:21:42Z", "repoId": 741297064, "pullRequestNo": 462 }, { "name": "wkpark", "id": 232347, "comment_id": 3330009557, "created_at": "2025-09-24T17:42:55Z", "repoId": 741297064, "pullRequestNo": 464 }, { "name": "coval3nte", "id": 65908512, "comment_id": 3848768229, "created_at": "2026-02-04T17:28:32Z", "repoId": 741297064, "pullRequestNo": 483 }, { "name": "bailey-coding", "id": 29517254, "comment_id": 3955014177, "created_at": "2026-02-24T22:09:52Z", "repoId": 741297064, "pullRequestNo": 487 }, { "name": "Br1an67", "id": 29810238, "comment_id": 3979412700, "created_at": "2026-03-01T07:32:18Z", "repoId": 741297064, "pullRequestNo": 489 } ] } ================================================ FILE: static/fonts/.gitignore ================================================ * !.gitignore ================================================ FILE: surya/__init__.py ================================================ ================================================ FILE: surya/common/__init__.py ================================================ ================================================ FILE: surya/common/adetr/decoder.py ================================================ from typing import Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithNoAttention from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from surya.common.pretrained import SuryaPreTrainedModel from surya.common.xla import mark_step _MAX_SQRT_GRADIENT = 1000.0 class WrappedEmbedding(nn.Embedding): def forward(self, input_ids, *args, **kwargs): return super().forward(input_ids) class SuryaADETRDecoderRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): variance = x.pow(2).mean(-1, keepdim=True) # Add clipping to prevent division by zero variance = torch.clamp(variance, min=self.eps) return x * torch.rsqrt(variance) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) # Clamp to float16 range f16_info = torch.finfo(x.dtype) output = output.clamp(min=f16_info.min, max=f16_info.max) output = torch.where( torch.isnan(output), torch.tensor(0.0, device=output.device), output ) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm) class SuryaADETRDecoderRotaryEmbedding(nn.Module): def __init__(self, dim, base=10000, device=None): super().__init__() self.dim = dim self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) ) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 1, 2 ) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class SuryaADETRDecoderSdpaCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper Modified for GQA """ def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_attention_heads * self.head_dim, self.hidden_size, bias=True ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # Encoder attention mask currently ignored bsz, q_len, _ = hidden_states.size() _, v_len, _ = encoder_hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_attention_heads, self.head_dim ).transpose(1, 2) if self.key_states is None: key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) key_states = key_states.view( bsz, v_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, v_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if use_cache: self._update_cache(key_states, value_states) else: key_states = self.key_states value_states = self.value_states key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=None, dropout_p=self.attention_dropout if self.training else 0.0, scale=self.head_dim**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output def _clear_cache(self): if self.value_states is not None: del self.value_states if self.key_states is not None: del self.key_states def _setup_cache(self, batch_size, device, dtype=None): # Setup initial caches self.value_states = None self.key_states = None @torch.no_grad() def _update_cache(self, key_states, value_states, **cache_kwargs): self.value_states = value_states self.key_states = key_states class SuryaADETRDecoderSdpaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None): super().__init__() self.config = config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_attention_heads * self.head_dim, self.hidden_size, bias=True ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, ) self.static_cache = static_cache self.max_boxes = max_boxes def forward( self, hidden_states: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: bool = False, window_attn: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Final is bsz, num_attention_heads, seq_len, head_dim query_states = query_states.view( bsz, q_len, self.num_attention_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if use_cache and hasattr(self, "key_states"): cache_kwargs = { "cache_position": cache_position, "window_attn": window_attn, } key_states, value_states = self._update_cache( key_states, value_states, **cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # Mask is batch, head, seq_len, kv_len causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] if cache_position is not None and self.static_cache: current_pos = cache_position[-1] causal_mask[:, :, :, current_pos + 1 :] = torch.finfo( causal_mask.dtype ).min attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, scale=self.head_dim**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output def _setup_cache(self, batch_size, device, dtype=None): if dtype is None and self.config.torch_dtype is not None: dtype = self.config.torch_dtype dtype = dtype if dtype is not None else torch.float32 # Setup initial caches self.value_states = None self.key_states = None if self.static_cache: cache_shape = ( batch_size, self.num_key_value_heads, self.max_boxes, self.head_dim, ) self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) def _clear_cache(self): if self.value_states is not None: del self.value_states if self.key_states is not None: del self.key_states def _update_static_cache(self, key_states, value_states, **cache_kwargs): cache_position = cache_kwargs.get("cache_position") k_out, v_out = ( self.key_states.to(key_states.device), self.value_states.to(value_states.device), ) k_out[:, :, cache_position] = key_states.to(k_out.dtype) v_out[:, :, cache_position] = value_states.to(v_out.dtype) self.key_states, self.value_states = k_out, v_out return k_out, v_out def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): k_out = key_states if self.key_states is not None: k_out = torch.cat([self.key_states, key_states], dim=2) v_out = value_states if self.value_states is not None: v_out = torch.cat([self.value_states, value_states], dim=2) self.key_states, self.value_states = k_out, v_out return k_out, v_out @torch.no_grad() def _update_cache(self, key_states, value_states, **cache_kwargs): if self.static_cache: return self._update_static_cache(key_states, value_states, **cache_kwargs) return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) class SuryaADETRDecoderMlp(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) if config.hidden_activation is None: config.hidden_activation = "gelu_pytorch_tanh" hidden_activation = config.hidden_activation self.act_fn = ACT2FN[hidden_activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class SuryaADETRDecoderLayer(nn.Module): def __init__(self, config, layer_idx, static_cache=False, max_boxes=None): super().__init__() self.cross_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.temporal_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.temporal_block = None if layer_idx in config.self_attn_layers: self.temporal_block = SuryaADETRDecoderSdpaAttention( config, static_cache=static_cache, max_boxes=max_boxes ) self.cross_attn_block = None if layer_idx in config.cross_attn_layers: self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config) self.window_attn = layer_idx not in config.global_attn_layers self.channel_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.mlp_block = SuryaADETRDecoderMlp(config) self.double_residual_flow = getattr(config, "double_residual_flow", False) def forward( self, activations: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None, cache_position: torch.Tensor = None, use_cache: bool = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: if self.double_residual_flow: return self.double_res_forward( activations, position_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) hidden_states = activations if self.cross_attn_block is not None: # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(hidden_states) cross_attn_path = self.cross_attn_block( cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache, ) hidden_states = cross_attn_path + hidden_states if self.temporal_block is not None: temporal_inputs = self.temporal_pre_norm( hidden_states ) # RMSNorm introduces slight slight differences temporal_path = self.temporal_block( temporal_inputs, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn, ) hidden_states = temporal_path + hidden_states block_input = hidden_states hidden_states = self.channel_pre_norm(block_input) hidden_states = self.mlp_block(hidden_states) hidden_states = hidden_states + block_input return hidden_states def double_res_forward( self, activations: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None, cache_position: torch.Tensor = None, use_cache: bool = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: raw_activations = activations if self.cross_attn_block is not None: # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(activations) cross_attn_path = self.cross_attn_block( cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache, ) cross_attn_output = cross_attn_path + raw_activations else: cross_attn_output = raw_activations if self.temporal_block is not None: inputs_normalized = self.temporal_pre_norm( cross_attn_output ) # RMSNorm introduces slight slight differences hidden_states = self.temporal_block( inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn, ) residual = hidden_states + raw_activations else: residual = cross_attn_output hidden_states = self.channel_pre_norm(residual) hidden_states = self.mlp_block(hidden_states) hidden_states = hidden_states + residual return hidden_states class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel): config_class = PretrainedConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SuryaADETRDecoderLayer"] _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True def _init_weights(self, module): if isinstance(module, SuryaADETRDecoderSdpaAttention): torch.nn.init.normal_( module.q_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.k_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.v_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.o_proj.weight, mean=0.0, std=self.config.init_std ) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) if getattr(module, "bias", None) is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers for layer in layers: if layer.temporal_block: layer.temporal_block._setup_cache(batch, device, dtype) if layer.cross_attn_block: layer.cross_attn_block._setup_cache(batch, device, dtype) def _clear_cache(self): layers = getattr(self, "model", self).layers for layer in layers: if layer.temporal_block: layer.temporal_block._clear_cache() if layer.cross_attn_block: layer.cross_attn_block._clear_cache() def reset_cache(self, batch, device, dtype): pass def _tie_weights(self): pass def tie_weights(self): pass class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`] Args: config: PretrainedConfig """ def __init__( self, config: PretrainedConfig, embedder: nn.Module = None, max_boxes: int = None, static_cache: bool = False, ): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.causal = config.causal self.embed_tokens = embedder self.max_boxes = max_boxes self.static_cache = static_cache self.layers = nn.ModuleList( [ SuryaADETRDecoderLayer( config, layer_idx, static_cache=static_cache, max_boxes=max_boxes ) for layer_idx in range(config.num_hidden_layers) ] ) self.final_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.gradient_checkpointing = False self.register_buffer( "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False, ) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings def get_input_embeddings(self): return self.embed_tokens # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, input_boxes_counts: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, prefill: bool = False, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if self.gradient_checkpointing and self.training and use_cache: use_cache = False inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts) hidden_states = inputs_embeds if use_cache and prefill: self._setup_cache( self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype, ) if cache_position is None: cache_position = torch.arange( hidden_states.shape[1], device=hidden_states.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position ) all_hidden_states = () if output_hidden_states else None for i, residual_block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) else: hidden_states = residual_block( hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) hidden_states = self.final_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states, ) # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # Ignore copy def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if not self.causal: return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] target_length = max(self.max_boxes, sequence_length) diagonal = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) causal_mask = diagonal if sequence_length != 1: # Select the upper triangular part of the matrix, but unmask current token (the diagonal) # triu will be the min_dtype, everything else is 0 (attended to) causal_mask = torch.triu(diagonal, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand( input_tensor.shape[0], 1, -1, -1 ) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: # Mask positions in the causal mask that are masked in the attention mask mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ :, None, None, : ].eq(0.0) causal_mask[..., :mask_length] = causal_mask[ ..., :mask_length ].masked_fill(padding_mask, min_dtype) if attention_mask is not None and attention_mask.device.type == "cuda": # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask ================================================ FILE: surya/common/donut/encoder.py ================================================ import collections.abc import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.pytorch_utils import ( find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, ) from transformers.utils import ModelOutput from transformers import DonutSwinConfig from surya.common.pretrained import SuryaPreTrainedModel from surya.common.xla import mark_step _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] @dataclass # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin class DonutSwinEncoderOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass class DonutSwinModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None # Copied from transformers.models.swin.modeling_swin.window_partition def window_partition(input_feature, window_size): """ Partitions the given input into windows. """ batch_size, height, width, num_channels = input_feature.shape input_feature = input_feature.view( batch_size, height // window_size, window_size, width // window_size, window_size, num_channels, ) windows = ( input_feature.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size, window_size, num_channels) ) return windows # Copied from transformers.models.swin.modeling_swin.window_reverse def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ num_channels = windows.shape[-1] windows = windows.view( -1, height // window_size, width // window_size, window_size, window_size, num_channels, ) windows = ( windows.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, height, width, num_channels) ) return windows # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin class DonutSwinEmbeddings(nn.Module): """ Construct the patch and position embeddings. Optionally, also the mask token. """ def __init__(self, config, use_mask_token=False): super().__init__() self.patch_embeddings = DonutSwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.mask_token = ( nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None ) self.position_embeddings = None self.row_embeddings = None self.column_embeddings = None if config.use_absolute_embeddings: self.position_embeddings = nn.Parameter( torch.zeros(1, num_patches + 1, config.embed_dim) ) if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings: self.row_embeddings = nn.Parameter( torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim) ) self.column_embeddings = nn.Parameter( torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim) ) self.norm = nn.LayerNorm(config.embed_dim) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 if num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] h0 = height // self.config.patch_size w0 = width // self.config.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 patch_pos_embed = patch_pos_embed.reshape( 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward( self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: _, num_channels, height, width = pixel_values.shape embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding( embeddings, height, width ) else: embeddings = embeddings + self.position_embeddings[:, :seq_len] if self.row_embeddings is not None and self.column_embeddings is not None: # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... row_embeddings = self.row_embeddings[ :, : output_dimensions[0], : ].repeat_interleave(output_dimensions[1], dim=1) column_embeddings = self.column_embeddings[ :, : output_dimensions[1], : ].repeat(1, output_dimensions[0], 1) embeddings = embeddings + row_embeddings + column_embeddings return embeddings, output_dimensions # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin class DonutSwinPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.embed_dim image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) patch_size = ( patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) ) num_patches = (image_size[1] // patch_size[1]) * ( image_size[0] // patch_size[0] ) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.grid_size = ( image_size[0] // patch_size[0], image_size[1] // patch_size[1], ) self.projection = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size ) def maybe_pad(self, pixel_values, height, width): if width % self.patch_size[1] != 0: pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) pixel_values = nn.functional.pad(pixel_values, pad_values) if height % self.patch_size[0] != 0: pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values def forward( self, pixel_values: Optional[torch.FloatTensor] ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) embeddings = embeddings.flatten(2).transpose(1, 2) return embeddings, output_dimensions # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging class DonutSwinPatchMerging(nn.Module): """ Patch Merging Layer. Args: input_resolution (`Tuple[int]`): Resolution of input feature. dim (`int`): Number of input channels. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): Normalization layer class. """ def __init__( self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def maybe_pad(self, input_feature, height, width): should_pad = (height % 2 == 1) or (width % 2 == 1) if should_pad: pad_values = (0, 0, 0, width % 2, 0, height % 2) input_feature = nn.functional.pad(input_feature, pad_values) return input_feature def forward( self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] ) -> torch.Tensor: height, width = input_dimensions # `dim` is height * width batch_size, dim, num_channels = input_feature.shape input_feature = input_feature.view(batch_size, height, width, num_channels) # pad input to be disible by width and height, if needed input_feature = self.maybe_pad(input_feature, height, width) # [batch_size, height/2, width/2, num_channels] input_feature_0 = input_feature[:, 0::2, 0::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_1 = input_feature[:, 1::2, 0::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_2 = input_feature[:, 0::2, 1::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 4*num_channels input_feature = torch.cat( [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1 ) input_feature = input_feature.view( batch_size, -1, 4 * num_channels ) # batch_size height/2*width/2 4*C input_feature = self.norm(input_feature) input_feature = self.reduction(input_feature) return input_feature # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin class DonutSwinSelfAttention(nn.Module): def __init__(self, config, dim, num_heads, num_kv_heads, window_size): super().__init__() if dim % num_heads != 0: raise ValueError( f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" ) self.num_attention_heads = num_heads self.num_kv_heads = num_kv_heads self.kv_repeats = self.num_attention_heads // self.num_kv_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.kv_head_size = self.num_kv_heads * self.attention_head_size self.window_size = ( window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) ) self.relative_position_bias_table = nn.Parameter( torch.zeros( (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads ) ) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.query = nn.Linear( self.all_head_size, self.all_head_size, bias=config.qkv_bias ) self.key = nn.Linear( self.all_head_size, self.kv_head_size, bias=config.qkv_bias ) self.value = nn.Linear( self.all_head_size, self.kv_head_size, bias=config.qkv_bias ) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def transpose_kv_for_scores(self, x, repeats): new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size) x = x.view(new_x_shape) x = x.repeat( 1, 1, repeats, 1 ) # repeat the values for each key-value head to match query dim return x.permute(0, 2, 1, 3).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape mixed_query_layer = self.query(hidden_states) # Final is (batch_size, num_attention_heads, seq_len, attention_head_size) key_layer = self.transpose_kv_for_scores( self.key(hidden_states), self.kv_repeats ) value_layer = self.transpose_kv_for_scores( self.value(hidden_states), self.kv_repeats ) query_layer = self.transpose_for_scores(mixed_query_layer) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1, ) relative_position_bias = ( relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) ) relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1) if attention_mask is None: attention_mask = relative_position_bias else: mask_shape = attention_mask.shape[0] repeat_count = batch_size // mask_shape attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1) attention_mask = attention_mask + relative_position_bias attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=0.0, scale=self.attention_head_size**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, dim, num_channels) outputs = (attn_output,) return outputs # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput class DonutSwinSelfOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, input_tensor: torch.Tensor ) -> torch.Tensor: return self.dense(hidden_states) # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin class DonutSwinAttention(nn.Module): def __init__(self, config, dim, num_heads, num_kv_heads, window_size): super().__init__() self.self = DonutSwinSelfAttention( config, dim, num_heads, num_kv_heads, window_size ) self.output = DonutSwinSelfOutput(config, dim) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = ( self.self.attention_head_size * self.self.num_attention_heads ) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, head_mask, output_attentions ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs # Copied from transformers.models.swin.modeling_swin.SwinIntermediate class DonutSwinIntermediate(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states # Copied from transformers.models.swin.modeling_swin.SwinOutput class DonutSwinOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.dense(hidden_states) # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin class DonutSwinLayer(nn.Module): def __init__( self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0 ): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = DonutSwinAttention( config, dim, num_heads, num_kv_heads, window_size=self.window_size ) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DonutSwinIntermediate(config, dim) self.output = DonutSwinOutput(config, dim) def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = int(0) self.window_size = ( torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) width_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) count = 0 for height_slice in height_slices: for width_slice in width_slices: img_mask[:, height_slice, width_slice, :] = count count += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill( attn_mask != 0, float(-100.0) ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None return attn_mask def maybe_pad(self, hidden_states, height, width): pad_right = (self.window_size - width % self.window_size) % self.window_size pad_bottom = (self.window_size - height % self.window_size) % self.window_size pad_values = (0, 0, 0, pad_right, 0, pad_bottom) hidden_states = nn.functional.pad(hidden_states, pad_values) return hidden_states, pad_values def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: if not always_partition: self.set_shift_and_window_size(input_dimensions) else: pass height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) # pad hidden_states to multiples of window size hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) _, height_pad, width_pad, _ = hidden_states.shape # cyclic shift if self.shift_size > 0: shifted_hidden_states = torch.roll( hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) ) else: shifted_hidden_states = hidden_states # partition windows hidden_states_windows = window_partition( shifted_hidden_states, self.window_size ) hidden_states_windows = hidden_states_windows.view( -1, self.window_size * self.window_size, channels ) attn_mask = self.get_attn_mask( height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device, ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, ) attention_output = attention_outputs[0] attention_windows = attention_output.view( -1, self.window_size, self.window_size, channels ) shifted_windows = window_reverse( attention_windows, self.window_size, height_pad, width_pad ) # reverse cyclic shift if self.shift_size > 0: attention_windows = torch.roll( shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2) ) else: attention_windows = shifted_windows was_padded = pad_values[3] > 0 or pad_values[5] > 0 if was_padded: attention_windows = attention_windows[:, :height, :width, :].contiguous() attention_windows = attention_windows.view(batch_size, height * width, channels) hidden_states = shortcut + attention_windows layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) layer_outputs = ( (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) ) return layer_outputs # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin class DonutSwinStage(nn.Module): def __init__( self, config, layer_num, dim, input_resolution, depth, num_heads, num_kv_heads, downsample, ): super().__init__() self.config = config self.dim = dim self.blocks = nn.ModuleList( [ DonutSwinLayer( config=config, dim=dim, input_resolution=input_resolution, num_heads=num_heads, num_kv_heads=num_kv_heads, shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=nn.LayerNorm ) else: self.downsample = None self.pointing = False self.positional_encoding = None if config.use_positional_embeddings: self.positional_encoding = self.build_2d_sincos_position_embedding( input_resolution[1], input_resolution[0], embed_dim=dim, ) @staticmethod def build_2d_sincos_position_embedding( width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32, ): grid_w = torch.arange(int(width), dtype=dtype, device=device) grid_h = torch.arange(int(height), dtype=dtype, device=device) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") if embed_dim % 4 != 0: raise ValueError( "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" ) pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] return torch.concat( [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1 )[None, :, :] def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor]: height, width = input_dimensions if self.positional_encoding is not None: hidden_states = hidden_states + self.positional_encoding.to( hidden_states.dtype ).to(hidden_states.device) for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None layer_outputs = layer_module( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = hidden_states if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) hidden_states = self.downsample( hidden_states_before_downsampling, input_dimensions ) else: output_dimensions = (height, width, height, width) stage_outputs = ( hidden_states, hidden_states_before_downsampling, output_dimensions, ) if output_attentions: stage_outputs += layer_outputs[1:] return stage_outputs # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin class DonutSwinEncoder(nn.Module): def __init__(self, config, grid_size): super().__init__() self.num_layers = len(config.depths) self.config = config self.layers = nn.ModuleList( [ DonutSwinStage( config=config, layer_num=i_layer, dim=int(config.embed_dim * 2**i_layer), input_resolution=( grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer), ), depth=config.depths[i_layer], num_heads=config.num_heads[i_layer], num_kv_heads=config.num_kv_heads[i_layer] if hasattr(config, "num_kv_heads") else config.num_heads[i_layer], downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, ) for i_layer in range(self.num_layers) ] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, output_hidden_states_before_downsampling: Optional[bool] = False, always_partition: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, DonutSwinEncoderOutput]: all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view( batch_size, *input_dimensions, hidden_size ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) else: layer_outputs = layer_module( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] output_dimensions = layer_outputs[2] input_dimensions = (output_dimensions[-2], output_dimensions[-1]) if output_hidden_states and output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states_before_downsampling.shape # rearrange b (h w) c -> b c h w # here we use the original (not downsampled) height and width reshaped_hidden_state = hidden_states_before_downsampling.view( batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size, ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states_before_downsampling,) all_reshaped_hidden_states += (reshaped_hidden_state,) elif output_hidden_states and not output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view( batch_size, *input_dimensions, hidden_size ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[3:] if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None ) return DonutSwinEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshaped_hidden_states, ) # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin class DonutSwinPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DonutSwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["DonutSwinStage"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) ================================================ FILE: surya/common/donut/processor.py ================================================ from typing import Dict, Union, Optional, List, Iterable import cv2 from torch import TensorType from transformers import ImageProcessingMixin from transformers.image_processing_utils import BatchFeature from transformers.image_transforms import pad, normalize from transformers.image_utils import ( ImageInput, ChannelDimension, make_list_of_images, get_image_size, ) import numpy as np from PIL import Image import PIL from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from surya.common.s3 import S3DownloaderMixin from surya.settings import settings class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin): def __init__( self, *args, max_size=None, align_long_axis=False, rescale_factor: Union[int, float] = 1 / 255, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, **kwargs, ): super().__init__(*args, **kwargs) self.patch_size = kwargs.get("patch_size", (4, 4)) self.max_size = max_size self.do_align_long_axis = align_long_axis self.resample = Image.Resampling.BILINEAR self.rescale_factor = rescale_factor self.image_mean = ( image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN ) self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD def __call__(self, images, **kwargs) -> PIL.Image.Image: """Preprocess an image or a batch of images.""" return self.preprocess(images, **kwargs) @classmethod def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): max_width, max_height = size["width"], size["height"] resized_image = cv2.resize( image, (max_width, max_height), interpolation=interpolation ) resized_image = resized_image.transpose(2, 0, 1) return resized_image def process_inner(self, images: List[np.ndarray]): assert images[0].shape[2] == 3 # RGB input images, channel dim last if self.do_align_long_axis: # Rotate if the bbox is wider than it is tall images = [ SuryaEncoderImageProcessor.align_long_axis( image, size=self.max_size, input_data_format=ChannelDimension.LAST ) for image in images ] # Verify that the image is wider than it is tall for img in images: assert img.shape[1] >= img.shape[0] # This also applies the right channel dim format, to channel x height x width images = [ SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images ] assert images[0].shape[0] == 3 # RGB input images, channel dim first # Convert to float32 for rescale/normalize images = [img.astype(np.float32) for img in images] # Pads with 255 (whitespace) # Pad to max size to improve performance max_size = self.max_size images = [ SuryaEncoderImageProcessor.pad_image( image=image, size=max_size, input_data_format=ChannelDimension.FIRST, pad_value=settings.RECOGNITION_PAD_VALUE, ) for image in images ] # Rescale and normalize for idx in range(len(images)): images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype( np.float32 ) images = [ SuryaEncoderImageProcessor.normalize( img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST, ) for img in images ] return images def preprocess( self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> PIL.Image.Image: images = make_list_of_images(images) # Convert to numpy for later processing steps images = [np.array(img) for img in images] images = self.process_inner(images) data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) @classmethod def pad_image( cls, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, pad_value: float = 0.0, ) -> np.ndarray: output_height, output_width = size["height"], size["width"] input_height, input_width = get_image_size(image, channel_dim=input_data_format) delta_width = output_width - input_width delta_height = output_height - input_height assert delta_width >= 0 and delta_height >= 0 pad_top = delta_height // 2 pad_left = delta_width // 2 pad_bottom = delta_height - pad_top pad_right = delta_width - pad_left padding = ((pad_top, pad_bottom), (pad_left, pad_right)) return pad( image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value, ) @classmethod def align_long_axis( cls, image: np.ndarray, size: Dict[str, int], **kwargs ) -> np.ndarray: input_height, input_width = image.shape[:2] output_height, output_width = size["height"], size["width"] if (output_width < output_height and input_width > input_height) or ( output_width > output_height and input_width < input_height ): image = np.rot90(image, 3) return image @classmethod def normalize( cls, image: np.ndarray, mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: return normalize( image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs, ) ================================================ FILE: surya/common/load.py ================================================ from typing import Optional, Any import torch from surya.settings import settings class ModelLoader: def __init__(self, checkpoint: Optional[str] = None): self.checkpoint = checkpoint def model( self, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> Any: raise NotImplementedError() def processor( self, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, ) -> Any: raise NotImplementedError() ================================================ FILE: surya/common/polygon.py ================================================ import copy from typing import List, Optional import numpy as np from pydantic import BaseModel, field_validator, computed_field import numbers class PolygonBox(BaseModel): polygon: List[List[float]] confidence: Optional[float] = None @field_validator("polygon", mode="before") @classmethod def convert_bbox_to_polygon(cls, value): if isinstance(value, (list, tuple)) and len(value) == 4: if all(isinstance(x, numbers.Number) for x in value): value = [float(v) for v in value] x_min, y_min, x_max, y_max = value polygon = [ [x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max], ] return polygon elif all( isinstance(point, (list, tuple)) and len(point) == 2 for point in value ): value = [[float(v) for v in point] for point in value] return value elif isinstance(value, np.ndarray): if value.shape == (4, 2): return value.tolist() raise ValueError( f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]. All values must be numeric. You passed {value} of type {type(value)}. The first value is of type {type(value[0])}." ) @property def height(self): return self.bbox[3] - self.bbox[1] @property def width(self): return self.bbox[2] - self.bbox[0] @property def area(self): return self.width * self.height @computed_field @property def bbox(self) -> List[float]: x_coords = [point[0] for point in self.polygon] y_coords = [point[1] for point in self.polygon] return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] def rescale(self, processor_size, image_size): # Point is in x, y format page_width, page_height = processor_size img_width, img_height = image_size width_scaler = img_width / page_width height_scaler = img_height / page_height for corner in self.polygon: corner[0] = int(corner[0] * width_scaler) corner[1] = int(corner[1] * height_scaler) def round(self, divisor): for corner in self.polygon: corner[0] = int(corner[0] / divisor) * divisor corner[1] = int(corner[1] / divisor) * divisor def fit_to_bounds(self, bounds): new_corners = copy.deepcopy(self.polygon) for corner in new_corners: corner[0] = max(min(corner[0], bounds[2]), bounds[0]) corner[1] = max(min(corner[1], bounds[3]), bounds[1]) self.polygon = new_corners def merge(self, other): x1 = min(self.bbox[0], other.bbox[0]) y1 = min(self.bbox[1], other.bbox[1]) x2 = max(self.bbox[2], other.bbox[2]) y2 = max(self.bbox[3], other.bbox[3]) self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] def merge_left(self, other): x1 = min(self.bbox[0], other.bbox[0]) self.polygon[0][0] = x1 self.polygon[3][0] = x1 def merge_right(self, other): x2 = max(self.bbox[2], other.bbox[2]) self.polygon[1][0] = x2 self.polygon[2][0] = x2 def expand(self, x_margin: float, y_margin: float): new_polygon = [] x_margin = x_margin * self.width y_margin = y_margin * self.height for idx, poly in enumerate(self.polygon): if idx == 0: new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)]) elif idx == 1: new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)]) elif idx == 2: new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)]) elif idx == 3: new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)]) self.polygon = new_polygon def intersection_polygon(self, other) -> List[List[float]]: new_poly = [] for i in range(4): if i == 0: new_corner = [ max(self.polygon[0][0], other.polygon[0][0]), max(self.polygon[0][1], other.polygon[0][1]), ] elif i == 1: new_corner = [ min(self.polygon[1][0], other.polygon[1][0]), max(self.polygon[1][1], other.polygon[1][1]), ] elif i == 2: new_corner = [ min(self.polygon[2][0], other.polygon[2][0]), min(self.polygon[2][1], other.polygon[2][1]), ] elif i == 3: new_corner = [ max(self.polygon[3][0], other.polygon[3][0]), min(self.polygon[3][1], other.polygon[3][1]), ] new_poly.append(new_corner) return new_poly def intersection_area(self, other, x_margin=0, y_margin=0): x_overlap = self.x_overlap(other, x_margin) y_overlap = self.y_overlap(other, y_margin) return x_overlap * y_overlap def x_overlap(self, other, x_margin=0): return max( 0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin), ) def y_overlap(self, other, y_margin=0): return max( 0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin), ) def intersection_pct(self, other, x_margin=0, y_margin=0): assert 0 <= x_margin <= 1 assert 0 <= y_margin <= 1 if self.area == 0: return 0 if x_margin: x_margin = int(min(self.width, other.width) * x_margin) if y_margin: y_margin = int(min(self.height, other.height) * y_margin) intersection = self.intersection_area(other, x_margin, y_margin) return intersection / self.area def shift(self, x_shift: float | None = None, y_shift: float | None = None): if x_shift is not None: for corner in self.polygon: corner[0] += x_shift if y_shift is not None: for corner in self.polygon: corner[1] += y_shift def clamp(self, bbox: List[float]): for corner in self.polygon: corner[0] = max(min(corner[0], bbox[2]), bbox[0]) corner[1] = max(min(corner[1], bbox[3]), bbox[1]) @property def center(self): return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2] def distance(self, other): center = self.center other_center = other.center return ( (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2 ) ** 0.5 def __hash__(self): return hash(tuple(self.bbox)) ================================================ FILE: surya/common/predictor.py ================================================ from typing import Optional import torch import torch.nn.functional as F from surya.common.load import ModelLoader from surya.settings import settings class BasePredictor: model_loader_cls = ModelLoader batch_size: Optional[int] = None default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1} torch_dtype = settings.MODEL_DTYPE @property def disable_tqdm(self) -> bool: return self._disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) def __init__( self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = None, attention_implementation: Optional[str] = None, ): if dtype is None: dtype = self.torch_dtype self.model = None self.processor = None loader = self.model_loader_cls(checkpoint) self.model = loader.model(device, dtype, attention_implementation) self.processor = loader.processor() self._disable_tqdm = settings.DISABLE_TQDM def to(self, device_dtype: torch.device | str | None = None): model_moved = False if hasattr(self, "model") and self.model: self.model.to(device_dtype) model_moved = True if hasattr(self, "foundation_predictor") and self.foundation_predictor: self.foundation_predictor.model.to(device_dtype) model_moved = True if not model_moved: raise ValueError("Model not loaded") def get_batch_size(self): batch_size = self.batch_size if batch_size is None: batch_size = self.default_batch_sizes["cpu"] if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes: batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL] return batch_size @staticmethod def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor if len(tensor.shape) == 1: # If tensor is 1D, we need to pad it to the batch size pad_size = batch_size - current_batch_size return F.pad(tensor, (0, pad_size), mode="constant", value=0) pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) return F.pad(tensor, padding, mode="constant", value=0) def __call__(self, *args, **kwargs): raise NotImplementedError() ================================================ FILE: surya/common/pretrained.py ================================================ from typing import Optional from transformers import PreTrainedModel from transformers.utils import is_flash_attn_2_available class SuryaPreTrainedModel(PreTrainedModel): # No-op if we pass attention, so we can set attention however we want in the config def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], **kwargs ): if attn_implementation is None: try: self._sdpa_can_dispatch(True) attn_implementation = "sdpa" except (ValueError, ImportError): attn_implementation = "eager" if self._supports_flash_attn and is_flash_attn_2_available(): attn_implementation = "flash_attention_2" return attn_implementation ================================================ FILE: surya/common/s3.py ================================================ import json import os import shutil import tempfile import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path import requests from tqdm import tqdm from surya.logging import get_logger from surya.settings import settings logger = get_logger() # Lock file expiration time in seconds (10 minutes) LOCK_EXPIRATION = 600 def join_urls(url1: str, url2: str): url1 = url1.rstrip("/") url2 = url2.lstrip("/") return f"{url1}/{url2}" def get_model_name(pretrained_model_name_or_path: str): return pretrained_model_name_or_path.split("/")[0] def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024): local_path = Path(local_path) try: response = requests.get(remote_path, stream=True, allow_redirects=True) response.raise_for_status() # Raise an exception for bad status codes # Get file size from headers for progress bar total_size = int(response.headers.get('content-length', 0)) # Create progress bar with file name and size info filename = local_path.name pbar = tqdm( total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"Downloading {filename}", miniters=1 ) with open(local_path, "wb") as f: downloaded = 0 for chunk in response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) downloaded += len(chunk) pbar.update(len(chunk)) pbar.close() return local_path except Exception as e: if local_path.exists(): local_path.unlink() logger.error(f"Download error for file {remote_path}: {str(e)}") raise def check_manifest(local_dir: str): local_dir = Path(local_dir) manifest_path = local_dir / "manifest.json" if not os.path.exists(manifest_path): return False try: with open(manifest_path, "r") as f: manifest = json.load(f) for file in manifest["files"]: if not os.path.exists(local_dir / file): return False except Exception: return False return True def download_directory(remote_path: str, local_dir: str): model_name = get_model_name(remote_path) s3_url = join_urls(settings.S3_BASE_URL, remote_path) # Check to see if it's already downloaded model_exists = check_manifest(local_dir) if model_exists: return # Use tempfile.TemporaryDirectory to automatically clean up with tempfile.TemporaryDirectory() as temp_dir: # Download the manifest file manifest_file = join_urls(s3_url, "manifest.json") manifest_path = os.path.join(temp_dir, "manifest.json") download_file(manifest_file, manifest_path) # List and download all files with open(manifest_path, "r") as f: manifest = json.load(f) pbar = tqdm( desc=f"Downloading {model_name} model to {local_dir}", total=len(manifest["files"]), ) with ThreadPoolExecutor( max_workers=settings.PARALLEL_DOWNLOAD_WORKERS ) as executor: futures = [] for file in manifest["files"]: remote_file = join_urls(s3_url, file) local_file = os.path.join(temp_dir, file) futures.append(executor.submit(download_file, remote_file, local_file)) for future in futures: future.result() pbar.update(1) pbar.close() # Move all files to new directory for file in os.listdir(temp_dir): shutil.move(os.path.join(temp_dir, file), local_dir) class S3DownloaderMixin: s3_prefix = "s3://" @classmethod def get_local_path(cls, pretrained_model_name_or_path) -> str: if pretrained_model_name_or_path.startswith(cls.s3_prefix): pretrained_model_name_or_path = pretrained_model_name_or_path.replace( cls.s3_prefix, "" ) cache_dir = settings.MODEL_CACHE_DIR local_path = os.path.join(cache_dir, pretrained_model_name_or_path) os.makedirs(local_path, exist_ok=True) else: local_path = "" return local_path @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): # Allow loading models directly from the hub, or using s3 if not pretrained_model_name_or_path.startswith(cls.s3_prefix): return super().from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) local_path = cls.get_local_path(pretrained_model_name_or_path) pretrained_model_name_or_path = pretrained_model_name_or_path.replace( cls.s3_prefix, "" ) # Retry logic for downloading the model folder retries = 3 delay = 5 attempt = 0 success = False while not success and attempt < retries: try: download_directory(pretrained_model_name_or_path, local_path) success = True # If download succeeded except Exception as e: logger.error( f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}" ) attempt += 1 if attempt < retries: logger.info(f"Retrying in {delay} seconds...") time.sleep(delay) # Wait before retrying else: logger.error( f"Failed to download {pretrained_model_name_or_path} after {retries} attempts." ) raise e # Reraise exception after max retries return super().from_pretrained(local_path, *args, **kwargs) ================================================ FILE: surya/common/surya/__init__.py ================================================ import warnings from typing import Optional, Tuple, TypedDict from dataclasses import dataclass import torch from torch import nn import torch.nn.functional as F from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.common.surya.config import SuryaModelConfig from surya.common.surya.decoder import SuryaDecoderModel from surya.common.surya.embedder import SimpleTokenEmbedder from surya.common.surya.encoder import SuryaEncoderModel from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat from surya.common.xla import get_nearest_pad from surya.settings import settings from surya.logging import get_logger logger = get_logger() @dataclass class SuryaModelOutput(CausalLMOutputWithPast): bbox_logits: torch.FloatTensor = None lm_logits: torch.FloatTensor = None class FlashAttentionKwargs(TypedDict, total=False): """ Keyword arguments for Flash Attention with Compile. Attributes: cu_seq_lens_q (`torch.LongTensor`, *optional*) Gets cumlative sequence length for query state. cu_seq_lens_k (`torch.LongTensor`, *optional*) Gets cumlative sequence length for key state. max_length_q (`int`, *optional*): Maximum sequence length for query state. max_length_k (`int`, *optional*): Maximum sequence length for key state. """ cu_seq_lens_q: Optional[torch.LongTensor] cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] class KwargsForCausalLM(FlashAttentionKwargs): ... class DistanceProjection(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.fc1 = nn.Linear(in_features, out_features) self.act = nn.SiLU() self.fc2 = nn.Linear(out_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x def init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.zeros_(self.fc1.bias) nn.init.zeros_(self.fc2.bias) class BboxHead(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.proj_layers = nn.ModuleList( [nn.Linear(in_features, in_features) for _ in range(6)] ) self.act = nn.SiLU() self.out_proj = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.proj_layers: x = layer(x) x = self.act(x) x = self.out_proj(x) return x class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): config_class = SuryaModelConfig supports_gradient_checkpointing = True _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True main_input_name = "input_ids" _tied_weights_keys = ["lm_head.weight"] def __init__( self, config: SuryaModelConfig, embedder: SimpleTokenEmbedder = None, vision_encoder: SuryaEncoderModel = None, decoder: SuryaDecoderModel = None, **kwargs, ): super().__init__(config, **kwargs) if vision_encoder is None: vision_encoder = SuryaEncoderModel(config.vision_encoder) if decoder is None: decoder = SuryaDecoderModel(config.decoder) if embedder is None: embedder = SimpleTokenEmbedder(config) self.vision_encoder = vision_encoder self.decoder = decoder self.embedder = embedder # Simple encoding for image patches self.img_w_embed = nn.Embedding( self.config.image_embed_encoding_size, self.config.hidden_size, ) self.img_h_embed = nn.Embedding( self.config.image_embed_encoding_size, self.config.hidden_size, ) # Tying configs self.vision_encoder.config = self.config.vision_encoder self.decoder.config = self.config.decoder self.bbox_head = BboxHead(config.hidden_size, 6) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) if ( self.config.multi_output_distance is not None and self.config.multi_output_distance > 0 ): self.multi_output_projections = nn.ModuleList( [ DistanceProjection( in_features=config.hidden_size, out_features=config.hidden_size ) for _ in range(self.config.multi_output_distance) ] ) def tie_weights(self): self._tie_weights() def _tie_weights(self): # Tie weights of lm head and token embedder self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed) def get_output_embeddings(self) -> nn.Module: return self.lm_head def get_input_embeddings(self) -> nn.Module: return self.embedder.token_embed def set_output_embeddings(self, new_embeddings: nn.Module): self.lm_head = new_embeddings def set_input_embeddings(self, new_embeddings: nn.Module): self.embedder.token_embed = new_embeddings def maybe_static_pad_image_inputs( self, chunk_pixels: torch.Tensor, chunk_grid_thw: torch.Tensor, actual_chunk_len: int, encoder_chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: valid_embed_len = actual_chunk_len // ( self.vision_encoder.spatial_merge_size**2 ) if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size: padding_len = encoder_chunk_size - actual_chunk_len chunk_pixels = F.pad( chunk_pixels, (0, 0, 0, padding_len), mode="constant", value=0.0, ) padding_grid = torch.tensor( [[1, 2, padding_len // 2]], device=chunk_grid_thw.device, dtype=chunk_grid_thw.dtype, ) chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0) return chunk_pixels, chunk_grid_thw, valid_embed_len def get_image_embeddings( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, max_batch_size: int | None = None, ): # embed all images with the vision encoder after they have already been tiled and flattened into a single batch chunks = [0] grid_chunks = [0] curr_chunk_len = 0 curr_seq_len = 0 for i in range(len(grid_thw)): curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item() if curr_chunk_len > encoder_chunk_size: chunks.append(curr_chunk_len + curr_seq_len) curr_seq_len += curr_chunk_len curr_chunk_len = 0 grid_chunks.append(i + 1) if curr_chunk_len > 0: chunks.append(pixel_values.shape[0]) grid_chunks.append(len(grid_thw)) assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], ( f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}" ) logger.debug( f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}" ) embeddings = [] for i in range(len(chunks) - 1): start = chunks[i] end = chunks[i + 1] grid_start = grid_chunks[i] grid_end = grid_chunks[i + 1] chunk_pixels = pixel_values[start:end] chunk_grid_thw = grid_thw[grid_start:grid_end] actual_chunk_len = end - start chunk_pixels, chunk_grid_thw, valid_embed_len = ( self.maybe_static_pad_image_inputs( chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size ) ) chunk_embeddings = self.vision_encoder.embed_images( image_batch=chunk_pixels.unsqueeze(0).to(device=self.device), grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device), ) embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0)) if len(embeddings) == 0: raise ValueError( "No image embeddings were generated. Check the input images and grid sizes." ) elif len(embeddings) == 1: embeddings = embeddings[0] else: embeddings = torch.cat(embeddings, dim=0) encoding_2d = self.get_2d_learned_embeddings( grid_thw, device=embeddings.device, bbox_size=self.config.image_embed_encoding_multiplier, ) assert embeddings.shape[0] == encoding_2d.shape[0], ( f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}" ) assert embeddings.shape[1] == encoding_2d.shape[1], ( f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}" ) embeddings = embeddings + encoding_2d return embeddings def embed_ids_boxes_images( self, input_ids, image_embeddings, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, input_boxes: torch.Tensor | None = None, embed_boxes: torch.Tensor | None = None, ): """ Insert embedded image tiles into the corresponding positions into the full input sequence Positions to insert new tokens are indicated by the special image token index """ # This is batched in the inner call inputs_embeds = self.embedder.embed( input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes ) if image_embeddings is not None: special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds) if inputs_embeds[special_image_mask].numel() != image_embeddings.numel(): n_image_tokens = torch.sum((input_ids == self.config.image_token_id)) n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1] warnings.warn( f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results" ) image_features = image_embeddings.to(inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter( special_image_mask, image_features ) else: assert (input_ids == self.config.image_token_id).sum() == 0, ( "Image tokens were present in the input but no input images were provided" ) return inputs_embeds def get_2d_learned_embeddings( self, grid_thw, device: str | torch.device = "cpu", bbox_size: int = 256, ): all_embeddings = [] for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.config.merge_size, grid_w // self.config.merge_size, ) # Scale to 0-1024 llm_grid_h = ( torch.arange(llm_grid_h, device=device) / max(1, (llm_grid_h - 1)) * bbox_size ) llm_grid_w = ( torch.arange(llm_grid_w, device=device) / max(1, (llm_grid_w - 1)) * bbox_size ) llm_grid_w_idx = llm_grid_w.to(torch.long) llm_grid_h_idx = llm_grid_h.to(torch.long) llm_grid_w = self.img_w_embed(llm_grid_w_idx) llm_grid_h = self.img_h_embed(llm_grid_h_idx) full_grid = llm_grid_h[:, None] + llm_grid_w[None, :] flattened = full_grid.flatten( 0, 1 ) # Flatten first dimension, so they are seq_len x embed_dim all_embeddings.append(flattened) return torch.concat( all_embeddings, dim=0 ) # Shape is num_image_tokens x embed_dim def get_logits(self, hidden_states): assert hidden_states.shape[1] == 1, ( "Multi output predictions only applied on the last token" ) all_lm_logits = [] all_bbox_logits = [] current_hidden = hidden_states # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions for i in range(self.config.multi_output_distance + 1): if i > 0: current_hidden = self.multi_output_projections[i - 1](current_hidden) lm_logits = self.lm_head(current_hidden) bbox_logits = F.sigmoid(self.bbox_head(current_hidden)) all_lm_logits.append(lm_logits) all_bbox_logits.append(bbox_logits) # Concatenate along sequence dimension (dim=1) final_lm_logits = torch.cat(all_lm_logits, dim=1) final_bbox_logits = torch.cat(all_bbox_logits, dim=1) return final_lm_logits, final_bbox_logits def forward( self, input_ids=None, image_embeddings=None, labels=None, image_tiles=None, grid_thw=None, inputs_embeds=None, attention_mask=None, position_ids=None, cache_position=None, past_key_values=None, output_hidden_states=False, output_attentions=False, use_cache=False, encoder_chunk_size=32768, cache_idxs=None, num_valid_tokens=None, prefill=True, text_lengths=None, valid_batch_size: torch.Tensor = None, input_boxes=None, embed_boxes=None, logits_to_keep=None, **kwargs: KwargsForCausalLM, ): if any([ input_ids is None, position_ids is None, cache_position is None, ( prefill and not ( (image_tiles is not None and grid_thw is not None) or image_embeddings is not None ) ), ]): raise ValueError( "`input_ids`, `position_ids`, and `cache_position` **must** be specified. " "For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`." ) inputs_embeds = self.embed_ids_boxes_images( input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes ) # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder # Skipped during decoding since not required if self.decoder.config._attn_implementation == "flash_attention_2" and prefill: # Needed for CPU -> GPU from surya.common.surya.flash_attn_utils import _get_unpad_data batch_size, query_length, _ = inputs_embeds.shape indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask ) kwargs["batch_size"] = batch_size kwargs["query_length"] = query_length kwargs["indices_k"] = indices_k kwargs["cu_seqlens_k"] = cu_seqlens_k kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) attention_mask = causal_mask outputs = self.decoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=True, use_cache=use_cache, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, prefill=prefill, text_lengths=text_lengths, **kwargs, ) hidden_states = outputs.last_hidden_state if logits_to_keep is not None: hidden_states = hidden_states[:, -logits_to_keep:, :] hidden_states = hidden_states.contiguous() loss = None if labels is not None: # Training, return full logits lm_logits = self.lm_head(hidden_states) bbox_logits = None vocab_size = lm_logits.shape[-1] labels = torch.roll(labels, shifts=-1, dims=-1) loss = F.cross_entropy( lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean" ) else: lm_logits, bbox_logits = self.get_logits(hidden_states) return SuryaModelOutput( loss=loss, bbox_logits=bbox_logits, lm_logits=lm_logits, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions if output_attentions else None, past_key_values=outputs.past_key_values, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): if self.decoder.config._attn_implementation == "flash_attention_2": return attention_mask # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_key_values.max_cache_len ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: SuryaModelConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`. batch_size (`torch.Tensor`): Batch size. config (`Qwen2Config`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) # Batch-aware diagonal attend mask diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze( 0 ) > cache_position.unsqueeze(-1) causal_mask = ( causal_mask.unsqueeze(0) * diagonal_attend_mask ) # (batch_size, seq_len, target_len) causal_mask = causal_mask[ :, None, :, : ] # (batch_size, 1, seq_len, target_len) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ :, None, None, : ].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) return causal_mask class SuryaXLAModel(SuryaModel): def get_image_embeddings( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, max_batch_size: int | None = None, ): # embed all images with the vision encoder after they have already been tiled and flattened into a single batch unpadded_max_grid_size = ( (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item() ) max_grid_size = get_nearest_pad( unpadded_max_grid_size, ) # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw # Always need 2 items in each row batch if max_grid_size == unpadded_max_grid_size: max_grid_size += 16 full_image_grid = torch.zeros( (valid_batch_size, max_grid_size, pixel_values.shape[-1]), dtype=pixel_values.dtype, ) # Roll out into a full grid seq_len = 0 row_grids = [] for i in range(valid_batch_size): curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2] full_image_grid[i, -curr_sample_len:] = pixel_values[ seq_len : seq_len + curr_sample_len ] padded_len = max_grid_size - curr_sample_len if padded_len > 0: row_grid = torch.tensor( [ [1, 4, padded_len // 4], grid_thw[i].tolist(), ], dtype=torch.long, ) else: row_grid = torch.tensor( [ grid_thw[i].tolist(), ], dtype=torch.long, ) row_grids.append(row_grid) seq_len += curr_sample_len # bsz, 2, 3 row_grids = torch.stack(row_grids, dim=0) if settings.FOUNDATION_STATIC_CACHE: # Pad to max batch size, repeat the final row row_grids = pad_to_batch_size_repeat( row_grids, batch_size=max_batch_size, ) full_image_grid = pad_to_batch_size( full_image_grid, batch_size=max_batch_size, ) full_image_grid = full_image_grid.to(self.device) embeddings = self.vision_encoder.embed_images( image_batch=full_image_grid, grid_thw=row_grids.to(self.device) ) encoding_2d = self.get_2d_learned_embeddings( row_grids, bbox_size=self.config.image_embed_encoding_multiplier, ) embeddings += encoding_2d return embeddings def embed_ids_boxes_images( self, input_ids, image_embeddings, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, input_boxes: torch.Tensor | None = None, embed_boxes: torch.Tensor | None = None, ): """ Insert embedded image tiles into the corresponding positions into the full input sequence Positions to insert new tokens are indicated by the special image token index """ # This is batched in the inner call inputs_embeds = self.embedder.embed( input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes ) if image_embeddings is not None: image_token_id_tensor = torch.tensor( self.config.image_token_id, device=inputs_embeds.device, dtype=torch.long, ) mask = input_ids == image_token_id_tensor last_image_token_pos = ( mask.size(1) - 1 - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True) ) # Calculate start position to replace N positions ending at (and including) the last image token start_positions = last_image_token_pos - image_embeddings[0].shape[0] batch_size, insert_len = image_embeddings.shape[:2] # Create position indices for each insertion pos_indices = torch.arange( insert_len, device=inputs_embeds.device ).unsqueeze(0) insert_positions = start_positions + pos_indices idx = insert_positions.unsqueeze(-1).expand( -1, -1, inputs_embeds.size(-1) ) # [B,N,D] inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings) inputs_embeds = inputs_embeds * ( input_ids != self.config.pad_token_id ).unsqueeze(-1).to(inputs_embeds.dtype) return inputs_embeds def get_2d_learned_embeddings( self, grid_thw, bbox_size: int = 256, ): dev = grid_thw.device all_row_coords = [] all_col_coords = [] for row_grid in grid_thw: merge = self.config.merge_size # per-sample grid sizes after merge H = (row_grid[:, 1] // merge).long() # (B,) W = (row_grid[:, 2] // merge).long() # (B,) row_coords = torch.cat( [ torch.linspace(0, bbox_size, steps=int(h), device=dev) .round() .repeat_interleave(w) # repeat each row value w times for h, w in zip(H.tolist(), W.tolist()) ] ) # (full_grid_size,) col_coords = torch.cat( [ torch.linspace(0, bbox_size, steps=int(w), device=dev) .round() .repeat(int(h)) # tile the column vector h times for h, w in zip(H.tolist(), W.tolist()) ] ) # (full_grid_size,) all_row_coords.append(row_coords) all_col_coords.append(col_coords) row_coords = torch.stack(all_row_coords, dim=0).to(self.device) col_coords = torch.stack(all_col_coords, dim=0).to(self.device) emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long()) return emb ================================================ FILE: surya/common/surya/config.py ================================================ from typing import Optional from transformers import PretrainedConfig from surya.common.s3 import S3DownloaderMixin from surya.common.surya.encoder.config import SuryaEncoderConfig from surya.common.surya.decoder.config import SuryaDecoderConfig class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig): model_type = "surya-multimodal-foundation" is_composition = True def __init__( self, vocab_size=65536, bbox_size=1025, blank_bbox_token_id=1025, bos_token_id=0, eos_token_id=1, pad_token_id=2, image_token_id=3, register_token_ids=(4, 5, 6, 7), eoi_token_id=8, beacon_token_id=9, special_token_count=4, max_sequence_length=1536, special_ocr_tokens=None, vision_encoder=None, decoder=None, tasks: dict | None = None, bbox_embed_size: int = 64, num_register_tokens: int = 4, image_embed_encoding_size: int = 1024, image_embed_encoding_multiplier: int = 256, num_beacon_tokens: int = 1, beacon_token_interval: int = 4096, sliding_window: Optional[int] = None, multi_output_distance: int = 4, max_multi_out: int = 8, **kwargs, ): super().__init__(**kwargs) self.is_encoder_decoder = False self.vocab_size = vocab_size self.bbox_size = bbox_size self.blank_bbox_token_id = blank_bbox_token_id self.image_token_id = image_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.eoi_token_id = eoi_token_id self.beacon_token_id = beacon_token_id self.special_ocr_tokens = special_ocr_tokens self.special_token_count = special_token_count # pad, bos, etc, tokens self.max_sequence_length = max_sequence_length self.tasks = tasks self.tie_word_embeddings = True self.bbox_embed_size = bbox_embed_size self.num_register_tokens = num_register_tokens self.register_token_ids = register_token_ids self.image_embed_encoding_size = image_embed_encoding_size self.image_embed_encoding_multiplier = image_embed_encoding_multiplier self.num_beacon_tokens = num_beacon_tokens self.beacon_token_interval = beacon_token_interval self.sliding_window = sliding_window self.multi_output_distance = multi_output_distance self.max_multi_out = max_multi_out if self.sliding_window is None: self.sliding_window = self.max_sequence_length if isinstance(vision_encoder, dict): vision_encoder = SuryaEncoderConfig(**vision_encoder) elif vision_encoder is None: vision_encoder = SuryaEncoderConfig() self.vision_encoder = vision_encoder if isinstance(decoder, dict): decoder = SuryaDecoderConfig(**decoder) elif decoder is None: decoder = SuryaDecoderConfig() self.decoder = decoder self.hidden_size = self.decoder.hidden_size self.patch_size = self.vision_encoder.spatial_patch_size self.merge_size = self.vision_encoder.spatial_merge_size ================================================ FILE: surya/common/surya/decoder/__init__.py ================================================ from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import ( Cache, ) from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.processing_utils import Unpack from transformers.utils import ( logging, ) from surya.common.pretrained import SuryaPreTrainedModel from surya.common.surya.decoder.config import SuryaDecoderConfig logger = logging.get_logger(__name__) class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query.dtype ) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Qwen2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: SuryaDecoderConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=True ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=False ) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism cache_kwargs = { "sin": sin, "cos": cos, "cache_position": cache_position, "cache_idxs": cache_idxs, "num_valid_tokens": num_valid_tokens, "prefill": prefill, "text_lengths": text_lengths, } key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get( "output_attentions", False ): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) elif self.config._attn_implementation == "flash_attention_2": # Needed for CPU -> GPU from surya.common.surya.flash_attn_utils import ( flash_attn_decode, flash_attn_prefill, ) if prefill: attention_interface = flash_attn_prefill else: attention_interface = flash_attn_decode else: attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] """ IMPORTANT: We sometimes use a custom sliding window impl. during training We force this to None to ensure that the HF attention integrations do not perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead to infer the final mask SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23) """ sliding_window = None attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, # main diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2DecoderLayer(nn.Module): def __init__(self, config: SuryaDecoderConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, text_lengths=text_lengths, prefill=prefill, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config: SuryaDecoderConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Qwen2PreTrainedModel(SuryaPreTrainedModel): config_class = SuryaDecoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class SuryaDecoderModel(Qwen2PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input Args: config: Qwen2Config """ def __init__(self, config: SuryaDecoderConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.layers = nn.ModuleList( [ Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: raise ValueError("You must specify inputs_embeds") if cache_position is None: raise ValueError("You must specify cache_position") if position_ids is None: raise ValueError("You must specify position_ids") hidden_states = inputs_embeds causal_mask = ( attention_mask # We make the 4D mask in the combined model when needed ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers for decoder_layer in self.layers[: self.config.num_hidden_layers]: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, prefill=prefill, text_lengths=text_lengths, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) output = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) return output if return_dict else output.to_tuple() ================================================ FILE: surya/common/surya/decoder/config.py ================================================ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging logger = logging.get_logger(__name__) class SuryaDecoderConfig(PretrainedConfig): model_type = "qwen2" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen2` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = False # Disable sliding window self.sliding_window = ( sliding_window # we check `use_sliding_window` in the modeling code ) self.max_window_layers = max_window_layers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) ================================================ FILE: surya/common/surya/embedder/__init__.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class SimpleTokenEmbedder(nn.Module): def __init__(self, config): super().__init__() self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) self.bbox_embed = nn.ModuleList( [ nn.Embedding( config.bbox_size + config.special_token_count, config.bbox_embed_size, ) for _ in range(6) ] ) self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1 self.max_bbox_size = config.bbox_size def embed( self, input_tokens: torch.Tensor, input_boxes: torch.Tensor | None, embed_boxes: torch.Tensor, ) -> torch.Tensor: # Embed tokens token_embeds = self.token_embed(input_tokens) # Optionally embed boxes if input_boxes is not None and embed_boxes.any(): # Is none in prefill input_boxes = input_boxes.to(torch.long) bbox_loss_ignore_mask = ( (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size) ).unsqueeze(-1) input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding) bbox_embeds = torch.sum( torch.stack( [ self.bbox_embed[i](input_boxes[:, :, i]) for i in range(len(self.bbox_embed)) ], dim=-1, ), dim=-1, ) bbox_embeds = F.pad( bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0) ) embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds) bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds) mask = embed_boxes & ~bbox_loss_ignore_mask bbox_embeds *= mask.float() token_embeds = token_embeds + bbox_embeds return token_embeds ================================================ FILE: surya/common/surya/encoder/__init__.py ================================================ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN from surya.common.pretrained import SuryaPreTrainedModel from surya.common.surya.encoder.config import SuryaEncoderConfig from surya.common.xla import get_nearest_pad from surya.logging import get_logger from surya.settings import settings if settings.FOUNDATION_XLA: import torch_xla.experimental.custom_kernel from surya.logging import get_logger logger = get_logger() class Qwen2_5_VLMLP(nn.Module): def __init__(self, config, bias: bool = False): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.down_proj( self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) ) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.in_channels = in_channels self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d( in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype bsz = hidden_states.shape[0] hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size, ) hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( bsz, -1, self.embed_dim ) return hidden_states class Qwen2_5_VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.inv_freq = 1.0 / ( theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) ) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2_5_VLPatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: bsz = x.shape[0] x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size)) return x def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: from flash_attn.layers.rotary import apply_rotary_emb cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class Qwen2_5_VLVisionXLASdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) self.head_dim = dim // num_heads def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool) cu_seqlens_cpu = cu_seqlens.cpu() for j in range(bsz): batch_seqlens = cu_seqlens_cpu[j] for i in range(1, len(batch_seqlens)): attention_mask[ j, ..., batch_seqlens[i - 1] : batch_seqlens[i], batch_seqlens[i - 1] : batch_seqlens[i], ] = True attention_mask = attention_mask.to(q.device) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_output = F.scaled_dot_product_attention( q, k, v, attention_mask, dropout_p=0.0, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) self.head_dim = dim // num_heads def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: # Note, this is faster than SDPA, but pretty memory inefficient # It also has significant accuracy issues bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] # Single reshape to target layout - avoid multiple operations q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) # Apply rotary embeddings if provided if position_embeddings is not None: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim] q = q.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim] k = k.transpose(1, 2) v = v.transpose(1, 2) total_seqlen = q.shape[2] # from cu_seqlens to segment ids for each position in dim 0 additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype) min_val = torch.finfo(q.dtype).min for i in range(bsz): padding_end = cu_seqlens[i][1].item() additive_bias[i, :, :, :padding_end] = min_val additive_bias = additive_bias.to(hidden_states.device) attn_scale = 1 / math.sqrt(self.head_dim) attn_output = torch_xla.experimental.custom_kernel.flash_attention( q, k, v, sm_scale=attn_scale, ab=additive_bias ) attn_output = ( attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1) ) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: from flash_attn import flash_attn_varlen_func bsz = hidden_states.shape[0] seq_length = hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0)) q = q.squeeze(0) k = k.squeeze(0) v = v.squeeze(0) cu_seqlens = cu_seqlens.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed class Qwen2_5_VLVisionAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.full( [bsz, 1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype, ) for j in range(bsz): batch_seqlens = cu_seqlens[j] for i in range(1, len(batch_seqlens)): attention_mask[ j, ..., batch_seqlens[i - 1] : batch_seqlens[i], batch_seqlens[i - 1] : batch_seqlens[i], ] = 0 q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q.dtype) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionSdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def unpack_qkv_with_mask(self, q, k, v, cu_seqlens): """ Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask. Args: q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim) cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths Returns: batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len) with 0 for valid tokens and -inf for padding (for additive attention) """ device = q.device dtype = q.dtype batch_size = cu_seqlens.shape[0] - 1 num_heads = q.shape[1] head_dim = q.shape[2] seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] # Keep as tensor max_seq_len = seq_lengths.max().item() # Use .max() on tensor if settings.FOUNDATION_STATIC_CACHE: # Pad max_seq_len to the nearest multiple for compilation max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16) # Pad batch_size to the nearest multiple for compilation batch_size = get_nearest_pad(batch_size, pad_multiple=2) # Ensure seq_lengths is a tensor of the correct size seq_lengths = F.pad( seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0 ) # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"? # don't do this - it's a path to madness. For some reason, this loop is optimal batch_indices = [] position_indices = [] for i, seq_len in enumerate( seq_lengths.tolist() ): # Convert to list only for iteration batch_indices.extend([i] * seq_len) position_indices.extend(list(range(seq_len))) batch_indices = torch.tensor(batch_indices, device=device) position_indices = torch.tensor(position_indices, device=device) batched_q = torch.zeros( (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype ) batched_k = torch.zeros_like(batched_q) batched_v = torch.zeros_like(batched_q) # Create additive attention mask attention_mask = torch.full( (batch_size, max_seq_len, max_seq_len), fill_value=float("-inf"), device=device, dtype=dtype, ) # Create mask for valid positions seq_range = torch.arange(max_seq_len, device=device) valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze( 1 ) # (batch_size, max_seq_len) valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze( 1 ) # (batch_size, max_seq_len, max_seq_len) # Simply use boolean indexing to set valid positions to 0 attention_mask[valid_2d] = 0 attention_mask = attention_mask.unsqueeze( 1 ) # (batch_size, 1, max_seq_len, max_seq_len) batched_q[batch_indices, position_indices] = q batched_k[batch_indices, position_indices] = k batched_v[batch_indices, position_indices] = v return ( batched_q, batched_k, batched_v, attention_mask, batch_indices, position_indices, ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states.squeeze(0) cu_seqlens = cu_seqlens.squeeze(0) seq_length = hidden_states.shape[0] q, k, v = ( self.qkv(hidden_states) .reshape(seq_length, 3, self.num_heads, -1) .permute(1, 0, 2, 3) .unbind(0) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) q = q.squeeze(0) k = k.squeeze(0) q, k, v, attention_mask, batch_indices, position_indices = ( self.unpack_qkv_with_mask(q, k, v, cu_seqlens) ) batch_size, max_seqlen = q.shape[:2] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_output = F.scaled_dot_product_attention( q, k, v, attention_mask, dropout_p=0.0, ) attn_output = attn_output.permute(0, 2, 1, 3).reshape( batch_size, max_seqlen, -1 ) # Bring back to (batch_size, max_seqlen, hidden_dim) attn_output = attn_output[batch_indices, position_indices] attn_output = self.proj(attn_output) return attn_output.unsqueeze(0) QWEN2_5_VL_VISION_ATTENTION_CLASSES = { "eager": Qwen2_5_VLVisionAttention, "flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2 if settings.FOUNDATION_XLA else Qwen2_5_VLVisionFlashAttention2, "sdpa": Qwen2_5_VLVisionXLASdpaAttention if settings.FOUNDATION_XLA else Qwen2_5_VLVisionSdpaAttention, } class Qwen2_5_VLVisionBlock(nn.Module): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( config.hidden_size, num_heads=config.num_heads ) self.mlp = Qwen2_5_VLMLP(config, bias=True) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states Qwen2_5_VL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen2_5_VLConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel): config_class = SuryaEncoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config_class = SuryaEncoderConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes self.window_size = config.window_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, embed_dim=config.hidden_size, ) head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth) ] ) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, context_dim=config.hidden_size, spatial_merge_size=config.spatial_merge_size, ) self.gradient_checkpointing = False def rot_pos_emb(self, grid_thw): rotary_pos_emb = [] grid_thw_list = grid_thw.cpu().tolist() for batch_item in grid_thw_list: row_pos_ids = [] heights = [h for _, h, _ in batch_item] widths = [w for _, _, w in batch_item] max_grid_size = max(heights + widths) for t, h, w in batch_item: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() # shape: token_count, 2 row_pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) ) # shape: token_count, 2 pos_ids = torch.cat(row_pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb.append(rotary_pos_emb_row) rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0) return rotary_pos_emb def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. Returns: `torch.Tensor`: hidden_states. """ bsz, seq_len, _ = hidden_states.size() hidden_states = self.patch_embed(hidden_states) # (bsz, seq_len, hidden_dim) rotary_pos_emb = self.rot_pos_emb(grid_thw) # hidden_states = hidden_states.reshape(bsz, seq_len, -1) # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to( hidden_states.device ) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum( dim=1, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( blk.__call__, hidden_states, cu_seqlens, None, position_embeddings, ) else: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, ) hidden_states = self.merger(hidden_states) return hidden_states class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel): @property def image_size(self) -> int: config: SuryaEncoderConfig = self.config if isinstance(config.image_size, tuple) and len(config.image_size) == 2: return config.image_size elif isinstance(config.image_size, int): return (config.image_size, config.image_size) raise ValueError( f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}" ) @property def hidden_size(self) -> int: config: SuryaEncoderConfig = self.config return config.hidden_size def embed_images( self, image_batch: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: return super().forward( hidden_states=image_batch, grid_thw=grid_thw, ) ================================================ FILE: surya/common/surya/encoder/config.py ================================================ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class SuryaEncoderConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "depth", } def __init__( self, depth=8, hidden_size=1280, hidden_act="silu", intermediate_size=3420, num_heads=16, in_channels=3, patch_size=14, spatial_merge_size=2, spatial_patch_size=14, temporal_patch_size=1, tokens_per_second=4, window_size=112, out_hidden_size=1280, fullatt_block_indexes=(3, 7), initializer_range=0.02, image_size=4096, **kwargs, ): super().__init__(**kwargs) self.depth = depth self.hidden_size = hidden_size self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.num_heads = num_heads self.in_channels = in_channels self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size self.tokens_per_second = tokens_per_second self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size self.initializer_range = initializer_range self.spatial_patch_size = spatial_patch_size self.image_size = image_size ================================================ FILE: surya/common/surya/flash_attn_utils.py ================================================ from typing import Optional import torch import torch.nn.functional as F from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from flash_attn.bert_padding import index_first_axis as _index_first_axis from flash_attn.bert_padding import pad_input def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). max_seqlen_in_batch (`int`): Maximum sequence length in batch. """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, query_length: int, indices_k, cu_seqlens_k, max_seqlen_in_batch_k ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. query_length (`int`): Target length. Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). indices_q (`torch.Tensor`): The indices of non-masked tokens from the flattened input target sequence. (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = _index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: raise NotImplementedError() return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def flash_attn_prefill( module: torch.nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, dropout: float, scaling: float, query_length: int, batch_size: int, indices_k: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_in_batch_k: int, **kwargs ): """ Wrapper for flash attention during the prefill stage query_states must have shape (batch_size, num_heads, seq_len, head_dim) key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) This is the opposite of what is required by flash attention, but keeps parity with the HF convention query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs """ query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input( query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens # Returning None for attn_weights to match other attention interfaces flash_attn_out = _flash_attn_varlen_func( q_flash, k_flash, v_flash, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=scaling, causal=module.is_causal, ) return pad_input(flash_attn_out, indices_q, batch_size, query_length), None # NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility # This function is an order of magnitude faster than the prefill variant, or using the HF interface def flash_attn_decode( module: torch.nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, scaling: float, **kwargs, ): """ Wrapper for flash attention during the decode stage query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) This is the opposite of what is required by flash attention, but keeps parity with the HF convention This function computes the left pad and cache seqlens to pass into FA2. For example - Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token attention_mask = tensor([ [0, 0, 1, 1, 1, 0, 0, 0], # ← batch 0 [0, 1, 1, 1, 1, 1, 1, 0], # ← batch 1 ]) cache_leftpad = tensor([2, 1], dtype=torch.int32) cache_seqlens = tensor([5, 7], dtype=torch.int32) These values allow FlashAttention to use a static cache layout with efficient slicing during decoding. """ query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32) cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1 # Returning None for attn_weights to match other attention interfaces return _flash_attn_with_kvcache( q=query_states, k_cache=key_states, v_cache=value_states, cache_leftpad=cache_leftpad, cache_seqlens=cache_seqlens, causal=module.is_causal, softmax_scale=scaling, ), None ================================================ FILE: surya/common/surya/processor/__init__.py ================================================ import math import cv2 import numpy as np import torch from PIL import Image from torch.nn.utils.rnn import pad_sequence from typing import List, Optional, Tuple from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer from surya.common.s3 import S3DownloaderMixin from surya.common.surya.processor.schema import ( TextInput, ImageInput, ProcessorOutput, ) from surya.common.surya.schema import TaskNames from surya.logging import get_logger from surya.settings import settings logger = get_logger() # Task agnostic tokens - Every task will use these in some form or another EOS_TOKEN = "" EOI_TOKEN = "" # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways. IMAGE_TOKEN = "" PAD_TOKEN = "" NO_OUTPUT_TOKEN = "" IMAGE_ROTATED_TOKEN = "" REGISTER_TOKENS = ["", "", "", ""] BEACON_TOKEN = "" NOMATH_TOKEN = "" # Task specific tokens OCR_WITH_BOXES_BOS_TOKEN = "" OCR_WITHOUT_BOXES_BOS_TOKEN = "" BLOCK_WITHOUT_BOXES_TOKEN = "" LAYOUT_BOS_TOKEN = "" TABLE_STRUCTURE_BOS_TOKEN = "" class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin): attributes = ["image_processor", "ocr_tokenizer"] image_processor_class = "BaseImageProcessor" ocr_tokenizer_class = "PreTrainedTokenizer" rescale_factor = 1 / 255.0 image_mean = (0.485, 0.456, 0.406) image_std = (0.229, 0.224, 0.225) def __init__( self, ocr_tokenizer: PreTrainedTokenizer, blank_bbox_token_id: int, num_register_tokens: int, patch_size: int, merge_size: int, num_beacon_tokens: int, beacon_token_interval: int, model_device: str, **kwargs, ): self.ocr_tokenizer = ocr_tokenizer self.patch_size = patch_size self.merge_size = merge_size self.num_register_tokens = num_register_tokens self.num_beacon_tokens = num_beacon_tokens self.beacon_token_interval = beacon_token_interval self.tokenizer_vocab_size = 0 for attr in self.attributes: if "tokenizer" in attr: self.tokenizer_vocab_size += getattr(self, attr).vocab_size self.offsets = {"ocr": 0} # Create special token mapping self.special_token_mapping = self.ocr_tokenizer.system_tokens self.register_token_ids = [ self.special_token_mapping.get(r) for r in REGISTER_TOKENS ] self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN) self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN) self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN) self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN) self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN) self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN) self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN) self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN) self.bos_token_id = { TaskNames.ocr_with_boxes: self.special_token_mapping.get( OCR_WITH_BOXES_BOS_TOKEN ), TaskNames.ocr_without_boxes: self.special_token_mapping.get( OCR_WITHOUT_BOXES_BOS_TOKEN ), TaskNames.block_without_boxes: self.special_token_mapping.get( BLOCK_WITHOUT_BOXES_TOKEN ), TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN), TaskNames.table_structure: self.special_token_mapping.get( TABLE_STRUCTURE_BOS_TOKEN ), } if self.image_token_id is None: logger.warning("Warning: Image token not found in special tokens") self.blank_bbox_token_id = blank_bbox_token_id self.bbox_pad_token_id = self.blank_bbox_token_id self.ignore_bbox_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k not in self.ocr_tokenizer.special_tokens["math_external"] ] math_end_token = "" self.math_start_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k in self.ocr_tokenizer.special_tokens["math_external"] and k != math_end_token ] self.math_end_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k == math_end_token ] if self.num_register_tokens > len(self.register_token_ids): raise ValueError( "The number of register tokens requested exceeds the number of register tokens defined in the special token mapping." ) self.image_mean = np.array(self.image_mean, dtype=np.float32) self.image_std = np.array(self.image_std, dtype=np.float32) self.model_device = model_device @property def vocab_size(self): return self.tokenizer_vocab_size def image_processor(self, image: Image.Image) -> np.ndarray: # Convert to array image = np.asarray(image, dtype=np.float32) return image @staticmethod def scale_to_fit( img: np.ndarray, max_size: Tuple[int, int], min_size: Tuple[int, int] = (168, 168), ): # Get current dimensions height, width = img.shape[:2] # Check for empty or invalid image if width == 0 or height == 0: return img max_width, max_height = max_size min_width, min_height = min_size # Calculate pixel counts current_pixels = width * height max_pixels = max_width * max_height min_pixels = min_width * min_height if current_pixels > max_pixels: scale_factor = (max_pixels / current_pixels) ** 0.5 new_width = math.floor(width * scale_factor) new_height = math.floor(height * scale_factor) elif current_pixels == 0: return img elif current_pixels < min_pixels: scale_factor = (min_pixels / current_pixels) ** 0.5 new_width = math.ceil(width * scale_factor) new_height = math.ceil(height * scale_factor) else: return img return cv2.resize( img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4 ) def _image_processor(self, image: np.ndarray): image = image.astype(np.float64) * self.rescale_factor image = (image.astype(np.float32) - self.image_mean) / self.image_std return image def _process_and_tile( self, image: np.ndarray ) -> Tuple[torch.Tensor, Tuple[int, int, int]]: """ Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio and returns a tensor of image tiles. """ extra_multipler = ( 4 if settings.FOUNDATION_XLA else 1 ) # Needed to force same size grid_thws per row with padding factor = ( self.patch_size * self.merge_size * extra_multipler ) # Make a multiple of window size height, width = image.shape[:2] h_bar = math.ceil(height / factor) * factor w_bar = math.ceil(width / factor) * factor if h_bar != height or w_bar != width: if height == 0 or width == 0: image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8) else: image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC) # Handle scaling and normalization image = self._image_processor(image) height, width = image.shape[:2] # Numpy array to torch tensor img_tensor = torch.from_numpy(image.transpose(2, 0, 1)) patches = img_tensor.unsqueeze(0) channel = patches.shape[1] grid_t = patches.shape[0] grid_h, grid_w = height // self.patch_size, width // self.patch_size patches = patches.reshape( grid_t, 1, channel, grid_h // self.merge_size, self.merge_size, self.patch_size, grid_w // self.merge_size, self.merge_size, self.patch_size, ) patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size ) return flatten_patches, (grid_t, grid_h, grid_w) # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput: rotated = image_input.get("rotated", False) image = image_input.get("image", None) assert image is not None, ( "A PIL Image must be provided when the input type is `image`" ) image_tiles, grid_thw = self._process_and_tile(image) num_tokens = image_tiles.shape[0] / self.merge_size**2 assert num_tokens.is_integer(), ( f"Expected number of tokens to be an integer, got {num_tokens}" ) input_ids = [self.image_token_id] * int(num_tokens) input_ids += self.register_token_ids[: self.num_register_tokens] # Handle the image being rotated in the imdataset if rotated: input_ids = [self.image_rotated_token] + input_ids return ProcessorOutput( input_ids=input_ids, image_tiles=image_tiles, grid_thw=grid_thw, ) def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput: input_text = text_input.get("text", None) math_mode = text_input.get("math", False) input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0] input_ids = [self.offsets["ocr"] + id for id in input_ids] # nomath token does not work for layout if not math_mode and task != "layout": input_ids.insert(0, self.nomath_token) return ProcessorOutput( input_ids=input_ids, image_tiles=None, grid_thw=None, ) def _process_input(self, input_dict: dict, task: str): input_type = input_dict["type"] if input_type == "image": return self._process_image_input(input_dict) elif input_type == "text": return self._process_text_input(input_dict, task) raise NotImplementedError(f"Input of type `{input_type}` is not implemented") # Peprocessing for OCR task # The task is expected to have - image_dict, user_input_dict, output_dict # use_input_dict is allowed to have an empty input which is fine, but needs to be present def _process_ocr_with_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = TaskNames.ocr_with_boxes, ): processed_input_ids = [] all_image_tiles = [] all_grid_thw = [] # 1. Process the image input for i, input_dict in enumerate(mixed_input): processor_output = self._process_input(input_dict, task) input_ids = processor_output["input_ids"] image_tiles = processor_output["image_tiles"] grid_thw = processor_output["grid_thw"] # Special handling of some delimiter tokens if i == 1: assert input_dict["type"] == "text", ( "Expected text input for model input." ) # Case for input - Add task specific bos token + end_of_input token # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these input_ids = [bos_token_id] + input_ids + [self.eoi_token_id] if i == 2: assert input_dict["type"] == "text", ( "Expected text for final model input" ) input_ids = input_ids + [self.eos_token_id] elif i > 2: raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}") # Some input types don't return any image tiles, accounting for that if image_tiles is not None: all_image_tiles.append(image_tiles) all_grid_thw.append(grid_thw) processed_input_ids.extend(input_ids) return ( torch.tensor(processed_input_ids, dtype=torch.long), all_image_tiles, all_grid_thw, ) def _process_layout(self, mixed_input: List[dict], bos_token_id: int): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task="layout" ) def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task="table_structure" ) def _process_ocr_without_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = "ocr_without_boxes", ): # Boxes are set to None, so this will work # TODO: improve this behavior return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task=task ) def _process_block_without_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = "block_without_boxes", ): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task=task ) def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]: height, width, _ = image.shape if height > width: # Rotate vertical lines image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) return image, True return image, False def __call__( self, mixed_batch: List[dict], padding_side: Optional[str] = "left", device: Optional[torch.device] = None, pad_to_multiple: Optional[int] = None, ): all_image_tiles = [] all_input_ids = [] all_grid_thw = [] for b in mixed_batch: mixed_input = b["inputs"] task = b["task"] assert task in self.bos_token_id, f"Task {task} has no bos token defined." # Select the correct processing function based on the task type input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")( mixed_input, self.bos_token_id[task] ) all_input_ids.append(input_ids) all_image_tiles.extend(image_tiles) all_grid_thw.extend(grid_thw) batched_input_ids = pad_sequence( all_input_ids, batch_first=True, padding_side=padding_side, padding_value=self.pad_token_id, ) if pad_to_multiple is not None: current_len = batched_input_ids.shape[1] # Calculate the next multiple of pad_to_multiple padded_len = ( (current_len + pad_to_multiple - 1) // pad_to_multiple ) * pad_to_multiple if padded_len > current_len: pad_len = padded_len - current_len batched_input_ids = torch.nn.functional.pad( batched_input_ids, (pad_len, 0), value=self.pad_token_id ) attention_mask = batched_input_ids.ne(self.pad_token_id) # Generating position IDs that are independent of left and right padding; # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked position_ids = attention_mask.cumsum(dim=-1) - 1 position_ids[position_ids < 0] = ( 0 # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0 ) position_ids = ( attention_mask.to(torch.long) * position_ids ) # Ensure right pad ids get set to zero batched_image_tiles = torch.cat(all_image_tiles, dim=0) batched_grid_thw = torch.from_numpy(np.array(all_grid_thw)) # Pin memory for CUDA if device == torch.device("cuda"): batched_image_tiles = batched_image_tiles.pin_memory() batched_grid_thw = batched_grid_thw.pin_memory() attention_mask = attention_mask.pin_memory() batched_input_ids = batched_input_ids.pin_memory() position_ids = position_ids.pin_memory() return BatchFeature( { "input_ids": batched_input_ids, "image_tiles": batched_image_tiles, "attention_mask": attention_mask, "position_ids": position_ids, "grid_thw": batched_grid_thw, } ) # Decode model outputs; Strips special tokens def decode(self, tokens: List[int], task: str): filtered_tokens = [ t for t in tokens if t not in self.special_token_mapping.values() and t != -100 ] # Skip special tokens and loss ignore index return self.ocr_tokenizer.decode(filtered_tokens, task=task) ================================================ FILE: surya/common/surya/processor/schema.py ================================================ from typing import TypedDict, Literal, List, Tuple import torch from PIL import Image class TaskDict(TypedDict): datasets: List[str] img_size: Tuple[int, int] class TasksDict(TypedDict): ocr_with_boxes: TaskDict ocr_without_boxes: TaskDict block_without_boxes: TaskDict class ProcessorInput(TypedDict): type: Literal["image", "ocr", "text", "empty_output"] class ImageInput(ProcessorInput): type: Literal["image"] image: Image.Image rotated: bool class TextInput(ProcessorInput): type: Literal["text"] text: str math: bool class ProcessorOutput(TypedDict): input_ids: List[int] image_tiles: torch.Tensor | None grid_thw: torch.Tensor | None ================================================ FILE: surya/common/surya/processor/tokenizer.py ================================================ import html import re from typing import List, Union, Dict, Optional, Tuple, Iterable import numpy as np import torch from tokenizers import AddedToken import json import os from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer from surya.common.s3 import S3DownloaderMixin from surya.common.surya.schema import TASK_NAMES, TaskNames from surya.logging import get_logger from surya.settings import settings logger = get_logger() def create_token_regex(tokens): escaped_tokens = [re.escape(token) for token in tokens] escaped_tokens.sort(key=len, reverse=True) pattern = r"^(" + "|".join(escaped_tokens) + r")" regex = re.compile(pattern) return regex class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer): pass class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer): """ HuggingFace slow tokenizer implementing: - UTF-16 code units as the base [0..65535] - Math tokens as greedy-longest-match ids after UTF-16 - Literal special tokens after math tokens Absolute ID layout: [0 .. 65535] : UTF-16 units [65536 .. 65536+M-1] : math tokens [65536+M .. 65536+M+S-1] : special tokens """ vocab_files_names = { "vocab_file": "vocab_math.json", # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1 "specials_file": "specials.json", # [flat list for legacy] "specials_dict_file": "specials_dict.json", # category dict (preferred) } model_input_names = ["input_ids", "attention_mask"] is_fast = False # ---------- helpers ---------- @staticmethod def _to_utf16_units(s: str) -> List[int]: b = s.encode("utf-16le") return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)] @staticmethod def _from_utf16_units(units: List[int]) -> str: b = bytearray() for u in units: b += int(u).to_bytes(2, "little") return b.decode("utf-16le", errors="ignore") class _TrieNode: __slots__ = ("child", "id", "leaf") def __init__(self): self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {} self.id: Optional[int] = None self.leaf: bool = False @classmethod def _build_trie( cls, token_to_id: Dict[str, int] ) -> "GreedyMathUTF16Tokenizer._TrieNode": root = cls._TrieNode() for tok, tid in token_to_id.items(): node = root for ch in tok: node = node.child.setdefault(ch, cls._TrieNode()) node.leaf = True node.id = tid return root def _build_escape_patterns(self, math_token_to_rawid): """Build pattern list from vocab commands that start with control characters. Scans the math vocab for LaTeX commands that could be corrupted by JSON escape sequence interpretation (e.g., \\begin becomes egin). """ control_chars = { '\x08': 'b', # backspace '\t': 't', # tab '\n': 'n', # newline '\r': 'r', # carriage return '\f': 'f', # form feed '\x07': 'a', # bell '\x0b': 'v', # vertical tab } patterns = {char: [] for char in control_chars} for token in math_token_to_rawid.keys(): if token.startswith('\\') and len(token) > 1: letter = token[1:2] # First char after backslash for ctrl_char, ctrl_letter in control_chars.items(): if letter == ctrl_letter: # This token could be corrupted: \token -> oken suffix = token[2:] # Everything after \X patterns[ctrl_char].append((suffix, token)) # Sort by length (longest first) to avoid partial matches for char in patterns: patterns[char].sort(key=lambda x: len(x[0]), reverse=True) return patterns @classmethod def _encode_math_greedy( cls, s: str, trie: "GreedyMathUTF16Tokenizer._TrieNode", math_base: int, debug: bool = False, ) -> List[int]: i, n = 0, len(s) out: List[int] = [] while i < n: node = trie j = i last_id = None last_j = i while j < n and (ch := s[j]) in node.child: node = node.child[ch] j += 1 if node.leaf: last_id, last_j = node.id, j if last_id is not None: if debug: print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}") out.append(math_base + last_id) i = last_j else: units = cls._to_utf16_units(s[i]) if debug: print(f"[MATH] fallback {s[i]!r} -> utf16 {units}") out.extend(units) i += 1 return out # ---------- init ---------- def __init__( self, vocab_file: Optional[str] = None, specials_file: Optional[str] = None, specials_dict_file: Optional[str] = None, *, # You can also pass programmatically instead of files: math_vocab: Optional[Dict[str, int]] = None, special_tokens: Optional[List[str]] = None, special_tokens_dict: Optional[Dict[str, List[str]]] = None, debug: bool = False, # Standard HF special token kwargs: bos_token: Optional[str] = None, eos_token: Optional[str] = None, pad_token: Optional[str] = None, unk_token: Optional[str] = None, **kwargs, ): # Load math vocab if vocab_file and os.path.isfile(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: mv = json.load(f) else: mv = math_vocab or {} # Make math ids contiguous if needed if mv: max_id = max(mv.values()) if set(mv.values()) != set(range(max_id + 1)): items = sorted(mv.items(), key=lambda kv: kv[1]) mv = {tok: i for i, (tok, _) in enumerate(items)} # Load special tokens (prefer category dict; fallback to flat list or defaults) sp_dict = None if specials_dict_file and os.path.isfile(specials_dict_file): with open(specials_dict_file, "r", encoding="utf-8") as f: sp_dict = json.load(f) elif special_tokens_dict is not None: sp_dict = dict(special_tokens_dict) if sp_dict is None: # Legacy path: flat list from file or provided/default list if specials_file and os.path.isfile(specials_file): with open(specials_file, "r", encoding="utf-8") as f: sp_list_flat = json.load(f) else: sp_list_flat = special_tokens or SPECIAL_TOKENS sp_dict = {"all": list(sp_list_flat)} # Ensure "all" exists and is unique/preserved in order. if "all" not in sp_dict or not isinstance(sp_dict["all"], list): order = [ "system", "formatting", "math_external", "script", "layout", "reasoning", "table_structure", "reserved", ] seen = set() all_tokens: List[str] = [] for k in order: if k in sp_dict and isinstance(sp_dict[k], list): for t in sp_dict[k]: if t not in seen: all_tokens.append(t) seen.add(t) sp_dict["all"] = all_tokens # Keep a copy of categories (if present) for downstream processor logic. self.special_tokens = sp_dict sp_list = list(sp_dict.get("all", [])) # Regex list should favor longest-first to avoid partial matches. specials_for_regex = sorted(sp_list, key=len, reverse=True) self.debug = debug self.UTF16_SPACE = 65536 self.math_token_to_rawid = dict(mv) # 0..M-1 self.math_vocab_size = len(self.math_token_to_rawid) self.MATH_BASE = self.UTF16_SPACE self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size # Maps self.math_absid_to_token = { self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items() } self.special_tokens_list = sp_list # ID assignment order self.special_to_absid = { tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list) } self.absid_to_special = {v: k for k, v in self.special_to_absid.items()} # Public attributes for legacy/processor: # All specials mapping (token -> absolute id) self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid) # Subset used heavily by processor for quick access self.reverse_special_token_mapping = { v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() } self.LAYOUT_LABEL2ID = { k: v for k, v in self.SPECIAL_TOKEN_MAPPING.items() if k in self.special_tokens["layout"] } self.TABLE_STRUCTURE_LABEL2ID = { k: v for k, v in self.SPECIAL_TOKEN_MAPPING.items() if k in self.special_tokens["table_structure"] } if not self.special_tokens.get("system", []): print("Warning: No system tokens found in special_tokens") self.MATH_TAG_START = "") kwargs.setdefault("pad_token", pad_token or "") kwargs.setdefault("unk_token", unk_token) super().__init__( vocab_file=vocab_file, specials_file=specials_file, specials_dict_file=specials_dict_file, **kwargs, ) # ---------- required HF surface ---------- @property def vocab_size(self) -> int: return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list) def get_vocab(self) -> Dict[str, int]: # Compact vocab: just math+specials with ABSOLUTE ids. v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()} v.update(self.special_to_absid) return v def __len__(self) -> int: return self.vocab_size # Core encode/decode on ABSOLUTE ids def _encode_core(self, text: str) -> List[int]: text = html.unescape(text) ids: List[int] = [] in_math = False chunks = self.specials_pattern.split(text) if self.specials_pattern else [text] for chunk in chunks: if chunk in self.special_to_absid: ids.append(self.special_to_absid[chunk]) if chunk.startswith(""): in_math = False if self.debug: print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}") continue if in_math: ids.extend( self._encode_math_greedy( chunk, self.trie, self.MATH_BASE, debug=self.debug ) ) else: units = self._to_utf16_units(chunk) if self.debug and units: print( f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}" ) ids.extend(units) return ids def _fix_latex_escapes(self, text: str) -> str: """Fix improperly escaped LaTeX commands in decoded text. Operates on the complete decoded string, replacing control character sequences with their intended LaTeX commands based on vocab patterns. """ result = [] i = 0 while i < len(text): char = text[i] if char in self.latex_escape_patterns: # Check if any pattern matches matched = False for pattern, replacement in self.latex_escape_patterns[char]: if text[i+1:].startswith(pattern): result.append(replacement) i += 1 + len(pattern) matched = True break if not matched: # Not a LaTeX command, keep the control char as-is result.append(char) i += 1 else: result.append(char) i += 1 return ''.join(result) def _decode_core(self, ids: Iterable[int]) -> str: out: List[str] = [] buf: List[int] = [] def flush(): if buf: out.append(self._from_utf16_units(buf)) buf.clear() for tid in ids: if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE: flush() out.append(self.math_absid_to_token.get(tid, "")) elif tid >= self.SPECIAL_BASE: flush() out.append(self.absid_to_special.get(tid, "")) else: buf.append(int(tid)) flush() decoded = "".join(out) return self._fix_latex_escapes(decoded) # ---- Tokenizer interface ---- def _tokenize(self, text: str, **kwargs) -> List[str]: ids = self._encode_core(text) toks: List[str] = [] for i in ids: if i < self.MATH_BASE: toks.append(f"") elif i < self.SPECIAL_BASE: toks.append(self.math_absid_to_token.get(i, "")) else: toks.append(self.absid_to_special.get(i, "")) return toks def _convert_token_to_id(self, token: str) -> int: if token.startswith(""): try: return int(token[3:-1], 16) # UTF-16 unit except Exception: return self.unk_token_id if self.unk_token_id is not None else 0 # math or specials if token in self.math_token_to_rawid: return self.MATH_BASE + self.math_token_to_rawid[token] if token in self.special_to_absid: return self.special_to_absid[token] # rare path: single-char token -> its UTF-16 unit if len(token) == 1: u = self._to_utf16_units(token) if len(u) == 1: return u[0] return self.unk_token_id if self.unk_token_id is not None else 0 def _convert_id_to_token(self, index: int) -> str: if index < self.MATH_BASE: return f"" if index < self.SPECIAL_BASE: return self.math_absid_to_token.get(index, "") return self.absid_to_special.get(index, "") def convert_tokens_to_string(self, tokens: List[str]) -> str: ids = [self._convert_token_to_id(t) for t in tokens] return self._decode_core(ids) def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: # Accept int, list, tuple, numpy, torch if hasattr(token_ids, "tolist"): token_ids = token_ids.tolist() elif isinstance(token_ids, int): token_ids = [token_ids] else: token_ids = list(token_ids) token_ids = [int(i) for i in token_ids] # normalize early if skip_special_tokens: token_ids = [i for i in token_ids if i < self.SPECIAL_BASE] return self._decode_core(token_ids) # HF plumbing def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: out = ( list(token_ids_0) if token_ids_1 is None else list(token_ids_0) + list(token_ids_1) ) # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id): # out.append(self.eos_token_id) return out def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: def mask(seq: List[int]) -> List[int]: return [1 if i >= self.SPECIAL_BASE else 0 for i in seq] return ( mask(token_ids_0) if token_ids_1 is None else mask(token_ids_0) + mask(token_ids_1) ) def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: return [0] * ( len(token_ids_0) if token_ids_1 is None else len(token_ids_0) + len(token_ids_1) ) # Save/load raw assets def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> Tuple[str, str]: os.makedirs(save_directory, exist_ok=True) pre = (filename_prefix + "-") if filename_prefix else "" vocab_path = os.path.join( save_directory, pre + self.vocab_files_names["vocab_file"] ) specials_path = os.path.join( save_directory, pre + self.vocab_files_names["specials_file"] ) specials_dict_path = os.path.join( save_directory, pre + self.vocab_files_names["specials_dict_file"] ) with open(vocab_path, "w", encoding="utf-8") as f: json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2) # Save both the flat list ("all") and the category dict (preferred) with open(specials_path, "w", encoding="utf-8") as f: json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2) with open(specials_dict_path, "w", encoding="utf-8") as f: json.dump(self.special_tokens, f, ensure_ascii=False, indent=2) return (vocab_path, specials_path) class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer): def __init__( self, special_tokens: Dict[str, list] | None = None, model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT, **kwargs, ): if special_tokens is None: special_tokens = dict() self.special_tokens = special_tokens self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained( model_checkpoint, ) self.system_tokens = { v: self.ocr_tokenizer(v)["input_ids"][0] for v in special_tokens.get("system", []) } self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING super().__init__(**kwargs) def get_vocab(self) -> Dict[str, int]: return self.ocr_tokenizer.get_vocab() def _add_tokens( self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False, ) -> int: return self.ocr_tokenizer._add_tokens( new_tokens, special_tokens=special_tokens ) @property def vocab_size(self): return self.ocr_tokenizer.vocab_size def _tokenize(self, text: str, **kwargs): # task = kwargs.get("task", TaskNames.ocr_with_boxes) # assert task in TASK_NAMES, f"Invalid task: {task}" tokens = self.ocr_tokenizer(text)["input_ids"] return tokens def __call__( self, texts: Union[str, List[str]], tasks: Union[str, List[str]] = None, **kwargs, ) -> Dict[str, List[List[int]]]: """Tokenizes text and returns input IDs.""" tokenized = [] if isinstance(texts, str): texts = [texts] assert isinstance(tasks, str), "Tasks must be a string if texts is a string" tasks = [tasks] if isinstance(texts, list): assert isinstance(tasks, list), "Tasks must be a list if texts is a list" for text, task in zip(texts, tasks): tokens = self._tokenize(text, task=task) tokenized.append(tokens) return {"input_ids": tokenized} def decode(self, token_ids, **kwargs): if isinstance(token_ids, (np.ndarray, torch.Tensor)): token_ids = token_ids.tolist() decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False) # replace all tokens with empty strings decoded_text = re.sub(r"", "", decoded_text) # replace with empty string decoded_text = re.sub(r"", "", decoded_text) return decoded_text ================================================ FILE: surya/common/surya/schema.py ================================================ class TaskNames: block_without_boxes = "block_without_boxes" ocr_with_boxes = "ocr_with_boxes" ocr_without_boxes = "ocr_without_boxes" layout = "layout" table_structure = "table_structure" TASK_NAMES = [ TaskNames.block_without_boxes, TaskNames.ocr_with_boxes, TaskNames.ocr_without_boxes, TaskNames.layout, TaskNames.table_structure, ] ================================================ FILE: surya/common/util.py ================================================ import copy from typing import List import torch from functools import lru_cache import torch.nn.functional as F from surya.common.polygon import PolygonBox def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: new_boxes = [] for box_obj in boxes: xs = [point[0] for point in box_obj.polygon] ys = [point[1] for point in box_obj.polygon] if max(xs) == min(xs) or max(ys) == min(ys): continue box = box_obj.bbox contained = False for other_box_obj in boxes: if other_box_obj.polygon == box_obj.polygon: continue other_box = other_box_obj.bbox if box == other_box: continue if ( box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3] ): contained = True break if not contained: new_boxes.append(box_obj) return new_boxes def rescale_bbox(bbox, processor_size, image_size): page_width, page_height = processor_size img_width, img_height = image_size width_scaler = img_width / page_width height_scaler = img_height / page_height new_bbox = copy.deepcopy(bbox) new_bbox[0] = int(new_bbox[0] * width_scaler) new_bbox[1] = int(new_bbox[1] * height_scaler) new_bbox[2] = int(new_bbox[2] * width_scaler) new_bbox[3] = int(new_bbox[3] * height_scaler) return new_bbox def expand_bbox(bbox, expansion_factor=0.01): expansion_low = 1 - expansion_factor expansion_high = 1 + expansion_factor return [ bbox[0] * expansion_low, bbox[1] * expansion_low, bbox[2] * expansion_high, bbox[3] * expansion_high, ] SCRIPT_TOKEN_MAPPING = { "latin": "", "punctuation": "", "cyrillic": "", "arabic": "", "chinese": "", "japanese": "", "korean": "", "symbols": "", "greek": "", "armenian": "", "hebrew": "", "devanagari": "", "bengali": "", "gurmukhi": "", "gujarati": "", "oriya": "", "tamil": "", "telugu": "", "kannada": "", "malayalam": "", "sinhala": "", "thai": "", "lao": "", "myanmar": "", "georgian": "", "ethiopic": "", "khmer": "", "mongolian": "", "math": "", } @lru_cache(maxsize=1) def script_ranges(): script_categories = { # Latin-based scripts (used by English, French, German, etc.) "latin": [ (0x0041, 0x005A), # Latin uppercase A-Z (0x0061, 0x007A), # Latin lowercase a-z (0x0080, 0x00FF), # Latin-1 Supplement (0x0100, 0x017F), # Latin Extended-A (0x0180, 0x024F), # Latin Extended-B (0x0250, 0x02AF), # IPA Extensions (0x02B0, 0x02FF), # Spacing Modifier Letters (0x0300, 0x036F), # Combining Diacritical Marks (0x1E00, 0x1EFF), # Latin Extended Additional (0x2C60, 0x2C7F), # Latin Extended-C (0xA720, 0xA7FF), # Latin Extended-D ], # Punctuation, universal characters, and general symbols "punctuation": [ (0x0020, 0x0020), # Space (0x0021, 0x002F), # Basic punctuation and symbols (0x0030, 0x0039), # Digits 0-9 (0x003A, 0x0040), # More punctuation and symbols (0x005B, 0x0060), # More punctuation and symbols (0x007B, 0x007F), # More punctuation and symbols (0x2000, 0x206F), # General Punctuation ], # Cyrillic scripts (used by Russian, Ukrainian, etc.) "cyrillic": [ (0x0400, 0x04FF), # Cyrillic (0x0500, 0x052F), # Cyrillic Supplement ], # Arabic scripts "arabic": [ (0x0600, 0x06FF), # Arabic (0x0750, 0x077F), # Arabic Supplement (0x08A0, 0x08FF), # Arabic Extended-A ], # Chinese characters "chinese": [ (0x4E00, 0x9FFF), # Common CJK Unified Ideographs (0x3400, 0x4DBF), # CJK Extension A (0x20000, 0x2A6DF), # CJK Extension B ], # Japanese-specific scripts (excluding shared CJK) "japanese": [ (0x3040, 0x30FF), # Hiragana and Katakana ], # Korean-specific scripts "korean": [ (0x1100, 0x11FF), # Hangul Jamo (0x3130, 0x318F), # Hangul Compatibility Jamo (0xAC00, 0xD7AF), # Hangul Syllables ], # Various mathematical and technical symbols "symbols": [ (0x2070, 0x209F), # Superscripts and Subscripts (0x20A0, 0x20CF), # Currency Symbols (0x2100, 0x214F), # Letterlike Symbols (0x2150, 0x218F), # Number Forms (0x2190, 0x21FF), # Arrows (0x2200, 0x22FF), # Mathematical Operators (0x2300, 0x23FF), # Miscellaneous Technical (0x2500, 0x257F), # Box Drawing (0x2580, 0x259F), # Block Elements (0x25A0, 0x25FF), # Geometric Shapes (0x2600, 0x26FF), # Miscellaneous Symbols (0x2700, 0x27BF), # Dingbats (0x27C0, 0x27EF), # Miscellaneous Mathematical Symbols-A (0x2980, 0x29FF), # Miscellaneous Mathematical Symbols-B (0x2A00, 0x2AFF), # Supplemental Mathematical Operators (0x1D400, 0x1D7FF), # Mathematical Alphanumeric Symbols ], # Individual scripts for languages with unique writing systems "greek": [(0x0370, 0x03FF)], # Greek and Coptic "armenian": [(0x0530, 0x058F)], # Armenian "hebrew": [(0x0590, 0x05FF)], # Hebrew "devanagari": [(0x0900, 0x097F)], # Devanagari (Hindi, Sanskrit) "bengali": [(0x0980, 0x09FF)], # Bengali "gurmukhi": [(0x0A00, 0x0A7F)], # Gurmukhi (Punjabi) "gujarati": [(0x0A80, 0x0AFF)], # Gujarati "oriya": [(0x0B00, 0x0B7F)], # Oriya "tamil": [(0x0B80, 0x0BFF)], # Tamil "telugu": [(0x0C00, 0x0C7F)], # Telugu "kannada": [(0x0C80, 0x0CFF)], # Kannada "malayalam": [(0x0D00, 0x0D7F)], # Malayalam "sinhala": [(0x0D80, 0x0DFF)], # Sinhala "thai": [(0x0E00, 0x0E7F)], # Thai "lao": [(0x0E80, 0x0EFF)], # Lao "myanmar": [(0x1000, 0x109F)], # Myanmar "georgian": [(0x10A0, 0x10FF)], # Georgian "ethiopic": [(0x1200, 0x137F)], # Ethiopic "khmer": [(0x1780, 0x17FF)], # Khmer "mongolian": [(0x1800, 0x18AF)], # Mongolian } # Convert to a flat structure with character ranges flat_ranges = {} for category, ranges in script_categories.items(): # Create a set of all characters in this category char_set = set() for start, end in ranges: char_set.update(range(start, end + 1)) # Store the set in flat_ranges flat_ranges[category] = char_set return script_categories, flat_ranges def get_top_scripts(text: str, max_scripts: int = 5): script_categories, flat_ranges = script_ranges() char_count = {category: 0 for category in script_categories.keys()} for char in text: for category, char_set in flat_ranges.items(): if ord(char) in char_set: char_count[category] += 1 break top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True) top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0] if " bool: if not torch.cuda.is_available(): return False if "cuda" not in str(device): return False # Check CUDA version >= 12.0 cuda_version_str = torch.version.cuda if cuda_version_str is None: return False cuda_version = tuple(map(int, cuda_version_str.split("."))) if cuda_version < (12, 0): return False # Check GPU compute capability (Ampere, Ada, Hopper GPUs) major, minor = torch.cuda.get_device_capability() compute_capability = major + minor / 10 if compute_capability < 8.0: return False return True def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor pad_size = batch_size - current_batch_size if pad_size < 0: return tensor # Repeat the last row pad_size times last_row = tensor[-1:].repeat(pad_size, 1, 1) # Concatenate original tensor with repeated last rows return torch.cat([tensor, last_row], dim=0) def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) return F.pad(tensor, padding, mode="constant", value=0) ================================================ FILE: surya/common/xla.py ================================================ import math from surya.settings import settings if settings.TORCH_DEVICE_MODEL == "xla": import torch_xla.core.xla_model as xm else: xm = None def get_nearest_pad( length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST ): return math.ceil(length / pad_multiple) * pad_multiple def get_compile_args(device: str) -> dict: if not settings.FOUNDATION_XLA: return {} return { "backend": "openxla", } def mark_step(): if xm is not None: xm.mark_step() ================================================ FILE: surya/debug/draw.py ================================================ from PIL import ImageDraw, ImageFont from surya.debug.fonts import get_font_path from surya.debug.text import get_text_size def draw_bboxes_on_image( bboxes, image, labels=None, label_font_size=10, color: str | list = "red" ): polys = [] for bb in bboxes: # Clockwise polygon poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] polys.append(poly) return draw_polys_on_image( polys, image, labels, label_font_size=label_font_size, color=color ) def draw_polys_on_image( corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list = "red", ): draw = ImageDraw.Draw(image) font_path = get_font_path() label_font = ImageFont.truetype(font_path, label_font_size) for i in range(len(corners)): poly = corners[i] poly = [(int(p[0]), int(p[1])) for p in poly] draw.polygon( poly, outline=color[i] if isinstance(color, list) else color, width=1 ) if labels is not None: label = labels[i] text_position = ( min([p[0] for p in poly]) + label_offset, min([p[1] for p in poly]) + label_offset, ) text_size = get_text_size(label, label_font) box_position = ( text_position[0] - box_padding + label_offset, text_position[1] - box_padding + label_offset, text_position[0] + text_size[0] + box_padding + label_offset, text_position[1] + text_size[1] + box_padding + label_offset, ) try: draw.rectangle(box_position, fill="white") except Exception as e: print(f"Error drawing rectangle at {box_position}: {e}") continue draw.text( text_position, label, fill=color[i] if isinstance(color, list) else color, font=label_font, ) return image ================================================ FILE: surya/debug/fonts.py ================================================ from typing import List, Optional import os import requests from surya.settings import settings def get_font_path(langs: Optional[List[str]] = None) -> str: font_path = settings.RECOGNITION_RENDER_FONTS["all"] if langs is not None: for k in settings.RECOGNITION_RENDER_FONTS: if k in langs and len(langs) == 1: font_path = settings.RECOGNITION_RENDER_FONTS[k] break if not os.path.exists(font_path): os.makedirs(os.path.dirname(font_path), exist_ok=True) font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: r.raise_for_status() for chunk in r.iter_content(chunk_size=8192): f.write(chunk) return font_path ================================================ FILE: surya/debug/katex.js ================================================ ================================================ FILE: surya/debug/render_html.py ================================================ import html as htmllib import os.path import re filepath = os.path.abspath(__file__) def render_text_as_html( bboxes: list[list[int]], texts: list[str], image_size: tuple[int, int], base_font_size: int = 16, scaler: int = 2 ): katex_path = os.path.join(os.path.dirname(filepath), "katex.js") with open(katex_path, "r") as f: katex_script = f.read() html_content = [] image_size = tuple([int(s * scaler) for s in image_size]) width, height = image_size html_content.append(f""" {katex_script} """) for i, (bbox, text) in enumerate(zip(bboxes, texts)): bbox = bbox.copy() bbox = [int(bb * scaler) for bb in bbox] x1, y1, x2, y2 = bbox width = x2 - x1 height = y2 - y1 min_dim = min(width, height) # Scale font size based on box height font_size = min(int(min_dim * 0.75), base_font_size) # Create div with absolute positioning div_style = ( f"left: {x1}px; " f"top: {y1}px; " f"width: {width}px; " f"height: {height}px; " f"font-size: {font_size}px;" ) class_ = "text-box" if height > width * 2: class_ += " vertical-text" # Determine if content is HTML/MathML or plain text if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()): # Content is already HTML/MathML, include as-is html_content.append(f'{text}') else: # Plain text, escape it escaped_text = htmllib.escape(text) html_content.append(f'{escaped_text}') html_content.append("") return "\n".join(html_content), image_size ================================================ FILE: surya/debug/text.py ================================================ import re from io import BytesIO from typing import List, Tuple from PIL import Image, ImageDraw, ImageFont from surya.debug.fonts import get_font_path from surya.debug.render_html import render_text_as_html try: from playwright.sync_api import sync_playwright has_playwright = True except ImportError: has_playwright = False def strip_html_tags(html_text): pattern = re.compile(r"<[\w/][^>]*>") text_only = pattern.sub("", html_text) return text_only def get_text_size(text, font): im = Image.new(mode="P", size=(0, 0)) draw = ImageDraw.Draw(im) _, _, width, height = draw.textbbox((0, 0), text=text, font=font) return width, height def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): font = ImageFont.truetype(font_path, box_font_size) text_width, text_height = get_text_size(text, font) while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: box_font_size = box_font_size - 1 font = ImageFont.truetype(font_path, box_font_size) text_width, text_height = get_text_size(text, font) # Calculate text position (centered in bbox) text_width, text_height = get_text_size(text, font) x = s_bbox[0] y = s_bbox[1] + (bbox_height - text_height) / 2 draw.text((x, y), text, fill="black", font=font) def draw_text_with_playwright( bboxes, texts: List[str], image_size: Tuple[int, int] ) -> Image.Image: html_content, image_size = render_text_as_html(bboxes, texts, image_size) if not has_playwright: raise ImportError( "Playwright is not installed. Please install it using `pip install playwright`" ) with sync_playwright() as p: browser = p.chromium.launch(headless=True) page = browser.new_page( viewport={"width": image_size[0], "height": image_size[1]} ) page.set_content(html_content) page.wait_for_timeout(1000) body = page.query_selector("body") image = body.screenshot() browser.close() pil_img = Image.open(BytesIO(image)) return pil_img def draw_text_on_image( bboxes, texts, image_size: Tuple[int, int], font_path=None, max_font_size=60, res_upscale=2, ) -> Image.Image: if has_playwright: return draw_text_with_playwright(bboxes, texts, image_size) texts = [strip_html_tags(text) for text in texts] if font_path is None: font_path = get_font_path() new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) image = Image.new("RGB", new_image_size, color="white") draw = ImageDraw.Draw(image) for bbox, text in zip(bboxes, texts): s_bbox = [int(coord * res_upscale) for coord in bbox] bbox_width = s_bbox[2] - s_bbox[0] bbox_height = s_bbox[3] - s_bbox[1] # Shrink the text to fit in the bbox if needed box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size)) render_text( draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size ) return image ================================================ FILE: surya/detection/__init__.py ================================================ from concurrent.futures import ThreadPoolExecutor from typing import List, Generator, Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image from tqdm import tqdm from surya.common.predictor import BasePredictor from surya.common.xla import mark_step from surya.detection.loader import DetectionModelLoader from surya.detection.parallel import FakeExecutor from surya.detection.util import get_total_splits, split_image from surya.detection.schema import TextDetectionResult from surya.settings import settings from surya.detection.heatmap import parallel_get_boxes class DetectionPredictor(BasePredictor): model_loader_cls = DetectionModelLoader batch_size = settings.DETECTOR_BATCH_SIZE default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 36, "xla": 18} def __call__( self, images: List[Image.Image], batch_size=None, include_maps=False ) -> List[TextDetectionResult]: detection_generator = self.batch_detection( images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE ) postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = ( not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH ) executor = ThreadPoolExecutor if parallelize else FakeExecutor with executor(max_workers=max_workers) as e: for preds, orig_sizes in detection_generator: for pred, orig_size in zip(preds, orig_sizes): postprocessing_futures.append( e.submit(parallel_get_boxes, pred, orig_size, include_maps) ) return [future.result() for future in postprocessing_futures] def prepare_image(self, img): new_size = (self.processor.size["width"], self.processor.size["height"]) # This double resize actually necessary for downstream accuracy img.thumbnail(new_size, Image.Resampling.LANCZOS) img = img.resize( new_size, Image.Resampling.LANCZOS ) # Stretch smaller dimension to fit new size img = np.asarray(img, dtype=np.uint8) img = self.processor(img)["pixel_values"][0] img = torch.from_numpy(img) return img def batch_detection( self, images: List, batch_size=None, static_cache=False ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() heatmap_count = self.model.config.num_labels orig_sizes = [image.size for image in images] splits_per_image = [ get_total_splits(size, self.processor.size["height"]) for size in orig_sizes ] batches = [] current_batch_size = 0 current_batch = [] for i in range(len(images)): if current_batch_size + splits_per_image[i] > batch_size: if len(current_batch) > 0: batches.append(current_batch) current_batch = [] current_batch_size = 0 current_batch.append(i) current_batch_size += splits_per_image[i] if len(current_batch) > 0: batches.append(current_batch) for batch_idx in tqdm( range(len(batches)), desc="Detecting bboxes", disable=self.disable_tqdm ): batch_image_idxs = batches[batch_idx] batch_images = [images[j].convert("RGB") for j in batch_image_idxs] split_index = [] split_heights = [] image_splits = [] for image_idx, image in enumerate(batch_images): image_parts, split_height = split_image( image, self.processor.size["height"] ) image_splits.extend(image_parts) split_index.extend([image_idx] * len(image_parts)) split_heights.extend(split_height) image_splits = [self.prepare_image(image) for image in image_splits] # Batch images in dim 0 batch = torch.stack(image_splits, dim=0).to(self.model.dtype) if static_cache: batch = self.pad_to_batch_size(batch, batch_size) with settings.INFERENCE_MODE(): pred = self.model( pixel_values=batch.to(self.model.device) ) # Moving the to device here fixes issues with xla recompilation logits = pred.logits correct_shape = [ self.processor.size["height"], self.processor.size["width"], ] current_shape = list(logits.shape[2:]) if current_shape != correct_shape: logits = F.interpolate( logits, size=correct_shape, mode="bilinear", align_corners=False ) mark_step() logits = logits.to(torch.float32).cpu().numpy() preds = [] for i, (idx, height) in enumerate(zip(split_index, split_heights)): # If our current prediction length is below the image idx, that means we have a new image # Otherwise, we need to add to the current image if len(preds) <= idx: preds.append([logits[i][k] for k in range(heatmap_count)]) else: heatmaps = preds[idx] pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] if height < self.processor.size["height"]: # Cut off padding to get original height pred_heatmaps = [ pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps ] for k in range(heatmap_count): heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) preds[idx] = heatmaps yield preds, [orig_sizes[j] for j in batch_image_idxs] torch.cuda.empty_cache() ================================================ FILE: surya/detection/heatmap.py ================================================ from typing import List import cv2 import numpy as np from PIL import Image from surya.common.util import clean_boxes from surya.detection import TextDetectionResult from surya.common.polygon import PolygonBox from surya.settings import settings def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7): # Find average intensity of top 10% pixels flat_map = linemap.ravel() top_10_count = int(len(flat_map) * 0.9) avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:]) scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2) low_text = np.clip(low_text * scaling_factor, 0.1, 0.6) text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8) return text_threshold, low_text def detect_boxes(linemap, text_threshold, low_text): # From CRAFT - https://github.com/clovaai/CRAFT-pytorch # Modified to return boxes and for speed, accuracy img_h, img_w = linemap.shape text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text) text_score_comb = (linemap > low_text).astype(np.uint8) label_count, labels, stats, centroids = cv2.connectedComponentsWithStats( text_score_comb, connectivity=4 ) det = [] confidences = [] max_confidence = 0 for k in range(1, label_count): # size filtering size = stats[k, cv2.CC_STAT_AREA] if size < 10: continue # make segmentation map x, y, w, h = stats[ k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT], ] try: niter = int(np.sqrt(min(w, h))) except ValueError: niter = 0 buffer = 1 sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer) ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer) mask = labels[sy:ey, sx:ex] == k selected_linemap = linemap[sy:ey, sx:ex][mask] if selected_linemap.size == 0: continue line_max = np.max(selected_linemap) # thresholding if line_max < text_threshold: continue segmap = mask.astype(np.uint8) ksize = buffer + niter kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize)) selected_segmap = cv2.dilate(segmap, kernel) # make box y_inds, x_inds = np.nonzero(selected_segmap) x_inds += sx y_inds += sy np_contours = np.column_stack((x_inds, y_inds)) rectangle = cv2.minAreaRect(np_contours) box = cv2.boxPoints(rectangle) # align diamond-shape w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) box_ratio = max(w, h) / (min(w, h) + 1e-5) if abs(1 - box_ratio) <= 0.1: left, right = np_contours[:, 0].min(), np_contours[:, 0].max() top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max() box = np.array( [[left, top], [right, top], [right, bottom], [left, bottom]], dtype=np.float32, ) # make clock-wise order startidx = box.sum(axis=1).argmin() box = np.roll(box, 4 - startidx, 0) max_confidence = max(max_confidence, line_max) confidences.append(line_max) det.append(box) if max_confidence > 0: confidences = [c / max_confidence for c in confidences] return det, confidences def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: if text_threshold is None: text_threshold = settings.DETECTOR_TEXT_THRESHOLD if low_text is None: low_text = settings.DETECTOR_BLANK_THRESHOLD if textmap.dtype != np.float32: textmap = textmap.astype(np.float32) boxes, confidences = detect_boxes(textmap, text_threshold, low_text) # From point form to box form return [ PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences) ] def get_and_clean_boxes( textmap, processor_size, image_size, text_threshold=None, low_text=None ) -> List[PolygonBox]: bboxes = get_detected_boxes(textmap, text_threshold, low_text) for bbox in bboxes: bbox.rescale(processor_size, image_size) bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]]) bboxes = clean_boxes(bboxes) return bboxes def parallel_get_boxes(preds, orig_sizes, include_maps=False): heatmap, affinity_map = preds heat_img, aff_img = None, None if include_maps: heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) heatmap_size = list(reversed(heatmap.shape)) bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) for box in bboxes: # Skip for vertical boxes if box.height < 3 * box.width: box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN) box.fit_to_bounds( [0, 0, orig_sizes[0], orig_sizes[1]] ) # Fix any bad expands result = TextDetectionResult( bboxes=bboxes, heatmap=heat_img, affinity_map=aff_img, image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]], ) return result ================================================ FILE: surya/detection/loader.py ================================================ from typing import Optional import torch from surya.common.load import ModelLoader from surya.detection.processor import SegformerImageProcessor from surya.detection.model.config import EfficientViTConfig from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation from surya.logging import get_logger from surya.settings import settings logger = get_logger() class DetectionModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT def model( self, device: Optional[torch.device | str] = None, dtype: Optional[torch.dtype | str] = None, attention_implementation: Optional[str] = None, ) -> EfficientViTForSemanticSegmentation: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE config = EfficientViTConfig.from_pretrained(self.checkpoint) model = EfficientViTForSemanticSegmentation.from_pretrained( self.checkpoint, dtype=dtype, config=config, ) model = model.to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_DETECTOR: torch._dynamo.config.cache_size_limit = 1 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model = torch.compile(model, **compile_args) logger.debug( f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) return model def processor( self, device: Optional[torch.device | str] = None, dtype: Optional[torch.dtype | str] = None, ) -> SegformerImageProcessor: return SegformerImageProcessor.from_pretrained(self.checkpoint) ================================================ FILE: surya/detection/model/__init__.py ================================================ ================================================ FILE: surya/detection/model/config.py ================================================ from transformers import PretrainedConfig from surya.common.s3 import S3DownloaderMixin class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig): r""" ```""" model_type = "efficientvit" def __init__( self, num_classes=2, num_channels=3, widths=(32, 64, 128, 256, 512), head_dim=32, num_stages=4, depths=(1, 1, 1, 6, 6), strides=(2, 2, 2, 2, 2), hidden_sizes=(32, 64, 160, 256), patch_size=(7, 7), hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, classifier_dropout_prob=0.0, layer_norm_eps=1e-6, decoder_layer_hidden_size=128, decoder_hidden_size=512, semantic_loss_ignore_index=255, initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) self.num_classes = num_classes self.widths = widths self.head_dim = head_dim self.num_channels = num_channels self.num_stages = num_stages self.depths = depths self.strides = strides self.hidden_sizes = hidden_sizes self.patch_size = patch_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.classifier_dropout_prob = classifier_dropout_prob self.layer_norm_eps = layer_norm_eps self.decoder_hidden_size = decoder_hidden_size self.decoder_layer_hidden_size = decoder_layer_hidden_size self.semantic_loss_ignore_index = semantic_loss_ignore_index self.initializer_range = initializer_range ================================================ FILE: surya/detection/model/encoderdecoder.py ================================================ """ This is an implementation of efficientvit, with some modifications (decode head, etc). Original paper at https://arxiv.org/abs/2205.14756 Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit License: Apache 2 """ from __future__ import annotations from typing import Optional, Union, Tuple, List, Any from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_outputs import SemanticSegmenterOutput from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.detection.model.config import EfficientViTConfig def val2list(x: Union[List, Tuple, Any], repeat_time=1): if isinstance(x, (list, tuple)): return list(x) return [x for _ in range(repeat_time)] def val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1): # repeat elements if necessary x = val2list(x) if len(x) > 0: x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] return tuple(x) def get_same_padding( kernel_size: Union[int, Tuple[int, ...]], ) -> Union[int, Tuple[int, ...]]: if isinstance(kernel_size, tuple): return tuple([get_same_padding(ks) for ks in kernel_size]) else: assert kernel_size % 2 > 0, "kernel size should be odd number" return kernel_size // 2 def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding class ConvNormAct(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, dilation=1, groups=1, bias=False, dropout=0.0, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, ): super(ConvNormAct, self).__init__() self.dropout = nn.Dropout(dropout, inplace=False) padding = get_padding(kernel_size, stride, dilation) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, padding=padding, ) self.norm = ( norm_layer(num_features=out_channels) if norm_layer else nn.Identity() ) self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): x = self.conv(x) x = self.norm(x) x = self.act(x) return x class DSConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(DSConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) self.depth_conv = ConvNormAct( in_channels, in_channels, kernel_size, stride, groups=in_channels, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.point_conv = ConvNormAct( in_channels, out_channels, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.depth_conv(x) x = self.point_conv(x) return x class ConvBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(ConvBlock, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.conv1 = ConvNormAct( in_channels, mid_channels, kernel_size, stride, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.conv2 = ConvNormAct( mid_channels, out_channels, kernel_size, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x class MBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, nn.ReLU6, None), ): super(MBConv, self).__init__() use_bias = val2tuple(use_bias, 3) norm_layer = val2tuple(norm_layer, 3) act_layer = val2tuple(act_layer, 3) mid_channels = mid_channels or round(in_channels * expand_ratio) self.inverted_conv = ConvNormAct( in_channels, mid_channels, 1, stride=1, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.depth_conv = ConvNormAct( mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) self.point_conv = ConvNormAct( mid_channels, out_channels, 1, norm_layer=norm_layer[2], act_layer=act_layer[2], bias=use_bias[2], ) def forward(self, x): x = self.inverted_conv(x) x = self.depth_conv(x) x = self.point_conv(x) return x class FusedMBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, groups=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(FusedMBConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.spatial_conv = ConvNormAct( in_channels, mid_channels, kernel_size, stride=stride, groups=groups, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.point_conv = ConvNormAct( mid_channels, out_channels, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.spatial_conv(x) x = self.point_conv(x) return x class LiteMLA(nn.Module): """Lightweight multi-scale linear attention""" def __init__( self, in_channels: int, out_channels: int, heads: Union[int, None] = None, heads_ratio: float = 1.0, dim=8, use_bias=False, norm_layer=(None, nn.BatchNorm2d), act_layer=(None, None), kernel_func=nn.ReLU, scales=(5,), eps=1e-5, ): super(LiteMLA, self).__init__() self.eps = eps heads = heads or int(in_channels // dim * heads_ratio) total_dim = heads * dim use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) self.dim = dim self.qkv = ConvNormAct( in_channels, 3 * total_dim, 1, bias=use_bias[0], norm_layer=norm_layer[0], act_layer=act_layer[0], ) self.aggreg = nn.ModuleList( [ nn.Sequential( nn.Conv2d( 3 * total_dim, 3 * total_dim, scale, padding=get_same_padding(scale), groups=3 * total_dim, bias=use_bias[0], ), nn.Conv2d( 3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], ), ) for scale in scales ] ) self.kernel_func = kernel_func(inplace=False) self.proj = ConvNormAct( total_dim * (1 + len(scales)), out_channels, 1, bias=use_bias[1], norm_layer=norm_layer[1], act_layer=act_layer[1], ) def _attn(self, q, k, v): dtype = v.dtype q, k, v = q.float(), k.float(), v.float() kv = k.transpose(-1, -2) @ v out = q @ kv out = out[..., :-1] / (out[..., -1:] + self.eps) return out.to(dtype) def forward(self, x): # Shape is B, C, H, W B, _, H, W = x.shape # generate multi-scale q, k, v qkv = self.qkv(x) multi_scale_qkv = [qkv] for op in self.aggreg: multi_scale_qkv.append(op(qkv)) multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose( -1, -2 ) # Shape for each is B, C, HW, head_dim q, k, v = multi_scale_qkv.chunk(3, dim=-1) # lightweight global attention q = self.kernel_func(q) k = self.kernel_func(k) v = F.pad(v, (0, 1), mode="constant", value=1.0) out = self._attn(q, k, v) # final projection out = out.transpose(-1, -2).reshape(B, -1, H, W) out = self.proj(out) return out class EfficientVitBlock(nn.Module): def __init__( self, in_channels, heads_ratio=1.0, head_dim=32, expand_ratio=4, norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, ): super(EfficientVitBlock, self).__init__() self.context_module = ResidualBlock( LiteMLA( in_channels=in_channels, out_channels=in_channels, heads_ratio=heads_ratio, dim=head_dim, norm_layer=(None, norm_layer), ), nn.Identity(), ) self.local_module = ResidualBlock( MBConv( in_channels=in_channels, out_channels=in_channels, expand_ratio=expand_ratio, use_bias=(True, True, False), norm_layer=(None, None, norm_layer), act_layer=(act_layer, act_layer, None), ), nn.Identity(), ) def forward(self, x): x = self.context_module(x) x = self.local_module(x) return x class ResidualBlock(nn.Module): def __init__( self, main: Optional[nn.Module], shortcut: Optional[nn.Module] = None, pre_norm: Optional[nn.Module] = None, ): super(ResidualBlock, self).__init__() self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() self.main = main self.shortcut = shortcut def forward(self, x): res = self.main(self.pre_norm(x)) if self.shortcut is not None: res = res + self.shortcut(x) return res def build_local_block( in_channels: int, out_channels: int, stride: int, kernel_size: int, expand_ratio: float, norm_layer: str, act_layer: str, fewer_norm: bool = False, block_type: str = "default", ): assert block_type in ["default", "large", "fused"] if expand_ratio == 1: if block_type == "default": block = DSConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) else: block = ConvBlock( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) else: if block_type == "default": block = MBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, expand_ratio=expand_ratio, use_bias=(True, True, False) if fewer_norm else False, norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, act_layer, None), ) else: block = FusedMBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, expand_ratio=expand_ratio, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) return block class Stem(nn.Sequential): def __init__( self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type="default", ): super().__init__() self.stride = stride self.add_module( "in_conv", ConvNormAct( in_chs, out_chs, kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer, ), ) stem_block = 0 for _ in range(depth): self.add_module( f"res{stem_block}", ResidualBlock( build_local_block( in_channels=out_chs, out_channels=out_chs, stride=1, kernel_size=3, expand_ratio=1, norm_layer=norm_layer, act_layer=act_layer, block_type=block_type, ), nn.Identity(), ), ) stem_block += 1 class EfficientVitLargeStage(nn.Module): def __init__( self, in_chs, out_chs, depth, stride, norm_layer, act_layer, head_dim, vit_stage=False, fewer_norm=False, ): super(EfficientVitLargeStage, self).__init__() blocks = [ ResidualBlock( build_local_block( in_channels=in_chs, out_channels=out_chs, stride=stride, kernel_size=stride + 1, expand_ratio=24 if vit_stage else 16, norm_layer=norm_layer, act_layer=act_layer, fewer_norm=vit_stage or fewer_norm, block_type="default" if fewer_norm else "fused", ), None, ) ] in_chs = out_chs if vit_stage: # for stage 4 for _ in range(depth): blocks.append( EfficientVitBlock( in_channels=in_chs, head_dim=head_dim, expand_ratio=6, norm_layer=norm_layer, act_layer=act_layer, ) ) else: # for stage 1, 2, 3 for i in range(depth): blocks.append( ResidualBlock( build_local_block( in_channels=in_chs, out_channels=out_chs, stride=1, kernel_size=3, expand_ratio=4, norm_layer=norm_layer, act_layer=act_layer, fewer_norm=fewer_norm, block_type="default" if fewer_norm else "fused", ), nn.Identity(), ) ) self.blocks = nn.Sequential(*blocks) def forward(self, x): return self.blocks(x) class EfficientVitLarge(nn.Module): def __init__( self, config: EfficientViTConfig, norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, ): super(EfficientVitLarge, self).__init__() self.grad_checkpointing = False self.num_classes = config.num_classes self.norm_eps = config.layer_norm_eps norm_layer = partial(norm_layer, eps=self.norm_eps) # input stem self.stem = Stem( config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type="large", ) stride = config.strides[0] # stages self.feature_info = [] self.stages = nn.Sequential() in_channels = config.widths[0] for i, (w, d, s) in enumerate( zip(config.widths[1:], config.depths[1:], config.strides[1:]) ): self.stages.append( EfficientVitLargeStage( in_channels, w, depth=d, stride=s, norm_layer=norm_layer, act_layer=act_layer, head_dim=config.head_dim, vit_stage=i >= 3, fewer_norm=i >= 2, ) ) stride *= s in_channels = w self.feature_info += [ dict(num_chs=in_channels, reduction=stride, module=f"stages.{i}") ] self.num_features = in_channels @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable def forward(self, x): x = self.stem(x) encoder_hidden_states = [] for i, module in enumerate(self.stages): x = module(x) encoder_hidden_states.append(x) return encoder_hidden_states class EfficientViTPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = EfficientViTConfig base_model_prefix = "efficientvit" main_input_name = "pixel_values" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class DecodeMLP(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.proj = nn.Linear(input_dim, output_dim) def forward(self, hidden_states: torch.Tensor): # Input is B, C, H, W hidden_states = hidden_states.flatten(2).transpose(1, 2) # Output is B, HW, C hidden_states = self.proj(hidden_states) return hidden_states class DecodeHead(EfficientViTPreTrainedModel): def __init__(self, config: EfficientViTConfig): super().__init__(config) # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size mlps = [] for width in config.widths[1:]: mlp = DecodeMLP( input_dim=width, output_dim=config.decoder_layer_hidden_size ) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) # the following 3 layers implement the ConvModule of the original implementation self.linear_fuse = nn.Conv2d( in_channels=config.decoder_layer_hidden_size * config.num_stages, out_channels=config.decoder_hidden_size, kernel_size=1, bias=False, ) self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) self.activation = nn.ReLU() self.dropout = nn.Dropout(config.classifier_dropout_prob) self.classifier = nn.Conv2d( config.decoder_hidden_size, config.num_labels, kernel_size=1 ) self.config = config def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: batch_size = encoder_hidden_states[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C # Permute to B, C, HW encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape( batch_size, -1, height, width ) # upsample encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False, ) all_hidden_states += (encoder_hidden_state,) hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) hidden_states = self.batch_norm(hidden_states) hidden_states = self.activation(hidden_states) # logits are of shape (batch_size, num_labels, height/4, width/4) logits = self.classifier(hidden_states) return logits class EfficientViTForSemanticSegmentation( S3DownloaderMixin, EfficientViTPreTrainedModel ): def __init__(self, config, **kwargs): super().__init__(config) self.vit = EfficientVitLarge(config) self.decode_head = DecodeHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor ) -> Union[Tuple, SemanticSegmenterOutput]: # Pixel values should be B,C,H,W encoder_hidden_states = self.vit( pixel_values, ) logits = self.decode_head(encoder_hidden_states) # Apply sigmoid to get 0-1 output logits = torch.special.expit(logits) return SemanticSegmenterOutput( loss=None, logits=logits, hidden_states=encoder_hidden_states ) class EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.vit = EfficientVitLarge(config) self.decode_head = DecodeHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor ) -> Union[Tuple, SemanticSegmenterOutput]: # Pixel values should be B,C,H,W encoder_hidden_states = self.vit( pixel_values, ) logits = self.decode_head(encoder_hidden_states) # Apply sigmoid to get 0-1 output logits = torch.special.expit(logits) return SemanticSegmenterOutput( loss=None, logits=logits, hidden_states=encoder_hidden_states ) ================================================ FILE: surya/detection/parallel.py ================================================ class FakeFuture: def __init__(self, func, *args, **kwargs): self._result = func(*args, **kwargs) def result(self): return self._result class FakeExecutor: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *excinfo): pass def submit(self, fn, *args, **kwargs): return FakeFuture(fn, *args, **kwargs) ================================================ FILE: surya/detection/processor.py ================================================ # coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Modified image processor class for Segformer based on transformers""" import warnings from typing import Any, Dict, List, Optional, Union import numpy as np from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_transforms import to_channel_dimension_format from transformers.image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, ChannelDimension, ImageInput, PILImageResampling, infer_channel_dimension_format, make_list_of_images, ) from transformers.utils import TensorType import PIL.Image import torch from surya.common.s3 import S3DownloaderMixin class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor): r""" Constructs a Segformer image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `(size["height"], size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. do_reduce_labels (`bool`, *optional*, defaults to `False`): Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the `preprocess` method. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_reduce_labels: bool = False, **kwargs, ) -> None: if "reduce_labels" in kwargs: warnings.warn( "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " "`do_reduce_labels` instead.", FutureWarning, ) do_reduce_labels = kwargs.pop("reduce_labels") super().__init__(**kwargs) size = size if size is not None else {"height": 512, "width": 512} size = get_size_dict(size) self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_reduce_labels = do_reduce_labels self._valid_processor_keys = [ "images", "segmentation_maps", "do_resize", "size", "resample", "do_rescale", "rescale_factor", "do_normalize", "image_mean", "image_std", "do_reduce_labels", "return_tensors", "data_format", "input_data_format", ] @classmethod def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): """ Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, reduce_labels=True)` """ image_processor_dict = image_processor_dict.copy() if "reduce_labels" in kwargs: image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") return super().from_dict(image_processor_dict, **kwargs) def _preprocess( self, image: ImageInput, do_resize: bool, do_rescale: bool, do_normalize: bool, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, rescale_factor: Optional[float] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) if do_normalize: image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) return image def _preprocess_image( self, image: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """Preprocesses a single image.""" # All transformations expect numpy arrays. if input_data_format is None: input_data_format = infer_channel_dimension_format(image) image = self._preprocess( image=image, do_resize=do_resize, size=size, resample=resample, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format, ) if data_format is not None: image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) return image def __call__(self, images, segmentation_maps=None, **kwargs): """ Preprocesses a batch of images and optionally segmentation maps. Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be passed in as positional arguments. """ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) def preprocess( self, images: ImageInput, segmentation_maps: Optional[ImageInput] = None, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_reduce_labels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. segmentation_maps (`ImageInput`, *optional*): Segmentation map to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after `resize` is applied. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image values between [0 - 1]. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation. do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_normalize = do_normalize if do_normalize is not None else self.do_normalize resample = resample if resample is not None else self.resample size = size if size is not None else self.size rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std images = make_list_of_images(images) images = [ self._preprocess_image( image=img, do_resize=do_resize, resample=resample, size=size, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, data_format=data_format, input_data_format=input_data_format, ) for img in images ] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) ================================================ FILE: surya/detection/schema.py ================================================ from typing import List, Optional, Any from pydantic import BaseModel from surya.common.polygon import PolygonBox class TextDetectionResult(BaseModel): bboxes: List[PolygonBox] heatmap: Optional[Any] affinity_map: Optional[Any] image_bbox: List[float] ================================================ FILE: surya/detection/util.py ================================================ import math from PIL import ImageOps from surya.settings import settings def get_total_splits(image_size, height): img_height = list(image_size)[1] max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT if img_height > max_height: num_splits = math.ceil(img_height / height) return num_splits return 1 def split_image(img, height): # This will not modify/return the original image - it will either crop, or copy the image img_height = list(img.size)[1] max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT if img_height > max_height: num_splits = math.ceil(img_height / height) splits = [] split_heights = [] for i in range(num_splits): top = i * height bottom = (i + 1) * height if bottom > img_height: bottom = img_height cropped = img.crop((0, top, img.size[0], bottom)) chunk_height = bottom - top if chunk_height < height: cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0)) splits.append(cropped) split_heights.append(chunk_height) return splits, split_heights return [img.copy()], [img_height] ================================================ FILE: surya/foundation/__init__.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import List, Optional, Tuple from collections import deque import cv2 import numpy as np import torch import math from PIL import Image from tqdm import tqdm import torch.nn.functional as F from surya.common.surya import SuryaModelOutput from surya.common.xla import mark_step from surya.common.predictor import BasePredictor from surya.foundation.loader import FoundationModelLoader from surya.foundation.util import ( detect_repeat_token, ) from surya.common.surya.schema import TaskNames from surya.foundation.cache.dynamic_ops import DynamicOpsCache from surya.foundation.cache.static_ops import StaticOpsCache from surya.settings import settings from surya.logging import get_logger, configure_logging configure_logging() logger = get_logger() @dataclass class ContinuousBatchInput: input_ids: torch.Tensor input_boxes: torch.Tensor position_ids: torch.Tensor # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts num_valid_tokens: torch.Tensor # count the number of predicted tokens for each batch element so far num_predicted_tokens: torch.Tensor needs_bbox_embedding: torch.Tensor @dataclass class ContinuousBatchOutput: input_ids: torch.Tensor preds: torch.Tensor bbox_preds: torch.Tensor scores: torch.Tensor token_probs: torch.Tensor @dataclass class FoundationPrompt: id: int task_name: TaskNames image: np.ndarray text: str math_mode: bool class FoundationPredictor(BasePredictor): model_loader_cls = FoundationModelLoader batch_size = ( settings.RECOGNITION_BATCH_SIZE ) # Default to the recognition batch size torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16 default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64} encoder_chunk_size: int = 4096 # Default chunk size encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768} extra_token_count = { "xla": 128 } # We have to pad the XLA cache since we don't use sliding window min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2 min_trim_length: int = 50 tasks = { TaskNames.ocr_with_boxes: { "needs_bboxes": True, "img_size": (1024, 512), "max_tokens": 224, }, TaskNames.ocr_without_boxes: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 224, }, TaskNames.block_without_boxes: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 768, }, TaskNames.layout: { "needs_bboxes": False, "img_size": (1024, 1024), "max_tokens": 200, }, TaskNames.table_structure: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 600, }, } def __init__( self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=None, attention_implementation: Optional[str] = None, ): super().__init__(checkpoint, device, dtype, attention_implementation) self.prompt_queue = deque() self.batch_prompt_mapping = None self.kv_cache = None self.beacon_token_interval = self.model.config.beacon_token_interval # Setup various tokens on-device self.device_pad_token = torch.tensor( self.processor.pad_token_id, device=self.model.device, dtype=torch.long ) self.device_beacon_token = torch.tensor( self.processor.beacon_token_id, device=self.model.device, dtype=torch.long ) self.special_token_ids = torch.tensor( [self.model.config.image_token_id] + self.model.config.register_token_ids, device=self.model.device, ) self.pad_to_multiple = ( settings.FOUNDATION_PAD_TO_NEAREST if settings.FOUNDATION_STATIC_CACHE else None ) def to(self, device_dtype: torch.device | str | None = None): super().to(device_dtype) self.special_token_ids = self.special_token_ids.to(device_dtype) def get_encoder_chunk_size(self) -> int: if settings.FOUNDATION_CHUNK_SIZE is not None: return settings.FOUNDATION_CHUNK_SIZE chunk_size = self.encoder_chunk_size if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL] return chunk_size def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int): kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache self.kv_cache = kv_cache_cls( self.model.config, batch_size, max_cache_len, text_sliding_window=max_sliding_window, device=self.model.device, dtype=self.model.dtype, ) self.prompt_queue.clear() self.batch_prompt_mapping = {i: None for i in range(batch_size)} @property def num_empty_slots(self): return sum(v is None for v in self.batch_prompt_mapping.values()) @property def num_active_slots(self): return len(self.batch_prompt_mapping) - self.num_empty_slots def prepare_input( self, task_names: List[str], images: List[Image.Image], input_text: List[str | None], math_modes: List[bool], ): batch = [] for image, text, task_name, math_mode in zip( images, input_text, task_names, math_modes ): image_size = self.tasks[task_name]["img_size"] try: image = self.processor.scale_to_fit( image, image_size ) # Only resizes if out of bounds (max/min) except cv2.error: # The image is empty if it can't be resized, so just make a blank image image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32) # Task input is the same for all tasks for now text = text or "" # Remove input text that exceeds max generation tokens (likely invalid) if len(text) > self.tasks[task_name]["max_tokens"]: text = "" inputs = [ {"type": "image", "image": image, "rotated": False}, {"type": "text", "text": text.strip(), "math": math_mode}, ] batch.append({"task": task_name, "inputs": inputs}) return batch def process_outputs( self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None ) -> ContinuousBatchOutput: # Predictions are multi-token lm_logits = outputs["lm_logits"].float() # shape: [batch_size, seq_len, V] bbox_logits = outputs["bbox_logits"].float() # shape: [batch_size, seq_len, 6] if ( max_lookahead_tokens is not None and lm_logits.shape[1] > max_lookahead_tokens + 1 ): lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :] bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :] # Get predictions preds = torch.argmax(lm_logits, dim=-1) input_ids = preds.to(torch.long) # Confidence scores for all tokens token_probs = F.softmax(lm_logits, dim=-1) scores = torch.max(token_probs, dim=-1).values # shape: [B, T] # Update input boxes box_preds = bbox_logits * self.model.config.bbox_size box_preds = box_preds.to(torch.long) return ContinuousBatchOutput( input_ids=input_ids, preds=preds, bbox_preds=box_preds, scores=scores, token_probs=token_probs, ) # Always left pad with beacons, don't worry about attention masking def maybe_insert_beacon_tokens( self, input_ids: torch.Tensor, input_boxes: torch.Tensor, num_predicted_tokens: torch.Tensor, num_new_tokens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len = ( input_ids.shape ) # seq_len can be >1 - In case of multi-token predictions # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted** token_positions = num_predicted_tokens + torch.arange( 1, seq_len + 1, device=input_ids.device ).unsqueeze(0) beacon_positions = token_positions % self.beacon_token_interval == 0 # If no beacons needed, return original input needs_beacon = beacon_positions.any(dim=1) # shape: [batch_size] if not needs_beacon.any(): if num_new_tokens is None: num_new_tokens = ( torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * seq_len ) return input_ids, input_boxes, num_new_tokens.squeeze(1) beacon_insert_pos = torch.zeros( batch_size, dtype=torch.long, device=input_ids.device ) for i in range(batch_size): if needs_beacon[i]: # Find first position that needs beacon beacon_insert_pos[i] = torch.where(beacon_positions[i])[0] # Padded input ids. new_input_ids = torch.full( (batch_size, seq_len + 1), self.device_pad_token, dtype=input_ids.dtype, device=input_ids.device, ) new_input_boxes = torch.full( (batch_size, seq_len + 1, 6), -100, dtype=input_boxes.dtype, device=input_boxes.device, ) # Fill in tokens for each sequence for i in range(batch_size): if needs_beacon[i]: insert_pos = beacon_insert_pos[i] new_input_ids[i, insert_pos] = self.device_beacon_token new_input_boxes[i, insert_pos, :] = -100 if insert_pos > 0: new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos] new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos] new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:] new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:] else: new_input_ids[i, 1:] = input_ids[i, :] new_input_boxes[i, 1:] = input_boxes[i, :] # Calculate valid token counts for both padded and non padded sequences valid_token_counts = torch.where( needs_beacon, torch.tensor(seq_len + 1, device=input_ids.device), torch.tensor(seq_len, device=input_ids.device), ) return new_input_ids, new_input_boxes, valid_token_counts def decode( self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None, ): # Note - If we want to use the outputs from the non-last token, we # need to set the cache position manually to ensure causality. The default # behavior only works for the last token currently input_ids = current_inputs.input_ids input_boxes = current_inputs.input_boxes embed_boxes = current_inputs.needs_bbox_embedding position_ids = current_inputs.position_ids num_predicted_tokens = current_inputs.num_predicted_tokens num_valid_tokens = current_inputs.num_valid_tokens batch_size = input_ids.shape[0] # Pre-shift the attention mask based on the cache update self.kv_cache.decode_attention_mask_update( num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size)) ) cache_position = self.get_cache_position( input_ids.shape[1], self.kv_cache.attention_mask, prefill=False ) with settings.INFERENCE_MODE(): outputs = self.model( input_ids=input_ids, attention_mask=self.kv_cache.attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=True, past_key_values=self.kv_cache, prefill=False, num_valid_tokens=num_valid_tokens, input_boxes=input_boxes, embed_boxes=embed_boxes, logits_to_keep=1, ) processed_output: ContinuousBatchOutput = self.process_outputs( outputs, max_lookahead_tokens=max_lookahead_tokens ) input_ids = processed_output.input_ids input_boxes = processed_output.bbox_preds # Update this **before** inserting beacon tokens tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE if max_lookahead_tokens is not None: num_new_tokens = torch.clamp( ( processed_output.scores.ge(tau) .to(torch.long) .cumprod(dim=1) .sum(dim=1, keepdim=True) ), min=1, ) else: num_new_tokens = input_ids.shape[1] num_predicted_tokens += num_new_tokens input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens( input_ids, input_boxes, num_predicted_tokens, num_new_tokens ) position_ids = position_ids[:, -1:] + torch.arange( 1, input_ids.shape[1] + 1, device=input_ids.device ) # Some of the input sequences may now have left padding tokens, so we want to account for that # offset is a per-batch offset of the position_ids offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1) position_ids -= offset new_input = ContinuousBatchInput( input_ids=input_ids, input_boxes=input_boxes, position_ids=position_ids, num_valid_tokens=num_valid_tokens, num_predicted_tokens=num_predicted_tokens, needs_bbox_embedding=current_inputs.needs_bbox_embedding, ) return new_input, processed_output def pad_and_shift_input_ids_position_ids( self, input_ids: torch.Tensor, bbox_preds: torch.Tensor, position_ids: torch.Tensor, new_seq_len: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Pads new_input_ids to match the new seq len with **left padding** and creates updated position_ids Returns: padded_input_ids (torch.Tensor): [batch_size, current_seq_len] updated_position_ids (torch.Tensor): [batch_size, current_seq_len] """ # No padding if new_seq_len == input_ids.shape[1]: return ( input_ids, bbox_preds, position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device), ) pad_len = new_seq_len - input_ids.shape[1] padded_input_ids = torch.nn.functional.pad( input_ids, (pad_len, 0), value=self.device_pad_token ) padded_bbox_preds = torch.nn.functional.pad( bbox_preds, (0, 0, pad_len, 0), value=-100 ) # Since we have **left padding**, offset the new position_ids by the amount of padding # This ensures that the **true tokens** get the correct position_ids # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs updated_position_ids = position_ids[:, -1:] + torch.arange( 1, new_seq_len + 1, device=self.model.device ) updated_position_ids -= pad_len return padded_input_ids, padded_bbox_preds, updated_position_ids def get_cache_position( self, seq_len: int, attention_mask: torch.Tensor, prefill: bool, ): batch_size, target_len = attention_mask.shape base_cache_position = ( torch.arange(seq_len, device=attention_mask.device) .unsqueeze(0) .expand(batch_size, -1) ) if prefill: return base_cache_position # This is a (batch_size) tensor, we can add the seq lens here cache_seqlens = ( attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device) ).argmax(dim=1).to(torch.int32) + 1 # Needs to be unsqueezed so broadcasting works return cache_seqlens.unsqueeze(1) + base_cache_position def prefill( self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None, ): logger.debug(f"Prefilling {self.num_empty_slots} slots") prompts: List[FoundationPrompt] = [ self.prompt_queue.popleft() for _ in range(min(self.num_empty_slots, len(self.prompt_queue))) ] non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None] idxs_to_merge = non_active_idxs[: len(prompts)] for i, prompt in zip(idxs_to_merge, prompts): self.batch_prompt_mapping[i] = prompt.id needs_bbox_embedding = torch.tensor( [ p.task_name in [TaskNames.layout, TaskNames.table_structure] for p in prompts ], dtype=torch.bool, ) batch_input = self.prepare_input( task_names=[p.task_name for p in prompts], images=[p.image for p in prompts], input_text=[p.text for p in prompts], math_modes=[ p.math_mode for p in prompts ], # Pass math mode to the processor ) processed_inputs = self.processor( batch_input, padding_side="left", device=self.model.device, pad_to_multiple=self.pad_to_multiple, ) input_ids = processed_inputs["input_ids"].to(dtype=torch.long) attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long) position_ids = processed_inputs["position_ids"].to(dtype=torch.long) valid_batch_size = len(idxs_to_merge) # Keep these off device until later image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype) grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long) if settings.FOUNDATION_STATIC_CACHE: input_ids = self.pad_to_batch_size( input_ids, batch_size=self.kv_cache.max_batch_size ) attention_mask = self.pad_to_batch_size( attention_mask, batch_size=self.kv_cache.max_batch_size ) position_ids = self.pad_to_batch_size( position_ids, batch_size=self.kv_cache.max_batch_size ) needs_bbox_embedding = self.pad_to_batch_size( needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size ) # Move to device after padding input_ids = input_ids.to(device=self.model.device) attention_mask = attention_mask.to(device=self.model.device) position_ids = position_ids.to(device=self.model.device) needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device) # Find text lengths of each # Oddly, this is optimal on GPU - causes a 30% slowdown if "optimized" # Be very careful with the type and device of this - can cause # a big slowdown if put on device is_special = ( (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu() ) # (batch, seq_len) text_lengths = [] for i in range(input_ids.shape[0]): special_positions = is_special[i].nonzero(as_tuple=True)[0] if len(special_positions) > 0: # Assuming special tokens are contiguous at the start prefix_len = special_positions[-1].item() + 1 else: prefix_len = 0 text_lengths.append(input_ids.shape[1] - prefix_len) text_lengths = torch.tensor(text_lengths, dtype=torch.long) cache_position = self.get_cache_position( input_ids.shape[1], attention_mask, prefill=True ) with settings.INFERENCE_MODE(): image_embeddings = self.model.get_image_embeddings( pixel_values=image_tiles, grid_thw=grid_thw, encoder_chunk_size=self.get_encoder_chunk_size(), valid_batch_size=valid_batch_size, max_batch_size=self.kv_cache.max_batch_size, ) mark_step() outputs = self.model( input_ids=input_ids, image_embeddings=image_embeddings, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, inputs_embeds=None, past_key_values=self.kv_cache, use_cache=True, encoder_chunk_size=self.get_encoder_chunk_size(), cache_idxs=idxs_to_merge, prefill=True, num_valid_tokens=None, # Not required during prefill text_lengths=text_lengths, valid_batch_size=valid_batch_size, logits_to_keep=1, ) # Process outputs processed_outputs = self.process_outputs( outputs, max_lookahead_tokens=max_lookahead_tokens ) # Multi-token prediction predicted_tokens = processed_outputs.input_ids.shape[1] num_valid_tokens = ( torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long) * predicted_tokens ) num_predicted_tokens = ( torch.ones( (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long ) * predicted_tokens ) self.kv_cache.prefill_attention_mask_update( attention_mask, idxs_to_merge, valid_batch_size, text_lengths ) self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths) full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size # If full batch, then we can ignore current_inputs if current_inputs is None or full_batch: new_seq_len = processed_outputs.input_ids.shape[1] # No padding tokens - So we can safely set position_ids this way position_ids = position_ids[:, -1:] + torch.arange( 1, new_seq_len + 1, device=position_ids.device ) new_input = ContinuousBatchInput( input_ids=processed_outputs.input_ids, input_boxes=processed_outputs.bbox_preds, position_ids=position_ids, num_valid_tokens=num_valid_tokens, num_predicted_tokens=num_predicted_tokens, needs_bbox_embedding=needs_bbox_embedding, ) return ( new_input, processed_outputs, range(processed_outputs.input_ids.shape[0]), ) # Merging inputs for next steps current_input_ids = current_inputs.input_ids current_position_ids = current_inputs.position_ids current_input_boxes = current_inputs.input_boxes current_needs_bbox_embedding = current_inputs.needs_bbox_embedding assert current_input_ids.shape[1] == current_position_ids.shape[1] input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids( processed_outputs.input_ids, processed_outputs.bbox_preds, position_ids, new_seq_len=current_input_ids.shape[1], ) current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size] current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size] current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size] current_num_valid_tokens = current_inputs.num_valid_tokens current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size] current_num_predicted_tokens = current_inputs.num_predicted_tokens current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[ :valid_batch_size ] current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[ :valid_batch_size ] new_input = ContinuousBatchInput( input_ids=current_input_ids, input_boxes=current_input_boxes, position_ids=current_position_ids, num_valid_tokens=current_num_valid_tokens, num_predicted_tokens=current_num_predicted_tokens, needs_bbox_embedding=current_needs_bbox_embedding, ) return new_input, processed_outputs, idxs_to_merge def get_max_image_token_count( self, images: list[np.ndarray], tasks: List[TaskNames] ) -> int: def compute_scaled_size( H: int, W: int, max_size: Tuple[int, int] ) -> Tuple[int, int]: max_W, max_H = max_size min_W, min_H = (168, 168) current_pixels = H * W max_pixels = max_H * max_W min_pixels = min_H * min_W current_pixels = max(1, current_pixels) # Avoid zero division if current_pixels > max_pixels: scale = (max_pixels / current_pixels) ** 0.5 return math.floor(H * scale), math.floor(W * scale) elif current_pixels < min_pixels: scale = (min_pixels / current_pixels) ** 0.5 return math.ceil(H * scale), math.ceil(W * scale) return H, W def get_tile_count(H: int, W: int, factor: int) -> int: H_bar = math.ceil(H / factor) * factor W_bar = math.ceil(W / factor) * factor grid_h = H_bar / self.processor.patch_size grid_w = W_bar // self.processor.patch_size return grid_h * grid_w max_tokens = 0 factor = self.processor.patch_size * self.processor.merge_size for image, task in zip(images, tasks): H, W = image.shape[:2] max_size = self.tasks[task]["img_size"] scaled_H, scaled_W = compute_scaled_size(H, W, max_size) token_count = get_tile_count(scaled_H, scaled_W, factor) / ( self.processor.merge_size**2 ) max_tokens = max(max_tokens, token_count) # Extra 10 to account for EOS/BOS/Rotation token etc. return 10 + self.processor.num_register_tokens + int(max_tokens) def prediction_loop( self, images: List[np.ndarray], input_texts: List[str], task_names: List[TaskNames], batch_size: int | None = None, max_tokens: int | None = None, max_sliding_window: int | None = None, math_mode: bool = True, drop_repeated_tokens: bool = True, max_lookahead_tokens: Optional[int] = None, top_k: int = 0, tqdm_desc: str = "Recognizing Text" ) -> tuple: allowed_tasks = self.tasks.keys() assert all([task_name in allowed_tasks for task_name in task_names]), ( f"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}" ) predicted_tokens = [[] for _ in range(len(images))] scores = [[] for _ in range(len(images))] topk_probs = [[] for _ in range(len(images))] if batch_size is None: batch_size = self.get_batch_size() batch_size = min(len(images), batch_size) current_inputs = None max_image_tokens = self.get_max_image_token_count(images, task_names) if max_sliding_window is None: max_sliding_window = self.model.config.sliding_window self.setup_cache( batch_size, max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0), max_sliding_window=max_sliding_window, ) batch_max_tokens = {} for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)): self.prompt_queue.append( FoundationPrompt( id=idx, task_name=task, text=txt, image=img, math_mode=math_mode ) ) batch_max_tokens[idx] = ( max_tokens or settings.FOUNDATION_MAX_TOKENS or self.tasks[task]["max_tokens"] ) overall_max_tokens = max(batch_max_tokens.values()) pbar = tqdm( total=len(self.prompt_queue), desc=tqdm_desc, disable=self.disable_tqdm, ) batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6) batch_pos = [0] * len(images) while self.prompt_queue or self.num_active_slots > 0: if ( self.num_empty_slots / batch_size ) >= self.min_prefill_ratio and self.prompt_queue: updated_inputs, outputs, merge_idxs = self.prefill( current_inputs, max_lookahead_tokens=0 ) predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() bbox_preds_cpu = outputs.bbox_preds.cpu() if top_k > 0: batch_top_k_probs, batch_top_k_indices = torch.topk( outputs.token_probs, k=top_k, dim=-1 ) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() for temp_idx, b_idx in enumerate(merge_idxs): if self.batch_prompt_mapping[b_idx] is not None: p_idx = self.batch_prompt_mapping[b_idx] seq_len = predicted_tokens_cpu.shape[1] for t_idx in range(seq_len): token = predicted_tokens_cpu[temp_idx, t_idx].item() predicted_tokens[p_idx].append(token) batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ temp_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[temp_idx, t_idx].item()) if top_k > 0: top_k_scores = { batch_top_k_indices_cpu[temp_idx, t_idx][ k ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ k ].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) if token in [ self.processor.eos_token_id, self.processor.no_output_token, ]: self.batch_prompt_mapping[b_idx] = None pbar.update(1) break else: updated_inputs, outputs = self.decode( current_inputs, max_lookahead_tokens=max_lookahead_tokens ) mark_step() predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() bbox_preds_cpu = outputs.bbox_preds.cpu() if top_k > 0: batch_top_k_probs, batch_top_k_indices = torch.topk( outputs.token_probs, k=top_k, dim=-1 ) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() for b_idx, p_idx in self.batch_prompt_mapping.items(): if p_idx is not None: seq_len = predicted_tokens_cpu.shape[1] num_tokens = updated_inputs.num_valid_tokens[b_idx].item() should_stop = False for t_idx in range(seq_len): # don't use multitoken prediction for lower confidence tokens if t_idx > 0 and num_tokens < seq_len: # roll so tokens are right aligned updated_inputs.input_ids[b_idx] = ( updated_inputs.input_ids[b_idx].roll( shifts=seq_len - num_tokens, dims=0 ) ) # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens) break token = predicted_tokens_cpu[b_idx, t_idx].item() predicted_tokens[p_idx].append(token) batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ b_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[b_idx, t_idx].item()) if top_k > 0: top_k_scores = { batch_top_k_indices_cpu[temp_idx, t_idx][ k ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ k ].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[ p_idx ] or ( drop_repeated_tokens and detect_repeat_token(predicted_tokens[p_idx]) and task_names[p_idx] in [ TaskNames.ocr_with_boxes, TaskNames.ocr_without_boxes, ] ) if ( token in [ self.processor.eos_token_id, self.processor.pad_token_id, ] or repeats ): should_stop = True break if should_stop: self.batch_prompt_mapping[b_idx] = None pbar.update(1) # Update inputs and mark XLA step current_inputs = updated_inputs pbar.close() del self.kv_cache self.kv_cache = None torch.cuda.empty_cache() return predicted_tokens, batch_bboxes, scores, topk_probs ================================================ FILE: surya/foundation/cache/__init__.py ================================================ ================================================ FILE: surya/foundation/cache/dynamic_ops.py ================================================ from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PretrainedConfig """ Special cache class for the surya foundation model that supports - 1) Static shape 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 3) Continuous batching - merging etc 4) Attention mask management - To match with what's currently in the cache Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 """ class DynamicOpsCache: def __init__( self, config: PretrainedConfig, batch_size: int, max_cache_len: int, text_sliding_window: int, device: int, dtype: int, ): self.text_sliding_window = text_sliding_window self.num_layers = config.num_hidden_layers self.max_batch_size = batch_size self.max_cache_len = max_cache_len self.head_dim = ( getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] cache_shape = ( self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim, ) device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.attention_mask = torch.zeros( (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long ) self.text_token_counts = [ torch.zeros(self.max_batch_size, dtype=torch.long, device=device) for _ in range(self.num_layers) ] self.dtype = dtype self.device = device def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = cache_kwargs.get("prefill", False) update_fn = self._prefill_update if prefill else self._decode_update return update_fn( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, value_states, self.text_token_counts[layer_idx], cache_kwargs, ) def update_text_counts( self, merge_idxs: torch.Tensor, valid_batch_size: torch.Tensor, new_text_lens: torch.Tensor, ): new_text_len_tensor = new_text_lens.to(device=self.device) for layer_idx in range(self.num_layers): self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[ :valid_batch_size ] # Mirrors the logic from _prefill_update # Logic is better explained in this funcrtion def prefill_attention_mask_update( self, prefill_attention_mask: torch.Tensor, merge_idxs: torch.Tensor, valid_batch_mask: torch.Tensor, text_lengths: List[int], ): seq_len = prefill_attention_mask.shape[1] sliding_window = self.text_sliding_window total_cache_len = self.max_cache_len prefix_cache_space = total_cache_len - sliding_window for batch_idx, cache_idx in enumerate(merge_idxs): text_len = text_lengths[batch_idx] prefix_len = seq_len - text_len self.attention_mask[cache_idx] = 0 # Set default assert prefix_len > 0, "There are no prefix (image) tokens!" end_pos = prefix_cache_space # Handle prefix part - Which may be left padded if prefix_len <= prefix_cache_space: start_pos = prefix_cache_space - prefix_len self.attention_mask[cache_idx, start_pos:end_pos] = ( prefill_attention_mask[batch_idx, :prefix_len] ) else: self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[ batch_idx, prefix_len - prefix_cache_space : prefix_len ] # Handle text part, keeping sliding window in consideration # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here if text_len > 0: text_cache_start = prefix_cache_space if text_len <= sliding_window: self.attention_mask[ cache_idx, text_cache_start : text_cache_start + text_len ] = 1 else: self.attention_mask[cache_idx, -sliding_window:] = 1 # Slow impl for now - Prefill time is dominated by the large sequence length forward pass def _prefill_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ): cache_idxs: List[int] = cache_kwargs.get("cache_idxs", None) text_lengths: List[int] = cache_kwargs.get("text_lengths", None) assert cache_idxs is not None, "cache_idxs must be specified during prefill" assert text_lengths is not None, "text_lengths must be specified during prefill" _, _, seq_len, _ = key_states.shape total_cache_len = self.max_cache_len sliding_window = self.text_sliding_window prefix_cache_space = total_cache_len - sliding_window for batch_idx, cache_idx in enumerate(cache_idxs): text_len = text_lengths[batch_idx] prefix_len = seq_len - text_len ###### Handle Image Tokens (Prefix) ##### # Place image tokens in appropriate cache space, aligned to the **right edge** assert prefix_len > 0, "There are no prefix (image) tokens!" # prefix_len may be greater than the prefix cache space due to left padding - This happens when # a different batch element has a large input text during prefill, causing others to have a lot of # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding end_pos = prefix_cache_space if prefix_len <= prefix_cache_space: start_pos = prefix_cache_space - prefix_len key_cache[cache_idx, :, start_pos:end_pos] = key_states[ batch_idx, :, :prefix_len ] value_cache[cache_idx, :, start_pos:end_pos] = value_states[ batch_idx, :, :prefix_len ] else: key_cache[cache_idx, :, :end_pos] = key_states[ batch_idx, :, prefix_len - prefix_cache_space : prefix_len ] value_cache[cache_idx, :, :end_pos] = value_states[ batch_idx, :, prefix_len - prefix_cache_space : prefix_len ] ###### Handle Text Tokens ##### # Text tokens start at the **left edge** of sliding window cache space if text_len > 0: text_cache_start = prefix_cache_space if text_len <= sliding_window: key_cache[ cache_idx, :, text_cache_start : text_cache_start + text_len ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len] value_cache[ cache_idx, :, text_cache_start : text_cache_start + text_len ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len] else: start_in_text = text_len - sliding_window key_cache[ cache_idx, :, text_cache_start : text_cache_start + sliding_window, ] = key_states[ batch_idx, :, prefix_len + start_in_text : prefix_len + text_len ] value_cache[ cache_idx, :, text_cache_start : text_cache_start + sliding_window, ] = value_states[ batch_idx, :, prefix_len + start_in_text : prefix_len + text_len ] # Return the full key/value states (not just cached) for use in subsequent layers return key_states, value_states # """ # Matches the logic of the decode update, but needs to be called before the updates # since some parts of the model depend on the attention mask # """ def decode_attention_mask_update( self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] ): sliding_window = self.text_sliding_window text_cache_start = self.max_cache_len - sliding_window # Using text_token_counts of first layer, should be same for all though current_text_lens = self.text_token_counts[0] cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device) # Get current text lengths for the relevant cache indices current_lens = current_text_lens[cache_idxs_tensor] new_text_lens = current_lens + num_valid_tokens is_full = new_text_lens > sliding_window # Handle full caches - set entire sliding window to 1 if is_full.any(): full_mask = is_full full_cache_idxs = cache_idxs_tensor[full_mask] self.attention_mask[full_cache_idxs, text_cache_start:] = 1 # Handle non-full caches - set specific ranges to 1 if (~is_full).any(): non_full_mask = ~is_full non_full_cache_idxs = cache_idxs_tensor[non_full_mask] non_full_current_lens = current_lens[non_full_mask] non_full_valid_tokens = num_valid_tokens[non_full_mask] max_valid_tokens = ( non_full_valid_tokens.max().item() if len(non_full_valid_tokens) > 0 else 0 ) if max_valid_tokens > 0: batch_size = len(non_full_cache_idxs) offset_range = torch.arange( max_valid_tokens, device=current_text_lens.device ) batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1) start_positions = non_full_current_lens.unsqueeze(1) valid_token_counts = non_full_valid_tokens.unsqueeze(1) position_indices = start_positions + batch_offsets valid_mask = batch_offsets < valid_token_counts row_indices = non_full_cache_idxs.unsqueeze(1).expand( -1, max_valid_tokens )[valid_mask] col_indices = text_cache_start + position_indices[valid_mask] self.attention_mask[row_indices, col_indices] = 1 """ Static cache update - respects per-batch text token limits - per-batch valid token lengths (right-padded inputs) kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim] They may have different `true` lengths, to account for multi token preds, or beacon tokens Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number of actual (non-padded) tokens to add per batch element. """ def _decode_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_valid_tokens: torch.Tensor = cache_kwargs.get( "num_valid_tokens" ) # shape: (B,) assert num_valid_tokens is not None, ( "`num_valid_tokens` must be provided in `cache_kwargs`" ) device = key_states.device batch_size, num_head, seq_len, head_dim = key_states.shape sliding_window = self.text_sliding_window max_cache_len = self.max_cache_len cache_text_start = max_cache_len - sliding_window new_text_lengths = text_token_counts + num_valid_tokens slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0) needs_rotate = slide_amounts > 0 # Rotate the cache if needed if torch.any(needs_rotate): k_slice = key_cache[:, :, -sliding_window:] # shape: [B, H, W, D] v_slice = value_cache[:, :, -sliding_window:] # same shape cache_indices = ( torch.arange(sliding_window, device=device) .unsqueeze(0) .repeat(batch_size, 1) ) # [B, W] rolled_indices = ( cache_indices + slide_amounts.unsqueeze(1) ) % sliding_window # [B, W] # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice rolled_indices = ( rolled_indices.unsqueeze(1) .unsqueeze(-1) .expand(-1, num_head, -1, head_dim) ) k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices) v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices) key_cache[:, :, -sliding_window:] = k_slice_rolled value_cache[:, :, -sliding_window:] = v_slice_rolled # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence insert_positions = torch.where( needs_rotate, max_cache_len - num_valid_tokens, text_token_counts + cache_text_start, ) max_tokens = num_valid_tokens.max().item() offsets = torch.arange(max_tokens, device=device).unsqueeze(0) # [1, max_T] valid_mask = offsets < num_valid_tokens.unsqueeze(1) # [B, max_T] src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets # [B, max_T] src_indices = src_indices.clamp(max=seq_len - 1) # safety tgt_indices = insert_positions.unsqueeze(1) + offsets # [B, max_T] tgt_indices = tgt_indices.clamp(max=max_cache_len - 1) # safety src_idx_exp = ( src_indices.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) tgt_idx_exp = ( tgt_indices.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) valid_mask_exp = ( valid_mask.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) k_src = torch.gather(key_states, 2, src_idx_exp) v_src = torch.gather(value_states, 2, src_idx_exp) k_src = k_src * valid_mask_exp v_src = v_src * valid_mask_exp # Write into cache key_cache.scatter_(2, tgt_idx_exp, k_src) value_cache.scatter_(2, tgt_idx_exp, v_src) # In-place edit - Mutates text_token_counts += num_valid_tokens text_token_counts.clamp_(max=sliding_window) return key_cache, value_cache # We have a non-uniform cache, so its better to not return it and handle any logic # that requires this ourselves def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: raise NotImplementedError() ================================================ FILE: surya/foundation/cache/static_ops.py ================================================ from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PretrainedConfig from surya.foundation.cache.dynamic_ops import DynamicOpsCache """ Special cache class for the surya foundation model that supports - 1) Static shape 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 3) Continuous batching - merging etc 4) Attention mask management - To match with what's currently in the cache Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 """ class StaticOpsCache(DynamicOpsCache): def __init__( self, config: PretrainedConfig, batch_size: int, max_cache_len: int, text_sliding_window: int, device: int, dtype: int, ): self.text_sliding_window = text_sliding_window self.num_layers = config.num_hidden_layers self.max_batch_size = batch_size self.max_cache_len = max_cache_len self.head_dim = ( getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] cache_shape = ( self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim, ) device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.attention_mask = torch.zeros( (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long ) self.text_token_counts = [ torch.zeros(self.max_batch_size, dtype=torch.long, device=device) for _ in range(self.num_layers) ] self.dtype = dtype self.device = device def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = cache_kwargs.get("prefill", False) update_fn = self._prefill_update if prefill else self._decode_update return update_fn( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, value_states, self.text_token_counts[layer_idx], cache_kwargs, ) def _prefill_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ): cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None) text_lengths: List[int] = cache_kwargs.get("text_lengths", None) assert cache_idxs is not None, "cache_idxs must be specified during prefill" assert text_lengths is not None, "text_lengths must be specified during prefill" cache_idx_length = len(cache_idxs) full_batch = len(cache_idxs) == self.max_batch_size # Insert key and value states at the end of the cache new_tokens = key_states.shape[2] # Direct right-aligned assignment if full_batch: key_cache[:, :, -new_tokens:] = key_states value_cache[:, :, -new_tokens:] = value_states else: key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length] value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length] return key_states, value_states # """ # Matches the logic of the decode update, but needs to be called before the updates # since some parts of the model depend on the attention mask # """ def decode_attention_mask_update( self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] ): max_valid_tokens = num_valid_tokens.max().item() if max_valid_tokens == 0: # If no valid tokens, we don't need to update the attention mask return # Shift the attention mask to the left by max_valid_tokens self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1) self.attention_mask[:, -max_valid_tokens:] = ( 1 # Full attention to all new tokens ) # Mirrors the logic from _prefill_update def prefill_attention_mask_update( self, attention_mask: torch.Tensor, merge_idxs: torch.Tensor, valid_batch_size: torch.Tensor, text_lengths: List[int], ): # Set from -(image_length + text_length) to end to 1 for each batch element seq_len = attention_mask.shape[1] self.attention_mask[merge_idxs] = ( 0 # Reset the attention mask for the current batch elements ) self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size] def _decode_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Naive, always assumes we'll roll by a fixed amount # Needs left padding with beacons to work properly num_valid_tokens: torch.Tensor = cache_kwargs.get( "num_valid_tokens" ) # shape: (B,) assert num_valid_tokens is not None, ( "`num_valid_tokens` must be provided in `cache_kwargs`" ) # (B, H, L, D) valid_tokens = key_states.shape[2] key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2)) value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2)) key_cache[:, :, -valid_tokens:, :] = key_states value_cache[:, :, -valid_tokens:, :] = value_states # In-place edit - Mutates text_token_counts += num_valid_tokens text_token_counts.clamp_(max=self.text_sliding_window) return key_cache, value_cache # The attention mask managed by our kv cache automatically masks the tokens # in the cache, so we can return full length for HF to use in other places # This is mainly utilized in the cache_positions creation def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.max_cache_len ================================================ FILE: surya/foundation/loader.py ================================================ from typing import Optional import torch from transformers.utils import is_flash_attn_2_available from surya.common.load import ModelLoader from surya.common.surya.config import SuryaModelConfig from surya.common.surya import SuryaModel, SuryaXLAModel from surya.common.surya.processor import SuryaOCRProcessor from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer from surya.common.util import is_flash_attn_2_supported from surya.common.xla import get_compile_args from surya.logging import get_logger from surya.settings import settings logger = get_logger() class FoundationModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=None, attention_implementation: Optional[str] = None, ) -> SuryaModel: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports # emulated bf16, but falls back to very slow kernels, especially for SDPA dtype = settings.MODEL_DTYPE_BFLOAT if device == "cuda" and not torch.cuda.is_bf16_supported( including_emulation=False ): # If the device is cuda, we check if bf16 is supported, and if not, we use float16 dtype = settings.MODEL_DTYPE elif dtype == torch.float16: dtype = torch.bfloat16 # Model weights in bfloat16 config = SuryaModelConfig.from_pretrained(self.checkpoint) if attention_implementation is not None: config.decoder._attn_implementation = attention_implementation config.vision_encoder._attn_implementation = attention_implementation elif is_flash_attn_2_available() and is_flash_attn_2_supported(device): config.decoder._attn_implementation = "flash_attention_2" config.vision_encoder._attn_implementation = "flash_attention_2" elif device == "xla": config.decoder._attn_implementation = "sdpa" config.vision_encoder._attn_implementation = "sdpa" else: config.decoder._attn_implementation = "sdpa" config.vision_encoder._attn_implementation = "sdpa" model_cls = SuryaModel if device == "xla": model_cls = SuryaXLAModel config._attn_implementation_autoset = True config.vision_encoder._attn_implementation_autoset = True config.decoder._attn_implementation_autoset = True model = model_cls.from_pretrained( self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True ).to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION: torch._dynamo.config.cache_size_limit = 1000 torch._dynamo.config.suppress_errors = True torch._dynamo.config.specialize_int = False torch._dynamo.config.allow_unspec_int_on_nn_module = True torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.recompile_limit = 32 logger.info( f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = get_compile_args(device) model.vision_encoder = torch.compile(model.vision_encoder, **compile_args) model.decoder = torch.compile(model.decoder, **compile_args) logger.debug( f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}." ) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT ) -> SuryaOCRProcessor: config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint) ocr_tokenizer = SuryaOCRTokenizer( special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint ) processor = SuryaOCRProcessor( ocr_tokenizer=ocr_tokenizer, blank_bbox_token_id=config.blank_bbox_token_id, num_register_tokens=config.num_register_tokens, sequence_length=None, patch_size=config.vision_encoder.patch_size, merge_size=config.vision_encoder.spatial_merge_size, model_device=device, num_beacon_tokens=config.num_beacon_tokens, beacon_token_interval=config.beacon_token_interval, ) return processor ================================================ FILE: surya/foundation/util.py ================================================ from typing import List, Tuple import numpy as np import torch def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40): if len(predicted_tokens) < max_repeats: return False # Detect repeats containing 1 or 2 tokens last_n = predicted_tokens[-max_repeats:] unique_tokens = len(set(last_n)) if unique_tokens > 5: return False return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens] def prediction_to_polygon_batch( pred: torch.Tensor, img_sizes: List[Tuple[int, int]], bbox_scaler, skew_scaler, skew_min=0.001, ): img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to( pred.device ) w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None] h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None] cx = pred[:, :, 0] cy = pred[:, :, 1] width = pred[:, :, 2] height = pred[:, :, 3] x1 = cx - width / 2 y1 = cy - height / 2 x2 = cx + width / 2 y2 = cy + height / 2 skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2) skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2) skew_x[torch.abs(skew_x) < skew_min] = 0 skew_y[torch.abs(skew_y) < skew_min] = 0 polygons_flat = torch.stack( [ x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, y2 - skew_y, ], dim=2, ) batch_size, seq_len, _ = pred.shape polygons = polygons_flat.view(batch_size, seq_len, 4, 2) polygons[:, :, :, 0] *= w_scale polygons[:, :, :, 1] *= h_scale return polygons ================================================ FILE: surya/input/load.py ================================================ from typing import List import PIL from surya.input.processing import open_pdf, get_page_images from surya.logging import get_logger from surya.settings import settings import os import filetype from PIL import Image import json logger = get_logger() def get_name_from_path(path): return os.path.basename(path).split(".")[0] def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): doc = open_pdf(pdf_path) last_page = len(doc) if page_range: assert all([0 <= page < last_page for page in page_range]), ( f"Invalid page range: {page_range}" ) else: page_range = list(range(last_page)) images = get_page_images(doc, page_range, dpi=dpi) doc.close() names = [get_name_from_path(pdf_path) for _ in page_range] return images, names def load_image(image_path): image = Image.open(image_path).convert("RGB") name = get_name_from_path(image_path) return [image], [name] def load_from_file( input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI ): input_type = filetype.guess(input_path) if input_type and input_type.extension == "pdf": return load_pdf(input_path, page_range, dpi=dpi) else: return load_image(input_path) def load_from_folder( folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI ): image_paths = [ os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".") ] image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] images = [] names = [] for path in image_paths: extension = filetype.guess(path) if extension and extension.extension == "pdf": image, name = load_pdf(path, page_range, dpi=dpi) images.extend(image) names.extend(name) else: try: image, name = load_image(path) images.extend(image) names.extend(name) except PIL.UnidentifiedImageError: logger.warning(f"Could not load image {path}") continue return images, names def load_lang_file(lang_path, names): with open(lang_path, "r") as f: lang_dict = json.load(f) return [lang_dict[name].copy() for name in names] ================================================ FILE: surya/input/processing.py ================================================ from typing import List import cv2 import numpy as np import pypdfium2 from PIL import Image from surya.logging import get_logger from surya.settings import settings logger = get_logger() def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]: new_images = [] for image in images: if image.mode != "RGB": image = image.convert("RGB") new_images.append(image) return new_images def open_pdf(pdf_filepath): return pypdfium2.PdfDocument(pdf_filepath) def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI): images = [ doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices ] images = [image.convert("RGB") for image in images] return images def slice_bboxes_from_image(image: np.ndarray, bboxes): lines = [] for bbox in bboxes: bbox = np.array(bbox, dtype=np.int32) bbox = np.clip(bbox, 0, None) # Ensure no negative indices # Ensure bbox is within the image bounds if bbox[3] <= bbox[1]: bbox[3] = bbox[1] + 1 if bbox[2] <= bbox[0]: bbox[2] = bbox[0] + 1 bbox[2] = min(bbox[2], image.shape[1]) bbox[3] = min(bbox[3], image.shape[0]) line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() if line.size == 0: logger.warning(f"Warning: found an empty line with bbox {bbox}") lines.append(line) return lines def slice_polys_from_image(image: np.ndarray, polys): lines = [] for idx, poly in enumerate(polys): lines.append(slice_and_pad_poly(image, poly)) return lines def slice_and_pad_poly(image_array: np.array, coordinates): # Draw polygon onto mask coordinates = [(corner[0], corner[1]) for corner in coordinates] bbox = [ min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates]), ] # We mask out anything not in the polygon cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() height, width = cropped_polygon.shape[:2] coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] # Validate the cropped area if any( [ bbox[3] <= bbox[1] or bbox[2] <= bbox[0], len(coordinates) < 3, height == 0, width == 0, ] ): return cropped_polygon # Pad the area outside the polygon with the pad value try: mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) cv2.fillPoly(mask, [np.int32(coordinates)], 1) mask = np.stack([mask] * 3, axis=-1) cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE except cv2.error as e: logger.warning(f"Warning: issue while processing polygon: {e}") return cropped_polygon ================================================ FILE: surya/layout/__init__.py ================================================ from typing import List from PIL import Image from surya.common.predictor import BasePredictor from surya.layout.schema import LayoutBox, LayoutResult from surya.settings import settings from surya.foundation import FoundationPredictor, TaskNames from surya.foundation.util import prediction_to_polygon_batch from surya.input.processing import convert_if_not_rgb from surya.layout.label import LAYOUT_PRED_RELABEL from surya.common.util import clean_boxes class LayoutPredictor(BasePredictor): batch_size = settings.LAYOUT_BATCH_SIZE default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16} # Override base init - Do not load model def __init__(self, foundation_predictor: FoundationPredictor): self.foundation_predictor = foundation_predictor self.processor = self.foundation_predictor.processor self.bbox_size = self.foundation_predictor.model.config.bbox_size self.tasks = self.foundation_predictor.tasks # Special handling for disable tqdm to pass into foundation predictor # Make sure they are kept in sync @property def disable_tqdm(self) -> bool: return super().disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) self.foundation_predictor.disable_tqdm = bool(value) def __call__( self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5 ) -> List[LayoutResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() if len(images) == 0: return [] images = convert_if_not_rgb(images) images = [self.processor.image_processor(image) for image in images] predicted_tokens, batch_bboxes, scores, topk_scores = ( self.foundation_predictor.prediction_loop( images=images, input_texts=["" for _ in range(len(images))], task_names=[TaskNames.layout for _ in range(len(images))], batch_size=batch_size, max_lookahead_tokens=0, # Do not do MTP for layout top_k=5, max_sliding_window=576, max_tokens=500, tqdm_desc="Recognizing Layout" ) ) image_sizes = [img.shape for img in images] predicted_polygons = prediction_to_polygon_batch( batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2 ) layout_results = [] for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip( images, predicted_tokens, predicted_polygons, scores, topk_scores ): layout_boxes = [] for z, (tok, poly, score, tok_topk) in enumerate( zip(image_tokens, image_polygons, image_scores, image_topk_scores) ): if tok == self.processor.eos_token_id: break predicted_label = self.processor.decode([tok], "layout") label = LAYOUT_PRED_RELABEL.get(predicted_label) if not label: # Layout can sometimes return unknown labels from other objectives continue top_k_dict = {} for k, v in tok_topk.items(): topk_label = self.processor.decode([k], "layout") if topk_label in LAYOUT_PRED_RELABEL: topk_label = LAYOUT_PRED_RELABEL[topk_label] if not topk_label.strip(): continue top_k_dict.update({topk_label: v}) layout_boxes.append( LayoutBox( polygon=poly.tolist(), label=label, position=z, top_k=top_k_dict, confidence=score, ) ) layout_boxes = clean_boxes(layout_boxes) layout_results.append( LayoutResult( bboxes=layout_boxes, image_bbox=[0, 0, image.shape[1], image.shape[0]], ) # Image is numpy array ) assert len(layout_results) == len(images) return layout_results ================================================ FILE: surya/layout/label.py ================================================ LAYOUT_PRED_RELABEL = { "": "PageHeader", "": "PageFooter", "": "Footnote", "": "Picture", "
": "Figure", "": "Text", "": "Caption", "": "ListItem", "": "SectionHeader", "": "Table", "": "TableOfContents", "
": "Form", "": "Equation", "": "Code", "": "Figure", } ================================================ FILE: surya/layout/schema.py ================================================ from typing import Optional, Dict, List from pydantic import BaseModel from surya.common.polygon import PolygonBox class LayoutBox(PolygonBox): label: str position: int top_k: Optional[Dict[str, float]] = None class LayoutResult(BaseModel): bboxes: List[LayoutBox] image_bbox: List[float] sliced: bool = False # Whether the image was sliced and reconstructed ================================================ FILE: surya/logging.py ================================================ import logging import warnings from surya.settings import settings def configure_logging(): logger = get_logger() # Remove any existing handlers to prevent duplicates for handler in logger.handlers[:]: logger.removeHandler(handler) # Add our handler handler = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) # Prevent propagation to parent loggers to avoid double logging logger.propagate = False logger.setLevel(settings.LOGLEVEL) warnings.simplefilter(action="ignore", category=FutureWarning) def get_logger(): return logging.getLogger("surya") ================================================ FILE: surya/models.py ================================================ from typing import Dict import torch from surya.common.predictor import BasePredictor from surya.detection import DetectionPredictor from surya.layout import LayoutPredictor from surya.logging import configure_logging from surya.ocr_error import OCRErrorPredictor from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.table_rec import TableRecPredictor from surya.settings import settings configure_logging() def load_predictors( device: str | torch.device | None = None, dtype: torch.dtype | str | None = None ) -> Dict[str, BasePredictor]: return { "layout": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)), "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), "recognition": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)), "detection": DetectionPredictor(device=device, dtype=dtype), "table_rec": TableRecPredictor(device=device, dtype=dtype), } ================================================ FILE: surya/ocr_error/__init__.py ================================================ import math from typing import List, Optional from tqdm import tqdm from surya.common.predictor import BasePredictor from surya.ocr_error.loader import OCRErrorModelLoader from surya.ocr_error.model.config import ID2LABEL from surya.ocr_error.schema import OCRErrorDetectionResult from surya.settings import settings from surya.common.xla import mark_step class OCRErrorPredictor(BasePredictor): model_loader_cls = OCRErrorModelLoader batch_size = settings.OCR_ERROR_BATCH_SIZE default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32} def __call__(self, texts: List[str], batch_size: Optional[int] = None): return self.batch_ocr_error_detection(texts, batch_size) def batch_ocr_error_detection( self, texts: List[str], batch_size: Optional[int] = None ): if batch_size is None: batch_size = self.get_batch_size() num_batches = math.ceil(len(texts) / batch_size) texts_processed = self.processor( texts, padding="longest", truncation=True, return_tensors="pt" ) predictions = [] for batch_idx in tqdm( range(num_batches), desc="Running OCR Error Detection", disable=self.disable_tqdm, ): start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to( self.model.device ) batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to( self.model.device ) # Pad to batch size current_batch_size = batch_input_ids.shape[0] if settings.OCR_ERROR_STATIC_CACHE: batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) batch_attention_mask = self.pad_to_batch_size( batch_attention_mask, batch_size ) with settings.INFERENCE_MODE(): pred = self.model(batch_input_ids, attention_mask=batch_attention_mask) logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size] predictions.extend(logits) mark_step() return OCRErrorDetectionResult( texts=texts, labels=[ID2LABEL[p] for p in predictions] ) ================================================ FILE: surya/ocr_error/loader.py ================================================ from typing import Optional import torch from surya.common.load import ModelLoader from surya.logging import get_logger from surya.ocr_error.model.config import DistilBertConfig from surya.ocr_error.model.encoder import DistilBertForSequenceClassification from surya.ocr_error.tokenizer import DistilBertTokenizer from surya.settings import settings logger = get_logger() class OCRErrorModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> DistilBertForSequenceClassification: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE config = DistilBertConfig.from_pretrained(self.checkpoint) model = ( DistilBertForSequenceClassification.from_pretrained( self.checkpoint, dtype=dtype, config=config, ) .to(device) .eval() ) if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR: torch._dynamo.config.cache_size_limit = 1 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model = torch.compile(model, **compile_args) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE ) -> DistilBertTokenizer: return DistilBertTokenizer.from_pretrained(self.checkpoint) ================================================ FILE: surya/ocr_error/model/__init__.py ================================================ ================================================ FILE: surya/ocr_error/model/config.py ================================================ from collections import OrderedDict from typing import Mapping from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfig from surya.common.s3 import S3DownloaderMixin ID2LABEL = { 0: 'good', 1: 'bad' } class DistilBertConfig(S3DownloaderMixin, PretrainedConfig): model_type = "distilbert" attribute_map = { "hidden_size": "dim", "num_attention_heads": "n_heads", "num_hidden_layers": "n_layers", } def __init__( self, vocab_size=30522, max_position_embeddings=512, sinusoidal_pos_embds=False, n_layers=6, n_heads=12, dim=768, hidden_dim=4 * 768, dropout=0.1, attention_dropout=0.1, activation="gelu", initializer_range=0.02, qa_dropout=0.1, seq_classif_dropout=0.2, pad_token_id=0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.sinusoidal_pos_embds = sinusoidal_pos_embds self.n_layers = n_layers self.n_heads = n_heads self.dim = dim self.hidden_dim = hidden_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation = activation self.initializer_range = initializer_range self.qa_dropout = qa_dropout self.seq_classif_dropout = seq_classif_dropout super().__init__(**kwargs, pad_token_id=pad_token_id) class DistilBertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: if self.task == "multiple-choice": dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} else: dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ ("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ] ) ================================================ FILE: surya/ocr_error/model/encoder.py ================================================ from __future__ import annotations import math from typing import Optional, Set, List, Tuple, Union, Dict import numpy as np import torch from torch import nn from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss from transformers import apply_chunking_to_forward from transformers.activations import get_activation from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput from transformers.pytorch_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.utils import ( is_flash_attn_greater_or_equal_2_10, ) from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.ocr_error.model.config import DistilBertConfig def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): position_enc = np.array( [ [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos) ] ) out.requires_grad = False out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() class Embeddings(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.dim, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.dim ) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.dropout = nn.Dropout(config.dropout) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False, ) def forward( self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Parameters: input_ids (torch.Tensor): torch.tensor(bs, max_seq_length) The token ids to embed. input_embeds (*optional*, torch.Tensor): The pre-computed word embeddings. Can only be passed if the input ids are `None`. Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type embeddings) """ if input_ids is not None: input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) seq_length = input_embeds.size(1) # Setting the position-ids to the registered buffer in constructor, it helps # when tracing the model without passing position-ids, solves # isues similar to issue #5664 if hasattr(self, "position_ids"): position_ids = self.position_ids[:, :seq_length] else: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device ) # (max_seq_length) position_ids = position_ids.unsqueeze(0).expand_as( input_ids ) # (bs, max_seq_length) position_embeddings = self.position_embeddings( position_ids ) # (bs, max_seq_length, dim) embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) return embeddings class MultiHeadSelfAttention(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.config = config self.n_heads = config.n_heads self.dim = config.dim self.dropout = nn.Dropout(p=config.attention_dropout) self.is_causal = False # Have an even number of multi heads that divide the dimensions if self.dim % self.n_heads != 0: # Raise value errors for even multi-head attention nodes raise ValueError( f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly" ) self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.pruned_heads: Set[int] = set() self.attention_head_size = self.dim // self.n_heads def prune_heads(self, heads: List[int]): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.n_heads, self.attention_head_size, self.pruned_heads ) # Prune linear layers self.q_lin = prune_linear_layer(self.q_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index) self.v_lin = prune_linear_layer(self.v_lin, index) self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = self.attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: query: torch.tensor(bs, seq_length, dim) key: torch.tensor(bs, seq_length, dim) value: torch.tensor(bs, seq_length, dim) mask: torch.tensor(bs, seq_length) Returns: weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` """ bs, q_length, dim = query.size() k_length = key.size(1) # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # assert key.size() == value.size() dim_per_head = self.dim // self.n_heads mask_reshp = (bs, 1, 1, k_length) def shape(x: torch.Tensor) -> torch.Tensor: """separate heads""" return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) def unshape(x: torch.Tensor) -> torch.Tensor: """group heads""" return ( x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) ) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = ( (mask == 0).view(mask_reshp).expand_as(scores) ) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill( mask, torch.tensor(torch.finfo(scores.dtype).min) ) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax( scores, dim=-1 ) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) # Mask heads if we want to if head_mask is not None: weights = weights * head_mask context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) context = unshape(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim) if output_attentions: return (context, weights) else: return (context,) class DistilBertFlashAttention2(MultiHeadSelfAttention): """ DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: query: torch.tensor(bs, seq_length, dim) key: torch.tensor(bs, seq_length, dim) value: torch.tensor(bs, seq_length, dim) mask: torch.tensor(bs, seq_length) Returns: weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` """ batch_size, q_length, dim = query.size() dim_per_head = self.dim // self.n_heads def reshape(x: torch.Tensor) -> torch.Tensor: """separate heads""" return x.view(batch_size, -1, self.n_heads, dim_per_head) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim query_states = reshape(self.q_lin(query)) key_states = reshape(self.k_lin(key)) value_states = reshape(self.v_lin(value)) attn_dropout = self.config.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_lin.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_weights = self._flash_attention_forward( query_states, key_states, value_states, mask, q_length, dropout=attn_dropout ) attn_weights_reshaped = attn_weights.reshape( batch_size, q_length, self.n_heads * dim_per_head ) attn_output = self.out_lin(attn_weights_reshaped) if output_attentions: return (attn_output, attn_weights) else: return (attn_output,) # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): from flash_attn.bert_padding import index_first_axis, unpad_input indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class FFN(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.dropout = nn.Dropout(p=config.dropout) self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) self.activation = get_activation(config.activation) def forward(self, input: torch.Tensor) -> torch.Tensor: return apply_chunking_to_forward( self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input ) def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: x = self.lin1(input) x = self.activation(x) x = self.lin2(x) x = self.dropout(x) return x DISTILBERT_ATTENTION_CLASSES = { "eager": MultiHeadSelfAttention, "flash_attention_2": DistilBertFlashAttention2, } class TransformerBlock(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() # Have an even number of Configure multi-heads if config.dim % config.n_heads != 0: raise ValueError( f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly" ) self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation]( config ) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: x: torch.tensor(bs, seq_length, dim) attn_mask: torch.tensor(bs, seq_length) Returns: sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ # Self-Attention sa_output = self.attention( query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions, ) if output_attentions: sa_output, sa_weights = ( sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) ) else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples sa_output = sa_output[0] sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) # Feed Forward Network ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) ffn_output: torch.Tensor = self.output_layer_norm( ffn_output + sa_output ) # (bs, seq_length, dim) output = (ffn_output,) if output_attentions: output = (sa_weights,) + output return output class Transformer(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.n_layers = config.n_layers self.layer = nn.ModuleList( [TransformerBlock(config) for _ in range(config.n_layers)] ) self.gradient_checkpointing = False def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore """ Parameters: x: torch.tensor(bs, seq_length, dim) Input sequence embedded. attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. Returns: hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] Tuple of length n_layers with the hidden states from each layer. Optional: only if output_hidden_states=True all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_state = x for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_state, attn_mask, head_mask[i], output_attentions, ) else: layer_outputs = layer_module( hidden_state, attn_mask, head_mask[i], output_attentions, ) hidden_state = layer_outputs[-1] if output_attentions: if len(layer_outputs) != 2: raise ValueError( f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}" ) attentions = layer_outputs[0] all_attentions = all_attentions + (attentions,) else: if len(layer_outputs) != 1: raise ValueError( f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}" ) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) if not return_dict: return tuple( v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions, ) class DistilBertPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DistilBertConfig load_tf_weights = None base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn_2 = True def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: create_sinusoidal_embeddings( self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight, ) class DistilBertModel(DistilBertPreTrainedModel): def __init__(self, config: DistilBertConfig): super().__init__(config) self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() def get_position_embeddings(self) -> nn.Embedding: """ Returns the position embeddings """ return self.embeddings.position_embeddings def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. Arguments: new_num_position_embeddings (`int`): The number of new position embedding matrix. If position embeddings are learned, increasing the size will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will add correct vectors at the end following the position encoding algorithm, whereas reducing the size will remove vectors from the end. """ num_position_embeds_diff = ( new_num_position_embeddings - self.config.max_position_embeddings ) # no resizing needs to be done if the length stays the same if num_position_embeds_diff == 0: return self.config.max_position_embeddings = new_num_position_embeddings old_position_embeddings_weight = ( self.embeddings.position_embeddings.weight.clone() ) self.embeddings.position_embeddings = nn.Embedding( self.config.max_position_embeddings, self.config.dim ) if self.config.sinusoidal_pos_embds: create_sinusoidal_embeddings( n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight, ) else: with torch.no_grad(): if num_position_embeds_diff > 0: self.embeddings.position_embeddings.weight[ :-num_position_embeds_diff ] = nn.Parameter(old_position_embeddings_weight) else: self.embeddings.position_embeddings.weight = nn.Parameter( old_position_embeddings_weight[:num_position_embeds_diff] ) # move position_embeddings to correct device self.embeddings.position_embeddings.to(self.device) def get_input_embeddings(self) -> nn.Embedding: return self.embeddings.word_embeddings def set_input_embeddings(self, new_embeddings: nn.Embedding): self.embeddings.word_embeddings = new_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.transformer.layer[layer].attention.prune_heads(heads) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) if self._use_flash_attention_2: attention_mask = ( attention_mask if (attention_mask is not None and 0 in attention_mask) else None ) else: if attention_mask is None: attention_mask = torch.ones( input_shape, device=device ) # (bs, seq_length) return self.transformer( x=embeddings, attn_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel): def __init__(self, config: DistilBertConfig, **kwargs): super().__init__(config, **kwargs) self.num_labels = config.num_labels self.config = config self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Linear(config.dim, config.dim) self.classifier = nn.Linear(config.dim, config.num_labels) self.dropout = nn.Dropout(config.seq_classif_dropout) # Initialize weights and apply final processing self.post_init() def get_position_embeddings(self) -> nn.Embedding: """ Returns the position embeddings """ return self.distilbert.get_position_embeddings() def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. Arguments: new_num_position_embeddings (`int`): The number of new position embedding matrix. If position embeddings are learned, increasing the size will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will add correct vectors at the end following the position encoding algorithm, whereas reducing the size will remove vectors from the end. """ self.distilbert.resize_position_embeddings(new_num_position_embeddings) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_state = distilbert_output[0] # (bs, seq_len, dim) pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = nn.ReLU()(pooled_output) # (bs, dim) pooled_output = self.dropout(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, num_labels) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + distilbert_output[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=distilbert_output.hidden_states, attentions=distilbert_output.attentions, ) ================================================ FILE: surya/ocr_error/schema.py ================================================ from typing import List from pydantic import BaseModel class OCRErrorDetectionResult(BaseModel): texts: List[str] labels: List[str] ================================================ FILE: surya/ocr_error/tokenizer.py ================================================ import collections import os import json import unicodedata from typing import List, Optional, Tuple from tokenizers import normalizers from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from surya.common.s3 import S3DownloaderMixin VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} # Copied from transformers.models.bert.tokenization_bert.load_vocab def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() with open(vocab_file, "r", encoding="utf-8") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip("\n") vocab[token] = index return vocab # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer): r""" Construct a DistilBERT tokenizer. Based on WordPiece. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): File containing the vocabulary. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. do_basic_tokenize (`bool`, *optional*, defaults to `True`): Whether or not to do basic tokenization before WordPiece. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True` unk_token (`str`, *optional*, defaults to `"[UNK]"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. sep_token (`str`, *optional*, defaults to `"[SEP]"`): The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens. pad_token (`str`, *optional*, defaults to `"[PAD]"`): The token used for padding, for example when batching sequences of different lengths. cls_token (`str`, *optional*, defaults to `"[CLS]"`): The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens. mask_token (`str`, *optional*, defaults to `"[MASK]"`): The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, strip_accents=None, **kwargs, ): if not os.path.isfile(vocab_file): raise ValueError( f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" ) self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer( do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, never_split=never_split, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, **kwargs, ) @property # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case def do_lower_case(self): return self.basic_tokenizer.do_lower_case @property # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size def vocab_size(self): return len(self.vocab) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: for token in self.basic_tokenizer.tokenize( text, never_split=self.all_special_tokens if not split_special_tokens else None ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) else: split_tokens += self.wordpiece_tokenizer.tokenize(token) else: split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.vocab.get(token, self.vocab.get(self.unk_token)) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" out_string = " ".join(tokens).replace(" ##", "").strip() return out_string # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: - single sequence: `[CLS] X [SEP]` - pair of sequences: `[CLS] A [SEP] B [SEP]` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] cls = [self.cls_token_id] sep = [self.sep_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: ``` 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | ``` If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: index = 0 if os.path.isdir(save_directory): vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) else: vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: # logger.warning( # f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." # " Please check that the vocabulary is not corrupted!" # ) index = token_index writer.write(token + "\n") index += 1 return (vocab_file,) # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer class BasicTokenizer(object): """ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). Args: do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True` tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). do_split_on_punc (`bool`, *optional*, defaults to `True`): In some instances we want to skip the basic punctuation splitting so that later tokenization can capture the full context of the words, such as contractions. """ def __init__( self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None, do_split_on_punc=True, ): if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = set(never_split) self.tokenize_chinese_chars = tokenize_chinese_chars self.strip_accents = strip_accents self.do_split_on_punc = do_split_on_punc def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. Args: never_split (`List[str]`, *optional*) Kept for backward compatibility purposes. Now implemented directly at the base class level (see [`PreTrainedTokenizer.tokenize`]) List of token not to split. """ # union() returns a new set by concatenating the two sets. never_split = self.never_split.union(set(never_split)) if never_split else self.never_split text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't # matter since the English models were not trained on any Chinese data # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) # prevents treating the same character with different unicode codepoints as different characters unicode_normalized_text = unicodedata.normalize("NFC", text) orig_tokens = whitespace_tokenize(unicode_normalized_text) split_tokens = [] for token in orig_tokens: if token not in never_split: if self.do_lower_case: token = token.lower() if self.strip_accents is not False: token = self._run_strip_accents(token) elif self.strip_accents: token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if not self.do_split_on_punc or (never_split is not None and text in never_split): return [text] chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer class WordpieceTokenizer(object): """Runs WordPiece tokenization.""" def __init__(self, vocab, unk_token, max_input_chars_per_word=100): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word def tokenize(self, text): """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary. For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. Args: text: A single token or whitespace separated tokens. This should have already been passed through *BasicTokenizer*. Returns: A list of wordpiece tokens. """ output_tokens = [] for token in whitespace_tokenize(text): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens ================================================ FILE: surya/recognition/__init__.py ================================================ from __future__ import annotations import re from typing import List import numpy as np import torch from PIL import Image import torch.nn.functional as F from surya.common.polygon import PolygonBox from surya.common.surya.processor import NOMATH_TOKEN from surya.common.predictor import BasePredictor from surya.detection import DetectionPredictor from surya.foundation import FoundationPredictor from surya.input.processing import ( convert_if_not_rgb, slice_polys_from_image, slice_bboxes_from_image, ) from surya.recognition.postprocessing import fix_unbalanced_tags from surya.recognition.util import ( sort_text_lines, clean_close_polygons, unwrap_math, clean_math_tags, filter_blacklist_tags, words_from_chars ) from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch from surya.recognition.schema import TextLine, OCRResult, TextChar from surya.common.surya.schema import TaskNames from surya.settings import settings from surya.logging import get_logger, configure_logging configure_logging() logger = get_logger() class RecognitionPredictor(BasePredictor): batch_size = settings.RECOGNITION_BATCH_SIZE default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128} # Override base init - Do not load model def __init__(self, foundation_predictor: FoundationPredictor): self.foundation_predictor = foundation_predictor self.processor = self.foundation_predictor.processor self.bbox_size = self.foundation_predictor.model.config.bbox_size self.tasks = self.foundation_predictor.tasks # Special handling for disable tqdm to pass into foundation predictor # Make sure they are kept in sync @property def disable_tqdm(self) -> bool: return super().disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) self.foundation_predictor.disable_tqdm = bool(value) def detect_and_slice_bboxes( self, images: List[Image.Image], task_names: List[str], det_predictor: DetectionPredictor, detection_batch_size: int | None = None, highres_images: List[Image.Image] | None = None, ): det_predictions = det_predictor(images, batch_size=detection_batch_size) all_slices = [] slice_map = [] all_polygons = [] all_task_names = [] all_res_scales = [] for idx, (det_pred, image, highres_image, task_name) in enumerate( zip(det_predictions, images, highres_images, task_names) ): polygons = [p.polygon for p in det_pred.bboxes] if highres_image: width_scaler = highres_image.size[0] / image.size[0] height_scaler = highres_image.size[1] / image.size[1] scaled_polygons = [ [ [int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon ] for polygon in polygons ] highres_image = self.processor.image_processor(highres_image) slices = slice_polys_from_image(highres_image, scaled_polygons) res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))] else: image = self.processor.image_processor(image) slices = slice_polys_from_image(image, polygons) res_scales = [(1, 1) for _ in range(len(slices))] slice_map.append(len(slices)) all_slices.extend(slices) all_polygons.extend(polygons) all_task_names.extend([task_name] * len(slices)) all_res_scales.extend(res_scales) assert ( len(all_slices) == sum(slice_map) == len(all_polygons) == len(all_task_names) == len(all_res_scales) ) return { "slices": all_slices, "slice_map": slice_map, "polygons": all_polygons, "task_names": all_task_names, "input_text": [None] * len(all_slices), "res_scales": all_res_scales, } def slice_bboxes( self, images: List[Image.Image], task_names: List[str], bboxes: List[List[List[int]]] | None = None, polygons: List[List[List[List[int]]]] | None = None, input_text: List[List[str | None]] | None = None, ) -> dict: assert bboxes is not None or polygons is not None slice_map = [] all_slices = [] all_polygons = [] all_text = [] all_task_names = [] for idx, image in enumerate(images): image = self.processor.image_processor(image) if polygons is not None: polys = polygons[idx] slices = slice_polys_from_image(image, polys) else: slices = slice_bboxes_from_image(image, bboxes[idx]) polys = [ [ [bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]], ] for bbox in bboxes[idx] ] slice_map.append(len(slices)) all_slices.extend(slices) all_polygons.extend(polys) all_task_names.extend([task_names[idx]] * len(slices)) if input_text is None: all_text.extend([None] * len(slices)) else: all_text.extend(input_text[idx]) assert ( len(all_slices) == sum(slice_map) == len(all_polygons) == len(all_text) == len(all_task_names) ), ( f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}" ) return { "slices": all_slices, "slice_map": slice_map, "polygons": all_polygons, "input_text": all_text, "task_names": all_task_names, "res_scales": [(1, 1) for _ in range(len(all_slices))], } def get_bboxes_text( self, flat: dict, predicted_tokens: list, scores: list, predicted_polygons: list, drop_repeated_text: bool = False, ) -> list: char_predictions = [] needs_boxes = [ self.tasks[task_name]["needs_bboxes"] for task_name in flat["task_names"] ] for slice_idx, ( slice_image, image_tokens, image_polygons, image_scores, needs_box, ) in enumerate( zip( flat["slices"], predicted_tokens, predicted_polygons, scores, needs_boxes, ) ): blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]] if self.processor.no_output_token in image_tokens: char_predictions.append(None) continue # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely if drop_repeated_text and detect_repeat_token(image_tokens): char_predictions.append( [ TextChar( text="", polygon=blank_bbox, confidence=0, bbox_valid=False, ) ] ) continue image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist() detokenize_sequences = [] detokenize_sequence = [] past_char_qwen_token = False def _add_detokenize_sequence( special_token: bool, past_special_token: bool, force: bool = False, ): nonlocal detokenize_sequence, detokenize_sequences if ( special_token or past_special_token or force ) and detokenize_sequence: chars = [dt[0] for dt in detokenize_sequence] scores = [dt[1] for dt in detokenize_sequence] bboxes = [dt[2] for dt in detokenize_sequence] if past_special_token: detokenize_sequences.append((chars, scores, None, "special")) else: detokenize_sequences.append((chars, scores, bboxes, "ocr")) detokenize_sequence = [] # Split up into sequences to detokenize separately past_special_token = False for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores): if char_id in [ self.processor.eos_token_id, self.processor.pad_token_id, ]: break special_token = ( char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE ) _add_detokenize_sequence( special_token, past_special_token ) detokenize_sequence.append((char_id, score, bbox)) past_special_token = special_token _add_detokenize_sequence( False, past_special_token, force=True ) img_chars = [] for sequence in detokenize_sequences: token_ids, seq_score, bboxes, token_type = sequence if token_type == "ocr": text = self.processor.ocr_tokenizer.decode( token_ids, task=TaskNames.ocr_with_boxes ) bboxes = clean_close_polygons( bboxes ) # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char bbox_idx = 0 for text_idx, text_line in enumerate(text): img_chars.append( TextChar( text=text_line, polygon=bboxes[bbox_idx], confidence=seq_score[bbox_idx], bbox_valid=True, ) ) # Ensure we don't exceed the bbox count # Use the last bbox for the rest of the text if bbox_idx < len(bboxes) - 1: bbox_idx += 1 elif token_type == "special": text = self.processor.ocr_tokenizer.decode( token_ids, task="ocr_without_boxes" ) if text in [NOMATH_TOKEN] or re.match(r"", text): continue img_chars.append( TextChar( text=text, polygon=blank_bbox, confidence=seq_score[0], bbox_valid=False, ) ) else: text = self.processor.ocr_tokenizer.decode( token_ids, task=TaskNames.block_without_boxes ) img_chars.append( TextChar( text=text, polygon=blank_bbox, confidence=seq_score[0], bbox_valid=False, ) ) char_predictions.append(img_chars) return char_predictions def __call__( self, images: List[Image.Image], task_names: List[str] | None = None, det_predictor: DetectionPredictor | None = None, detection_batch_size: int | None = None, recognition_batch_size: int | None = None, highres_images: List[Image.Image] | None = None, bboxes: List[List[List[int]]] | None = None, polygons: List[List[List[List[int]]]] | None = None, input_text: List[List[str | None]] | None = None, sort_lines: bool = False, math_mode: bool = True, return_words: bool = False, drop_repeated_text: bool = False, max_sliding_window: int | None = None, max_tokens: int | None = None, filter_tag_list: List[str] = None ) -> List[OCRResult]: if task_names is None: task_names = [TaskNames.ocr_with_boxes] * len(images) if recognition_batch_size is None: recognition_batch_size = self.get_batch_size() assert len(images) == len(task_names), ( "You need to pass in one task name for each image" ) images = convert_if_not_rgb(images) if highres_images is not None: assert len(images) == len(highres_images), ( "You need to pass in one highres image for each image" ) highres_images = ( convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images) ) if bboxes is None and polygons is None: assert det_predictor is not None, ( "You need to pass in a detection predictor if you don't provide bboxes or polygons" ) # Detect then slice flat = self.detect_and_slice_bboxes( images, task_names, det_predictor, detection_batch_size=detection_batch_size, highres_images=highres_images, ) else: if bboxes is not None: assert len(images) == len(bboxes), ( "You need to pass in one list of bboxes for each image" ) if polygons is not None: assert len(images) == len(polygons), ( "You need to pass in one list of polygons for each image" ) flat = self.slice_bboxes( images, bboxes=bboxes, polygons=polygons, input_text=input_text, task_names=task_names, ) # No images passed, or no boxes passed, or no text detected in the images if len(flat["slices"]) == 0: return [ OCRResult( text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]] ) for im in images ] # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better sorted_pairs = sorted( enumerate(flat["slices"]), key=lambda x: -(x[1].shape[0] * x[1].shape[1]) # height * width ) indices, sorted_slices = zip(*sorted_pairs) # Reorder input_text and task_names based on the new order flat["slices"] = list(sorted_slices) flat["input_text"] = [flat["input_text"][i] for i in indices] flat["task_names"] = [flat["task_names"][i] for i in indices] # Make predictions predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop( images=flat["slices"], input_texts=flat["input_text"], task_names=flat["task_names"], batch_size=recognition_batch_size, math_mode=math_mode, drop_repeated_tokens=True, max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance, max_sliding_window=max_sliding_window, max_tokens=max_tokens, tqdm_desc="Recognizing Text" ) # Get text and bboxes in structured form bbox_size = self.bbox_size image_sizes = [img.shape for img in flat["slices"]] predicted_polygons = prediction_to_polygon_batch( batch_bboxes, image_sizes, bbox_size, bbox_size // 2 ) char_predictions = self.get_bboxes_text( flat, predicted_tokens, scores, predicted_polygons, drop_repeated_text=drop_repeated_text, ) char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0]) char_predictions = [pred for _, pred in char_predictions] predictions_by_image = [] slice_start = 0 for idx, image in enumerate(images): slice_end = slice_start + flat["slice_map"][idx] image_lines = char_predictions[slice_start:slice_end] polygons = flat["polygons"][slice_start:slice_end] res_scales = flat["res_scales"][slice_start:slice_end] slice_start = slice_end lines = [] for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales): # Special case when input text is good if not text_line: lines.append( TextLine( text="", polygon=polygon, chars=[], confidence=1, original_text_good=True, ) ) else: confidence = ( float(np.mean([char.confidence for char in text_line])) if len(text_line) > 0 else 0 ) poly_box = PolygonBox(polygon=polygon) for char in text_line: char.rescale( res_scale, (1, 1) ) # Rescale from highres if needed char.shift( poly_box.bbox[0], poly_box.bbox[1] ) # Ensure character boxes match line boxes (relative to page) char.clamp(poly_box.bbox) text_line = fix_unbalanced_tags( text_line, self.processor.ocr_tokenizer.special_tokens ) text_line = filter_blacklist_tags(text_line, filter_tag_list) text = "".join([char.text for char in text_line]) text = unwrap_math(text) text = clean_math_tags(text) lines.append( TextLine( text=text, polygon=polygon, chars=text_line, confidence=confidence, words=words_from_chars(text_line, poly_box) if return_words else [], ) ) if sort_lines: lines = sort_text_lines(lines) predictions_by_image.append( OCRResult( text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]] ) ) return predictions_by_image ================================================ FILE: surya/recognition/languages.py ================================================ CODE_TO_LANGUAGE = { "_math": "Math", "af": "Afrikaans", "am": "Amharic", "ar": "Arabic", "as": "Assamese", "az": "Azerbaijani", "be": "Belarusian", "bg": "Bulgarian", "bn": "Bengali", "br": "Breton", "bs": "Bosnian", "ca": "Catalan", "cs": "Czech", "cy": "Welsh", "da": "Danish", "de": "German", "el": "Greek", "en": "English", "eo": "Esperanto", "es": "Spanish", "et": "Estonian", "eu": "Basque", "fa": "Persian", "fi": "Finnish", "fr": "French", "fy": "Western Frisian", "ga": "Irish", "gd": "Scottish Gaelic", "gl": "Galician", "gu": "Gujarati", "ha": "Hausa", "he": "Hebrew", "hi": "Hindi", "hr": "Croatian", "hu": "Hungarian", "hy": "Armenian", "id": "Indonesian", "is": "Icelandic", "it": "Italian", "ja": "Japanese", "jv": "Javanese", "ka": "Georgian", "kk": "Kazakh", "km": "Khmer", "kn": "Kannada", "ko": "Korean", "ku": "Kurdish", "ky": "Kyrgyz", "la": "Latin", "lo": "Lao", "lt": "Lithuanian", "lv": "Latvian", "mg": "Malagasy", "mk": "Macedonian", "ml": "Malayalam", "mn": "Mongolian", "mr": "Marathi", "ms": "Malay", "my": "Burmese", "ne": "Nepali", "nl": "Dutch", "no": "Norwegian", "om": "Oromo", "or": "Oriya", "pa": "Punjabi", "pl": "Polish", "ps": "Pashto", "pt": "Portuguese", "ro": "Romanian", "ru": "Russian", "sa": "Sanskrit", "sd": "Sindhi", "si": "Sinhala", "sk": "Slovak", "sl": "Slovenian", "so": "Somali", "sq": "Albanian", "sr": "Serbian", "su": "Sundanese", "sv": "Swedish", "sw": "Swahili", "ta": "Tamil", "te": "Telugu", "th": "Thai", "tl": "Tagalog", "tr": "Turkish", "ug": "Uyghur", "uk": "Ukrainian", "ur": "Urdu", "uz": "Uzbek", "vi": "Vietnamese", "xh": "Xhosa", "yi": "Yiddish", "zh": "Chinese", } LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} ================================================ FILE: surya/recognition/postprocessing.py ================================================ import re from typing import List, Dict from surya.recognition.schema import TextChar def truncate_repetitions(text: str, min_len=15): # From nougat, with some cleanup if len(text) < 2 * min_len: return text # try to find a length at which the tail is repeating max_rep_len = None for rep_len in range(min_len, int(len(text) / 2)): # check if there is a repetition at the end same = True for i in range(0, rep_len): if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: same = False break if same: max_rep_len = rep_len if max_rep_len is None: return text lcs = text[-max_rep_len:] # remove all but the last repetition text_to_truncate = text while text_to_truncate.endswith(lcs): text_to_truncate = text_to_truncate[:-max_rep_len] return text[: len(text_to_truncate)] def extract_tags(proposed_tags: List[str]) -> List[str]: tags = [] for tag in proposed_tags: tag_match = re.match(tag_pattern, tag) if not tag_match: continue if not tag_match.group(1) == "/": continue tags.append(tag_match.group(2)) return tags tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE) def cleanup_math(line: str): matches = re.finditer(r"(]*>)(.*?)", line, re.DOTALL) result = line for match in matches: opening_tag = match.group(1) # The opening tag with attributes full_match = match.group(0) # The entire content tag block_content = match.group(2) # Just the content inside the tags clean_block = re.sub(r"<[^>]+>", "", block_content) if not re.search(r"[\\\_]", clean_block): result = result.replace(full_match, clean_block) else: result = result.replace(full_match, f"{opening_tag}{clean_block}") return result def fix_unbalanced_tags( text_chars: List[TextChar], special_tokens: Dict[str, list] ) -> List[TextChar]: self_closing_tags = ["br"] open_tags = [] format_tags = extract_tags(special_tokens["formatting"]) + extract_tags( special_tokens["math_external"] ) for char in text_chars: if len(char.text) <= 1: continue tag_match = re.match(tag_pattern, char.text) if not tag_match: continue is_closing = tag_match.group(1) == "/" tag_name = tag_match.group(2).lower() if tag_name not in format_tags: continue if tag_name in self_closing_tags: continue # Self-closing tags if tag_match.group(3) and tag_match.group(3).strip().endswith("/"): continue if is_closing: if open_tags and open_tags[-1] == tag_name: open_tags.pop() else: open_tags.append(tag_name) for tag in open_tags: text_chars.append( TextChar( text=f"", confidence=0, polygon=[[0, 0], [1, 0], [1, 1], [0, 1]], bbox_valid=False, ) ) return text_chars ================================================ FILE: surya/recognition/schema.py ================================================ import math import numpy as np from typing import Optional, List from pydantic import BaseModel, field_validator from surya.common.polygon import PolygonBox class BaseChar(PolygonBox): text: str confidence: Optional[float] = 0 @field_validator("confidence", mode="before") @classmethod def validate_confidence(cls, v: float) -> float: if v is None: return 0 elif math.isnan(v) or np.isnan(v): return 0 return v class TextChar(BaseChar): bbox_valid: bool = True # This is false when the given bbox is not valid class TextWord(BaseChar): bbox_valid: bool = True class TextLine(BaseChar): chars: List[TextChar] # Individual characters in the line original_text_good: bool = False words: List[TextWord] | None = None class OCRResult(BaseModel): text_lines: List[TextLine] image_bbox: List[float] ================================================ FILE: surya/recognition/util.py ================================================ import re from typing import List, Tuple import numpy import torch from surya.common.polygon import PolygonBox from surya.recognition.schema import TextLine, TextWord, TextChar MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"] def unwrap_math(text: str) -> str: if len(text) > 50: return text # Detected as math, but does not contain LaTeX commands if ( re.match(r'^\s*\s*$', text, re.DOTALL) and text.count("", "", text) text = re.sub(r"", "", text) return text MATH_BLOCK = re.compile(r"(]*>)(.*?)", flags=re.I | re.S) STRIP_TAGS = re.compile(r"]*>", flags=re.I | re.S) DEFAULT_TAGS_TO_FILTER = ["p", "li", "ul", "ol", "table", "td", "tr", "th", "tbody", "pre"] def filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]: filtered_chars = [] char_buffer = [] in_tag = False if tags_to_filter is None: tags_to_filter = DEFAULT_TAGS_TO_FILTER for text_char in text_chars: char = text_char.text if char.startswith("<") or in_tag: in_tag = True char_buffer.append(text_char) if char.endswith(">"): full_tag = ''.join(c.text for c in char_buffer) inner = full_tag[1:-1].strip() # remove < > inner = inner.strip("/") # remove '/' # Possible that it is just an empty <> if not inner: filtered_chars.extend(char_buffer) in_tag = False char_buffer = [] continue tag_name_candidate = inner.split()[0] # remove any attributes if tag_name_candidate in tags_to_filter: # Discard tag pass else: # Keep tag filtered_chars.extend(char_buffer) in_tag = False char_buffer = [] else: filtered_chars.append(text_char) # Flush buffer if we never reached a tag close if char_buffer: filtered_chars.extend(char_buffer) return filtered_chars def clean_math_tags(html: str) -> str: # strip unwanted tags inside every well‑formed def _inner(m): inner = STRIP_TAGS.sub("", m.group(2)) return f"{m.group(1)}{inner}" if inner.strip() else "" cleaned = MATH_BLOCK.sub(_inner, html) # drop only orphan *closing* tags depth = 0 parts = [] for token in re.split(r"(]*>)", cleaned, flags=re.I): if token.lower().startswith("": if depth: # keep it only if it matches an open depth -= 1 parts.append(token) # else: skip orphan closing tag else: parts.append(token) return "".join(parts) def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): # Sorts in reading order. Not 100% accurate, this should only # be used as a starting point for more advanced sorting. vertical_groups = {} for line in lines: group_key = ( round( line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance ) * tolerance ) if group_key not in vertical_groups: vertical_groups[group_key] = [] vertical_groups[group_key].append(line) # Sort each group horizontally and flatten the groups into a single list sorted_lines = [] for _, group in sorted(vertical_groups.items()): sorted_group = sorted( group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0] ) sorted_lines.extend(sorted_group) return sorted_lines def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1): if len(bboxes) < 2: return bboxes new_bboxes = [bboxes[0]] for i in range(1, len(bboxes)): close = True prev_bbox = bboxes[i - 1] bbox = bboxes[i] for j in range(4): if ( abs(bbox[j][0] - prev_bbox[j][0]) > thresh or abs(bbox[j][1] - prev_bbox[j][1]) > thresh ): close = False break if not close: new_bboxes.append(bboxes[i]) return new_bboxes def words_from_chars(chars: List[TextChar], line_box: PolygonBox): words = [] word = None for i, char in enumerate(chars): if not char.bbox_valid: if word: words.append(word) word = None continue if not word: word = TextWord(**char.model_dump()) # Fit bounds to line if first word if i == 0: word.merge_left(line_box) elif not char.text.strip(): if word: words.append(word) word = None else: # Merge bboxes word.merge(char) word.text = word.text + char.text if i == len(chars) - 1: word.merge_right(line_box) if word: words.append(word) return words ================================================ FILE: surya/scripts/__init__.py ================================================ ================================================ FILE: surya/scripts/config.py ================================================ from typing import List import click import os from surya.input.load import load_from_folder, load_from_file from surya.settings import settings class CLILoader: def __init__(self, filepath: str, cli_options: dict, highres: bool = False): self.page_range = cli_options.get("page_range") if self.page_range: self.page_range = self.parse_range_str(self.page_range) self.filepath = filepath self.config = cli_options self.save_images = cli_options.get("images", False) self.debug = cli_options.get("debug", False) self.output_dir = cli_options.get("output_dir") self.load(highres) @staticmethod def common_options(fn): fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) return fn def load(self, highres: bool = False): highres_images = None if os.path.isdir(self.filepath): images, names = load_from_folder(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath) if highres: highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) else: images, names = load_from_file(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath).split(".")[0] if highres: highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) self.images = images self.highres_images = highres_images self.names = names self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name)) os.makedirs(self.result_path, exist_ok=True) @staticmethod def parse_range_str(range_str: str) -> List[int]: range_lst = range_str.split(",") page_lst = [] for i in range_lst: if "-" in i: start, end = i.split("-") page_lst += list(range(int(start), int(end) + 1)) else: page_lst.append(int(i)) page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order return page_lst ================================================ FILE: surya/scripts/detect_layout.py ================================================ import time import click import copy import json from collections import defaultdict from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.debug.draw import draw_polys_on_image from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.settings import settings import os configure_logging() logger = get_logger() @click.command(help="Detect layout of an input file or folder (PDFs or image).") @CLILoader.common_options def detect_layout_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs) foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) start = time.time() layout_predictions = layout_predictor(loader.images) if loader.debug: logger.debug(f"Layout took {time.time() - start} seconds") if loader.save_images: for idx, (image, layout_pred, name) in enumerate( zip(loader.images, layout_predictions, loader.names) ): polygons = [p.polygon for p in layout_pred.bboxes] labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes] bbox_image = draw_polys_on_image( polygons, copy.deepcopy(image), labels=labels ) bbox_image.save( os.path.join(loader.result_path, f"{name}_{idx}_layout.png") ) predictions_by_page = defaultdict(list) for idx, (pred, name, image) in enumerate( zip(layout_predictions, loader.names, loader.images) ): out_pred = pred.model_dump() out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(predictions_by_page, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ================================================ FILE: surya/scripts/detect_text.py ================================================ import click import copy import json import time from collections import defaultdict from surya.detection import DetectionPredictor from surya.debug.draw import draw_polys_on_image from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader import os configure_logging() logger = get_logger() @click.command(help="Detect bboxes in an input file or folder (PDFs or image).") @CLILoader.common_options def detect_text_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs) det_predictor = DetectionPredictor() start = time.time() predictions = det_predictor(loader.images, include_maps=loader.debug) end = time.time() if loader.debug: logger.debug(f"Detection took {end - start} seconds") if loader.save_images: for idx, (image, pred, name) in enumerate( zip(loader.images, predictions, loader.names) ): polygons = [p.polygon for p in pred.bboxes] bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png")) if loader.debug: heatmap = pred.heatmap heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png")) predictions_by_page = defaultdict(list) for idx, (pred, name, image) in enumerate( zip(predictions, loader.names, loader.images) ): out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(predictions_by_page, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ================================================ FILE: surya/scripts/finetune_ocr.py ================================================ from __future__ import annotations from dataclasses import dataclass, field from typing import Optional, Tuple from datasets import load_dataset import numpy as np import torch from transformers import ( HfArgumentParser, TrainingArguments, Trainer, ) from surya.common.surya import SuryaModel from surya.common.surya.processor import SuryaOCRProcessor from surya.foundation import FoundationPredictor from surya.common.surya.processor.schema import ImageInput, TextInput from surya.common.surya.schema import TaskNames from surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING # Do not change these defaults OCR_TASK_NAME = TaskNames.ocr_with_boxes OCR_MAX_IMAGE_SIZE = (1024, 512) # Simple wrapper for huggingface dataset class SuryaOCRDataset(torch.utils.data.Dataset): def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): super().__init__() self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split="train") self.processor = processor def __len__(self): return len(self.hf_dataset) def get_script_text(self, text: str) -> str: scripts = get_top_scripts(text) script_text = "".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts) return script_text def __getitem__(self, index): try: data = self.hf_dataset[index] image = data["image"] image = image.convert("RGB") image = np.asarray(image, dtype=np.float32) image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE) # Add in script information gt_text = data["text"] gt_text = self.get_script_text(gt_text) + gt_text return_dict = { "task": TaskNames.ocr_with_boxes, "inputs": [ ImageInput(type="image", image=image, rotated=False), # This empty TextInput **must be included** to match the original format TextInput(type="text", text=""), TextInput(type="text",text=gt_text), ], } return return_dict except: import traceback; traceback.print_exc() return self.__getitem__((index + 1) % self.__len__()) class SuryaOCRDataCollator: def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): self.processor = processor self.max_sequence_length = data_args.max_sequence_length def __call__(self, inputs): # Use right padding for training. Defaults to left for inference processed_batch = self.processor(inputs, padding_side="right") if self.max_sequence_length is not None: processed_batch["input_ids"] = processed_batch["input_ids"][:, :self.max_sequence_length] processed_batch["attention_mask"] = processed_batch["attention_mask"][:, :self.max_sequence_length] processed_batch["position_ids"] = processed_batch["position_ids"][:, :self.max_sequence_length] lm_labels = processed_batch["input_ids"].clone() skip_label_mask = ( (lm_labels == self.processor.pad_token_id ) | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes]) | (lm_labels == self.processor.eoi_token_id) | (lm_labels == self.processor.image_token_id) ) lm_labels[skip_label_mask] = -100 processed_batch["labels"] = lm_labels return processed_batch def load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]: foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path) return foundation_predictor.model, foundation_predictor.processor @dataclass class SuryaOCRModelArguments: pretrained_checkpoint_path: Optional[str] = field(default=None) @dataclass class SuryaOCRDataArguments: dataset_name: str = field(default="datalab-to/ocr_finetune_example") num_loading_proc: int = field(default=16) max_sequence_length: Optional[int] = field(default=None) @dataclass class SuryaOCRTrainingArguments(TrainingArguments): remove_unused_columns: bool = field(default=False) def main(): parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path) dataset = SuryaOCRDataset(processor, data_args) collator = SuryaOCRDataCollator(processor, data_args) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=collator ) trainer.train() if __name__ == "__main__": main() ================================================ FILE: surya/scripts/hf_to_s3.py ================================================ import json import shutil import datetime from pathlib import Path import boto3 from huggingface_hub import snapshot_download import click from tqdm import tqdm S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com" # Example usage - python scripts/hf_to_s3.py layout # This will upload to s3://layout/TODAYS_DATE @click.command(help="Uploads the data from huggingface to an S3 bucket") @click.argument("hf_repo_id", type=str) @click.argument("s3_path", type=str) @click.option("--bucket_name", type=str, default="datalab") @click.option("--revision_hash", type=str, default=None) @click.option("--access_key_id", type=str, default="") @click.option("--access_key_secret", type=str, default="") @click.option("--suffix", type=str, default="") def main( hf_repo_id: str, s3_path: str, bucket_name: str, revision_hash: str, access_key_id: str, access_key_secret: str, suffix: str, ): curr_date = datetime.datetime.now().strftime("%Y_%m_%d") s3_path = f"{s3_path}/{curr_date}" if suffix: s3_path = f"{s3_path}_{suffix}" download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash) download_folder = Path(download_folder) contained_files = list(download_folder.glob("*")) contained_files = [f.name for f in contained_files] # Just get the base name manifest_file = download_folder / "manifest.json" with open(manifest_file, "w") as f: json.dump({"files": contained_files}, f) # Upload the files to S3 s3_client = boto3.client( service_name="s3", endpoint_url=S3_API_URL, aws_access_key_id=access_key_id, aws_secret_access_key=access_key_secret, region_name="auto", ) # Iterate through all files in the folder for file_path in tqdm( download_folder.glob("*"), desc="Uploading files", unit="file" ): s3_key = f"{s3_path}/{file_path.name}" try: s3_client.upload_file(str(file_path), bucket_name, s3_key) except Exception as e: print(f"Error uploading {file_path}: {str(e)}") shutil.rmtree(download_folder) print(f"Uploaded files to {s3_path}") if __name__ == "__main__": main() ================================================ FILE: surya/scripts/ocr_latex.py ================================================ import os import click import json import time from collections import defaultdict from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.common.surya.schema import TaskNames configure_logging() logger = get_logger() @click.command(help="OCR LaTeX equations.") @CLILoader.common_options def ocr_latex_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) foundation_predictor = FoundationPredictor() texify_predictor = RecognitionPredictor(foundation_predictor) tasks = [TaskNames.block_without_boxes] * len(loader.images) bboxes = [[[0, 0, image.width, image.height]] for image in loader.images] start = time.time() predictions_by_image = texify_predictor( loader.images, tasks, bboxes=bboxes, ) latex_predictions = [p.text_lines[0].text for p in predictions_by_image] if loader.debug: logger.debug(f"OCR took {time.time() - start:.2f} seconds") max_chars = max([len(latex) for latex in latex_predictions]) logger.debug(f"Max chars: {max_chars}") out_preds = defaultdict(list) for name, pred, image in zip(loader.names, latex_predictions, loader.images): out_pred = { "equation": pred, "page": len(out_preds[name]) + 1, } out_preds[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(out_preds, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ================================================ FILE: surya/scripts/ocr_text.py ================================================ import os import click import json import time from collections import defaultdict from surya.common.surya.schema import TaskNames from surya.detection import DetectionPredictor from surya.debug.text import draw_text_on_image from surya.logging import configure_logging, get_logger from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.scripts.config import CLILoader configure_logging() logger = get_logger() @click.command(help="OCR text.") @click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes) @click.option( "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR." ) @CLILoader.common_options def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) task_names = [task_name] * len(loader.images) foundation_predictor = FoundationPredictor() det_predictor = DetectionPredictor() rec_predictor = RecognitionPredictor(foundation_predictor) start = time.time() predictions_by_image = rec_predictor( loader.images, task_names=task_names, det_predictor=det_predictor, highres_images=loader.highres_images, math_mode=not disable_math, ) if loader.debug: logger.debug(f"OCR took {time.time() - start:.2f} seconds") max_chars = max( [len(line.text) for p in predictions_by_image for line in p.text_lines] ) logger.debug(f"Max chars: {max_chars}") if loader.save_images: for idx, (name, image, pred) in enumerate( zip(loader.names, loader.images, predictions_by_image) ): bboxes = [line.bbox for line in pred.text_lines] pred_text = [line.text for line in pred.text_lines] page_image = draw_text_on_image(bboxes, pred_text, image.size) page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png")) out_preds = defaultdict(list) for name, pred, image in zip(loader.names, predictions_by_image, loader.images): out_pred = pred.model_dump() out_pred["page"] = len(out_preds[name]) + 1 out_preds[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(out_preds, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ================================================ FILE: surya/scripts/run_streamlit_app.py ================================================ import subprocess import os def streamlit_app_cli(): cur_dir = os.path.dirname(os.path.abspath(__file__)) ocr_app_path = os.path.join(cur_dir, "streamlit_app.py") cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ================================================ FILE: surya/scripts/run_texify_app.py ================================================ import subprocess import os def texify_app_cli(): cur_dir = os.path.dirname(os.path.abspath(__file__)) ocr_app_path = os.path.join(cur_dir, "texify_app.py") cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ================================================ FILE: surya/scripts/streamlit_app.py ================================================ import io import tempfile from typing import List import pypdfium2 import streamlit as st from surya.common.surya.schema import TaskNames from surya.models import load_predictors from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image from surya.debug.text import draw_text_on_image from PIL import Image, ImageDraw from surya.table_rec import TableResult from surya.detection import TextDetectionResult from surya.recognition import OCRResult from surya.layout import LayoutResult from surya.settings import settings from surya.common.util import rescale_bbox, expand_bbox @st.cache_resource() def load_predictors_cached(): return load_predictors() def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15): from pdftext.extraction import plain_text_output with tempfile.NamedTemporaryFile(suffix=".pdf") as f: f.write(pdf_file.getvalue()) f.seek(0) # Sample the text from the middle of the PDF page_middle = page_count // 2 page_range = range( max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count) ) text = plain_text_output(f.name, page_range=page_range) sample_gap = len(text) // max_samples if len(text) == 0 or sample_gap == 0: return "This PDF has no text or very little text", ["no text"] if sample_gap < sample_len: sample_gap = sample_len # Split the text into samples for the model samples = [] for i in range(0, len(text), sample_gap): samples.append(text[i : i + sample_len]) results = predictors["ocr_error"](samples) label = "This PDF has good text." if results.labels.count("bad") / len(results.labels) > 0.2: label = "This PDF may have garbled or bad OCR text." return label, results.labels def text_detection(img) -> (Image.Image, TextDetectionResult): text_pred = predictors["detection"]([img])[0] text_polygons = [p.polygon for p in text_pred.bboxes] det_img = draw_polys_on_image(text_polygons, img.copy()) return det_img, text_pred def layout_detection(img) -> (Image.Image, LayoutResult): pred = predictors["layout"]([img])[0] polygons = [p.polygon for p in pred.bboxes] labels = [ f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes ] layout_img = draw_polys_on_image( polygons, img.copy(), labels=labels, label_font_size=18 ) return layout_img, pred def table_recognition( img, highres_img, skip_table_detection: bool ) -> (Image.Image, List[TableResult]): if skip_table_detection: layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])] table_imgs = [highres_img] else: _, layout_pred = layout_detection(img) layout_tables_lowres = [ line.bbox for line in layout_pred.bboxes if line.label in ["Table", "TableOfContents"] ] table_imgs = [] layout_tables = [] for tb in layout_tables_lowres: highres_bbox = rescale_bbox(tb, img.size, highres_img.size) # Slightly expand the box highres_bbox = expand_bbox(highres_bbox) table_imgs.append(highres_img.crop(highres_bbox)) layout_tables.append(highres_bbox) table_preds = predictors["table_rec"](table_imgs) table_img = img.copy() for results, table_bbox in zip(table_preds, layout_tables): adjusted_bboxes = [] labels = [] colors = [] for item in results.cells: adjusted_bboxes.append( [ (item.bbox[0] + table_bbox[0]), (item.bbox[1] + table_bbox[1]), (item.bbox[2] + table_bbox[0]), (item.bbox[3] + table_bbox[1]), ] ) labels.append(item.label) if "Row" in item.label: colors.append("blue") else: colors.append("red") table_img = draw_bboxes_on_image( adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors, ) return table_img, table_preds # Function for OCR def ocr( img: Image.Image, highres_img: Image.Image, skip_text_detection: bool = False, recognize_math: bool = True, with_bboxes: bool = True, ) -> (Image.Image, OCRResult): if skip_text_detection: img = highres_img bboxes = [[[0, 0, img.width, img.height]]] else: bboxes = None if with_bboxes: tasks = [TaskNames.ocr_with_boxes] else: tasks = [TaskNames.ocr_without_boxes] img_pred = predictors["recognition"]( [img], task_names=tasks, bboxes=bboxes, det_predictor=predictors["detection"], highres_images=[highres_img], math_mode=recognize_math, return_words=True, )[0] bboxes = [line.bbox for line in img_pred.text_lines] text = [line.text for line in img_pred.text_lines] rec_img = draw_text_on_image(bboxes, text, img.size) word_boxes = [] for line in img_pred.text_lines: if line.words: word_boxes.extend([word.bbox for word in line.words]) box_img = img.copy() draw = ImageDraw.Draw(box_img) for word_box in word_boxes: draw.rectangle(word_box, outline="red", width=2) return rec_img, img_pred, box_img def open_pdf(pdf_file): stream = io.BytesIO(pdf_file.getvalue()) return pypdfium2.PdfDocument(stream) @st.cache_data() def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI): doc = open_pdf(pdf_file) renderer = doc.render( pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72, ) png = list(renderer)[0] png_image = png.convert("RGB") doc.close() return png_image @st.cache_data() def page_counter(pdf_file): doc = open_pdf(pdf_file) doc_len = len(doc) doc.close() return doc_len st.set_page_config(layout="wide") col1, col2 = st.columns([0.5, 0.5]) predictors = load_predictors_cached() st.markdown(""" # Surya OCR Demo This app will let you try surya, a multilingual OCR toolkit. Notes: - This works best on documents with printed text. - For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines. - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease). Find the project [here](https://github.com/VikParuchuri/surya). """) in_file = st.sidebar.file_uploader( "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] ) if in_file is None: st.stop() filetype = in_file.type page_count = None if "pdf" in filetype: page_count = page_counter(in_file) page_number = st.sidebar.number_input( f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count ) pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI) pil_image_highres = get_page_image( in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES ) else: pil_image = Image.open(in_file).convert("RGB") pil_image_highres = pil_image page_number = None run_text_det = st.sidebar.button("Run Text Detection") run_text_rec = st.sidebar.button("Run OCR") run_layout_det = st.sidebar.button("Run Layout Analysis") run_table_rec = st.sidebar.button("Run Table Rec") run_ocr_errors = st.sidebar.button("Run bad PDF text detection") use_pdf_boxes = st.sidebar.checkbox( "PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.", ) skip_table_detection = st.sidebar.checkbox( "Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.", ) skip_text_detection = st.sidebar.checkbox( "Skip text detection", value=False, help="OCR only: Skip text detection and treat the whole image as a single line.", ) recognize_math = st.sidebar.checkbox( "Recognize math in OCR", value=True, help="Enable math mode in OCR - this will recognize math.", ) ocr_with_boxes = st.sidebar.checkbox( "OCR with boxes", value=True, help="Enable OCR with boxes - this will predict character-level boxes.", ) if pil_image is None: st.stop() # Run Text Detection if run_text_det: det_img, text_pred = text_detection(pil_image) with col1: st.image(det_img, caption="Detected Text", use_container_width=True) st.json( text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True ) # Run layout if run_layout_det: layout_img, pred = layout_detection(pil_image) with col1: st.image(layout_img, caption="Detected Layout", use_container_width=True) st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) # Run OCR if run_text_rec: rec_img, pred, box_img = ocr( pil_image, pil_image_highres, skip_text_detection, recognize_math, with_bboxes=ocr_with_boxes, ) with col1: st.image(rec_img, caption="OCR Result", use_container_width=True) json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"]) with json_tab: st.json(pred.model_dump(), expanded=False) with text_tab: st.text("\n".join([p.text for p in pred.text_lines])) st.image( box_img, caption="OCR with Word Boxes (for debugging)", use_container_width=True, ) if run_table_rec: table_img, pred = table_recognition( pil_image, pil_image_highres, skip_table_detection ) with col1: st.image(table_img, caption="Table Recognition", use_container_width=True) st.json([p.model_dump() for p in pred], expanded=True) if run_ocr_errors: if "pdf" not in filetype: st.error("This feature only works with PDFs.") label, results = ocr_errors(in_file, page_count) with col1: st.write(label) st.json(results) with col2: st.image(pil_image, caption="Uploaded Image", use_container_width=True) ================================================ FILE: surya/scripts/table_recognition.py ================================================ import os import click import copy import json from collections import defaultdict from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.table_rec import TableRecPredictor from surya.debug.draw import draw_bboxes_on_image from surya.common.util import rescale_bbox, expand_bbox from surya.settings import settings configure_logging() logger = get_logger() @click.command(help="Detect layout of an input file or folder (PDFs or image).") @CLILoader.common_options @click.option( "--skip_table_detection", is_flag=True, help="Tables are already cropped, so don't re-detect tables.", default=False, ) def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) table_rec_predictor = TableRecPredictor() pnums = [] prev_name = None for i, name in enumerate(loader.names): if prev_name is None or prev_name != name: pnums.append(0) else: pnums.append(pnums[-1] + 1) prev_name = name layout_predictions = layout_predictor(loader.images) table_imgs = [] table_counts = [] for layout_pred, img, highres_img in zip( layout_predictions, loader.images, loader.highres_images ): # The table may already be cropped if skip_table_detection: table_imgs.append(highres_img) table_counts.append(1) else: # The bbox for the entire table bbox = [ line.bbox for line in layout_pred.bboxes if line.label in ["Table", "TableOfContents"] ] # Number of tables per page table_counts.append(len(bbox)) if len(bbox) == 0: continue page_table_imgs = [] highres_bbox = [] for bb in bbox: highres_bb = rescale_bbox(bb, img.size, highres_img.size) highres_bb = expand_bbox(highres_bb) page_table_imgs.append(highres_img.crop(highres_bb)) highres_bbox.append(highres_bb) table_imgs.extend(page_table_imgs) table_preds = table_rec_predictor(table_imgs) img_idx = 0 prev_count = 0 table_predictions = defaultdict(list) for i in range(sum(table_counts)): while i >= prev_count + table_counts[img_idx]: prev_count += table_counts[img_idx] img_idx += 1 pred = table_preds[i] orig_name = loader.names[img_idx] pnum = pnums[img_idx] table_img = table_imgs[i] out_pred = pred.model_dump() out_pred["page"] = pnum + 1 table_idx = i - prev_count out_pred["table_idx"] = table_idx table_predictions[orig_name].append(out_pred) if loader.save_images: rows = [line.bbox for line in pred.rows] cols = [line.bbox for line in pred.cols] row_labels = [f"Row {line.row_id}" for line in pred.rows] col_labels = [f"Col {line.col_id}" for line in pred.cols] cells = [line.bbox for line in pred.cells] rc_image = copy.deepcopy(table_img) rc_image = draw_bboxes_on_image( rows, rc_image, labels=row_labels, label_font_size=20, color="blue" ) rc_image = draw_bboxes_on_image( cols, rc_image, labels=col_labels, label_font_size=20, color="red" ) rc_image.save( os.path.join( loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png" ) ) cell_image = copy.deepcopy(table_img) cell_image = draw_bboxes_on_image(cells, cell_image, color="green") cell_image.save( os.path.join( loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png", ) ) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(table_predictions, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ================================================ FILE: surya/scripts/texify_app.py ================================================ import os import re from typing import List from surya.recognition import RecognitionPredictor from surya.foundation import FoundationPredictor from surya.common.surya.schema import TaskNames os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = ( "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS ) import io import pandas as pd import streamlit as st from streamlit_drawable_canvas import st_canvas import hashlib import pypdfium2 from surya.settings import settings from PIL import Image MAX_WIDTH = 800 MAX_HEIGHT = 1000 def replace_fences(text): text = re.sub(r'(.*?)', r"$$\1$$", text) text = re.sub(r"(.*?)", r"$\1$", text) text = re.sub(r'(.*?)', r"$\1$", text) return text @st.cache_resource() def load_predictor(): foundation_predictor = FoundationPredictor() return RecognitionPredictor(foundation_predictor) @st.cache_data() def inference(pil_image: Image.Image, bbox: List[float]): input_img = pil_image.crop(bbox) bbox = [0, 0, input_img.width, input_img.height] model_output = predictor( [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]] ) return model_output[0].text_lines[0].text def open_pdf(pdf_file): stream = io.BytesIO(pdf_file.getvalue()) return pypdfium2.PdfDocument(stream) @st.cache_data() def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES): doc = open_pdf(pdf_file) renderer = doc.render( pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72, ) png = list(renderer)[0] png_image = png.convert("RGB") doc.close() return png_image @st.cache_data() def page_counter(pdf_file): doc = open_pdf(pdf_file) doc_len = len(doc) doc.close() return doc_len def resize_image(pil_image): if pil_image is None: return pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) def get_canvas_hash(pil_image): return hashlib.md5(pil_image.tobytes()).hexdigest() st.set_page_config(layout="wide") top_message = """### LaTeX OCR After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right. """ st.markdown(top_message) col1, col2 = st.columns([0.7, 0.3]) predictor = load_predictor() in_file = st.sidebar.file_uploader( "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] ) if in_file is None: st.stop() if in_file is None: st.stop() filetype = in_file.type page_count = None if "pdf" in filetype: page_count = page_counter(in_file) page_number = st.sidebar.number_input( f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count ) pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) else: pil_image = Image.open(in_file).convert("RGB") page_number = None if pil_image is None: st.stop() pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) canvas_hash = get_canvas_hash(pil_image) with col1: # Create a canvas component canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity stroke_width=1, stroke_color="#FFAA00", background_color="#FFF", background_image=pil_image, update_streamlit=True, height=pil_image.height, width=pil_image.width, drawing_mode="rect", point_display_radius=0, key=canvas_hash, ) if not canvas_result.json_data: st.stop() objects = pd.json_normalize( canvas_result.json_data["objects"] ) # need to convert obj to str because PyArrow bbox_list = None if objects.shape[0] > 0: boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]] boxes["right"] = boxes["left"] + boxes["width"] boxes["bottom"] = boxes["top"] + boxes["height"] bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist() if bbox_list: with col2: texts = [inference(pil_image, bbox) for bbox in bbox_list] for idx, latex in enumerate(reversed(texts)): st.markdown(f"### {len(texts) - idx}") st.markdown(replace_fences(latex), unsafe_allow_html=True) st.code(latex) st.divider() with col2: tips = """ ### Usage tips - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple. """ st.markdown(tips) ================================================ FILE: surya/settings.py ================================================ import os from typing import Callable, Dict, Optional import torch from dotenv import find_dotenv from pydantic import computed_field from pydantic_settings import BaseSettings from pathlib import Path from platformdirs import user_cache_dir class Settings(BaseSettings): # General TORCH_DEVICE: Optional[str] = None IMAGE_DPI: int = 96 # Used for detection, layout, reading order IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec IN_STREAMLIT: bool = False # Whether we're running in streamlit FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing DISABLE_TQDM: bool = False # Disable tqdm progress bars S3_BASE_URL: str = "https://models.datalab.to" PARALLEL_DOWNLOAD_WORKERS: int = ( 10 # Number of workers for parallel model downloads ) MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models") LOGLEVEL: str = "INFO" # Logging level # Paths DATA_DIR: str = "data" RESULT_DIR: str = "results" BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts") @computed_field def TORCH_DEVICE_MODEL(self) -> str: if self.TORCH_DEVICE is not None: return self.TORCH_DEVICE if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" try: import torch_xla if len(torch_xla.devices()) > 0: return "xla" except Exception: pass return "cpu" # Text detection DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07" DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench" DETECTOR_IMAGE_CHUNK_HEIGHT: int = ( 1400 # Height at which to slice images vertically ) DETECTOR_TEXT_THRESHOLD: float = ( 0.6 # Threshold for text detection (above this is considered text) ) DETECTOR_BLANK_THRESHOLD: float = ( 0.35 # Threshold for blank space (below this is considered blank) ) DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min( 8, os.cpu_count() ) # Number of workers for postprocessing DETECTOR_MIN_PARALLEL_THRESH: int = ( 3 # Minimum number of images before we parallelize ) DETECTOR_BOX_Y_EXPAND_MARGIN: float = ( 0.05 # Margin by which to expand detected boxes vertically ) COMPILE_DETECTOR: bool = False # Text recognition FOUNDATION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" FOUNDATION_MODEL_QUANTIZE: bool = False FOUNDATION_MAX_TOKENS: Optional[int] = None FOUNDATION_CHUNK_SIZE: Optional[int] = None FOUNDATION_PAD_TO_NEAREST: int = 256 COMPILE_FOUNDATION: bool = False FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9 RECOGNITION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" RECOGNITION_BATCH_SIZE: Optional[int] = ( None # Defaults to 8 for CPU/MPS, 256 otherwise ) RECOGNITION_RENDER_FONTS: Dict[str, str] = { "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), } RECOGNITION_FONT_DL_BASE: str = ( "https://github.com/satbyy/go-noto-universal/releases/download/v7.0" ) RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 # Layout LAYOUT_MODEL_CHECKPOINT: str = "s3://layout/2025_09_23" LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768} LAYOUT_SLICE_MIN: Dict = { "height": 1500, "width": 1500, } # When to start slicing images LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200} # Size of slices LAYOUT_BATCH_SIZE: Optional[int] = None LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" LAYOUT_MAX_BOXES: int = 100 COMPILE_LAYOUT: bool = False LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" # Table Rec TABLE_REC_MODEL_CHECKPOINT: str = "s3://table_recognition/2025_02_18" TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768} TABLE_REC_MAX_BOXES: int = 150 TABLE_REC_BATCH_SIZE: Optional[int] = None TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench" COMPILE_TABLE_REC: bool = False # Texify TEXIFY_BENCHMARK_DATASET: str = "datalab-to/texify_bench" # OCR Error Detection OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18" OCR_ERROR_BATCH_SIZE: Optional[int] = None COMPILE_OCR_ERROR: bool = False # Tesseract (for benchmarks only) TESSDATA_PREFIX: Optional[str] = None COMPILE_ALL: bool = False @computed_field def DETECTOR_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_DETECTOR or self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def LAYOUT_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def FOUNDATION_XLA(self) -> bool: return ( self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def FOUNDATION_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_FOUNDATION or self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def TABLE_REC_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_TABLE_REC or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def OCR_ERROR_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_OCR_ERROR or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def MODEL_DTYPE(self) -> torch.dtype: if self.TORCH_DEVICE_MODEL == "cpu": return torch.float32 if self.TORCH_DEVICE_MODEL == "xla": return torch.bfloat16 return torch.float16 @computed_field def MODEL_DTYPE_BFLOAT(self) -> torch.dtype: if self.TORCH_DEVICE_MODEL == "cpu": return torch.float32 if self.TORCH_DEVICE_MODEL == "mps": return torch.bfloat16 return torch.bfloat16 @computed_field def INFERENCE_MODE(self) -> Callable: if self.TORCH_DEVICE_MODEL == "xla": return torch.no_grad return torch.inference_mode class Config: env_file = find_dotenv("local.env") extra = "ignore" settings = Settings() ================================================ FILE: surya/table_rec/__init__.py ================================================ from copy import deepcopy from itertools import chain from typing import List import numpy as np import torch from PIL import Image from tqdm import tqdm from surya.common.xla import mark_step from surya.common.predictor import BasePredictor from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult from surya.common.polygon import PolygonBox from surya.settings import settings from surya.table_rec.loader import TableRecModelLoader from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \ MERGE_VALUES from surya.table_rec.shaper import LabelShaper class TableRecPredictor(BasePredictor): model_loader_cls = TableRecModelLoader batch_size = settings.TABLE_REC_BATCH_SIZE default_batch_sizes = { "cpu": 8, "mps": 8, "cuda": 32, "xla": 16 } def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: return self.batch_table_recognition(images, batch_size) def inference_loop( self, encoder_hidden_states: torch.Tensor, batch_input_ids: torch.Tensor, current_batch_size: int, batch_size: int ): shaper = LabelShaper() batch_predictions = [[] for _ in range(current_batch_size)] max_tokens = settings.TABLE_REC_MAX_BOXES decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum( 0) - 1 inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size) batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) # Move to device after padding for XLA encoder_hidden_states = encoder_hidden_states.to(self.model.device) batch_input_ids = batch_input_ids.to(self.model.device) self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) with settings.INFERENCE_MODE(): token_count = 0 all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device) while token_count < max_tokens: is_prefill = token_count == 0 return_dict = self.model.decoder( input_ids=batch_input_ids, encoder_hidden_states=encoder_hidden_states, cache_position=decoder_position_ids, use_cache=True, prefill=is_prefill ) decoder_position_ids = decoder_position_ids[-1:] + 1 # Get predictions for each box element box_properties = [] done = [] # Pre-process all logits at once processed_logits = {} for k, _, mode in BOX_PROPERTIES: k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once if mode == "classification": # Process all classification logits in one operation items = torch.argmax(k_logits, dim=-1) if k == "category": done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id) items = items - SPECIAL_TOKENS processed_logits[k] = items elif mode == "regression": if k == "bbox": k_logits = k_logits * BOX_DIM processed_logits[k] = k_logits elif k == "colspan": k_logits = torch.clamp(k_logits, min=1) processed_logits[k] = torch.round(k_logits) items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES} for j in range(current_batch_size): box_property = {} for k, _, mode in BOX_PROPERTIES: if mode == "classification": box_property[k] = int(items[k][j].item()) elif mode == "regression": if k == "bbox": box_property[k] = items[k][j].tolist() elif k == "colspan": box_property[k] = int(items[k][j].item()) box_properties.append(box_property) all_done = all_done | done all_done_cpu = all_done.cpu() if all_done_cpu[:current_batch_size].all(): break batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long) batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)): if not status: batch_predictions[j].append(box_property) token_count += inference_token_count inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) # Move to device after padding for XLA batch_input_ids = batch_input_ids.to(self.model.device) return batch_predictions def batch_table_recognition( self, images: List, batch_size=None) -> List[TableResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() if len(images) == 0: return [] query_items = [] for image in images: query_items.append({ "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]], "category": CATEGORY_TO_ID["Table"], "colspan": 0, "merges": 0, "is_header": 0 }) output_order = [] for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm): batch_query_items = query_items[i:i + batch_size] batch_images = images[i:i + batch_size] batch_images = [image.convert("RGB") for image in batch_images] # also copies the images current_batch_size = len(batch_images) orig_sizes = [image.size for image in batch_images] model_inputs = self.processor(images=batch_images, query_items=batch_query_items) batch_pixel_values = model_inputs["pixel_values"] batch_input_ids = model_inputs["input_ids"] batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype) if settings.TABLE_REC_STATIC_CACHE: batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) # Move to device after padding for XLA batch_pixel_values = batch_pixel_values.to(self.model.device) shaper = LabelShaper() # We only need to process each image once with settings.INFERENCE_MODE(): encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state # Inference to get rows and columns rowcol_predictions = self.inference_loop( encoder_hidden_states, batch_input_ids, current_batch_size, batch_size ) mark_step() row_query_items = [] row_encoder_hidden_states = [] idx_map = [] columns = [] for j, img_predictions in enumerate(rowcol_predictions): for row_prediction in img_predictions: polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]: row_query_items.append({ "polygon": polygon, "category": row_prediction["category"], "colspan": 0, "merges": 0, "is_header": int(row_prediction["is_header"] == 1) }) row_encoder_hidden_states.append(encoder_hidden_states[j]) idx_map.append(j) elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]: columns.append({ "polygon": polygon, "category": row_prediction["category"], "colspan": 0, "merges": 0, "is_header": int(row_prediction["is_header"] == 1) }) # Re-inference to predict cells row_encoder_hidden_states = torch.stack(row_encoder_hidden_states) row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False) row_input_ids = row_inputs["input_ids"] cell_predictions = [] for j in range(0, len(row_input_ids), batch_size): cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size] cell_batch_input_ids = row_input_ids[j:j + batch_size] cell_batch_size = len(cell_batch_input_ids) cell_predictions.extend( self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size) ) mark_step() result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper) output_order.extend(result) return output_order def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper): results = [] for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j] # Each row prediction matches a cell prediction rows = [] cells = [] columns = [] cell_id = 0 row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]] col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]] # Generate table columns for z, col_prediction in enumerate(col_predictions): polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) columns.append( TableCol( polygon=polygon, col_id=z, is_header=col_prediction["is_header"] == 1 ) ) # Generate table rows for z, row_prediction in enumerate(row_predictions): polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) row = TableRow( polygon=polygon, row_id=z, is_header=row_prediction["is_header"] == 1 ) rows.append(row) # Get cells that span multiple columns within a row spanning_cells = [] for l, spanning_cell in enumerate(row_cell_predictions[z]): polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) colspan = max(1, int(spanning_cell["colspan"])) if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES: # Skip single column cells if they don't merge continue if PolygonBox(polygon=polygon).height < row.height * .85: # Spanning cell must cover most of the row continue spanning_cells.append( TableCell( polygon=polygon, row_id=z, rowspan=1, cell_id=cell_id, within_row_id=l, colspan=colspan, merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]], is_header=row.is_header or z == 0 ) ) cell_id += 1 # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col used_spanning_cells = set() skip_columns = 0 for l, col in enumerate(columns): if skip_columns: skip_columns -= 1 continue cell_polygon = row.intersection_polygon(col) cell_added = False for zz, spanning_cell in enumerate(spanning_cells): cell_polygonbox = PolygonBox(polygon=cell_polygon) intersection_pct = cell_polygonbox.intersection_pct(spanning_cell) # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns) correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]]) if intersection_pct > .9: if spanning_cell.width > (correct_col_width * .85): cell_added = True if zz not in used_spanning_cells: used_spanning_cells.add(zz) spanning_cell.col_id = l cells.append(spanning_cell) skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell else: used_spanning_cells.add(zz) # Skip this spanning cell if not cell_added: cells.append( TableCell( polygon=cell_polygon, row_id=z, rowspan=1, cell_id=cell_id, within_row_id=l, colspan=1, merge_up=False, merge_down=False, col_id=l, is_header=row.is_header or col.is_header or z == 0 ) ) cell_id += 1 # Turn cells into a row grid grid_cells = deepcopy([ [cell for cell in cells if cell.row_id == row.row_id] for row in rows ]) # Merge cells across rows for z, grid_row in enumerate(grid_cells[1:]): prev_row = grid_cells[z] for l, cell in enumerate(grid_row): if l >= len(prev_row): continue above_cell = prev_row[l] if all([ above_cell.merge_down, cell.merge_up, above_cell.col_id == cell.col_id, above_cell.colspan == cell.colspan, ]): above_cell.merge(cell) above_cell.rowspan += cell.rowspan grid_row[l] = above_cell merged_cells_all = list(chain.from_iterable(grid_cells)) used_ids = set() merged_cells = [] for cell in merged_cells_all: if cell.cell_id in used_ids: continue used_ids.add(cell.cell_id) merged_cells.append(cell) result = TableResult( cells=merged_cells, unmerged_cells=cells, rows=rows, cols=columns, image_bbox=[0, 0, orig_size[0], orig_size[1]], ) results.append(result) return results ================================================ FILE: surya/table_rec/loader.py ================================================ from typing import Optional import torch from surya.common.load import ModelLoader from surya.logging import get_logger from surya.settings import settings from surya.table_rec.model.config import ( SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, ) from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel from surya.table_rec.processor import SuryaTableRecProcessor logger = get_logger() class TableRecModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> TableRecEncoderDecoderModel: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE if device == "mps": logger.warning( "`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead" ) device = "cpu" dtype = "float32" config = SuryaTableRecConfig.from_pretrained(self.checkpoint) decoder_config = config.decoder decoder = SuryaTableRecDecoderConfig(**decoder_config) config.decoder = decoder encoder_config = config.encoder encoder = DonutSwinTableRecConfig(**encoder_config) config.encoder = encoder model = TableRecEncoderDecoderModel.from_pretrained( self.checkpoint, config=config, dtype=dtype ) model = model.to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC: torch.set_float32_matmul_precision("high") torch._dynamo.config.cache_size_limit = 16 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model.encoder = torch.compile(model.encoder, **compile_args) model.decoder = torch.compile(model.decoder, **compile_args) logger.debug( f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE ) -> SuryaTableRecProcessor: processor = SuryaTableRecProcessor(self.checkpoint) processor.token_pad_id = 0 processor.token_eos_id = 1 processor.token_bos_id = 1 processor.token_query_end_id = 4 return processor ================================================ FILE: surya/table_rec/model/__init__.py ================================================ ================================================ FILE: surya/table_rec/model/config.py ================================================ from dataclasses import dataclass from typing import Dict import torch from transformers import PretrainedConfig from transformers.utils import ModelOutput from surya.common.s3 import S3DownloaderMixin from surya.settings import settings BOX_DIM = 1024 SPECIAL_TOKENS = 5 MAX_BOXES = 150 MERGE_KEYS = { "none": 0, "merge_up": 1, "merge_down": 2, "merge_both": 3 } MERGE_VALUES = [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]] ID_TO_CATEGORY = { 0: 'Blank', 1: 'Table-row', 2: 'Table-column', 3: 'Table-cell', 4: 'Table' } CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()} ID_TO_HEADER = { 0: "None", 1: "Header" } HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()} BOX_PROPERTIES = [ ("bbox", 6, "regression"), ("category", len(ID_TO_CATEGORY), "classification"), ("merges", len(MERGE_KEYS), "classification"), ("colspan", 1, "regression"), ("is_header", len(ID_TO_HEADER), "classification") ] @dataclass class TableRecModelOutput(ModelOutput): box_property_logits: Dict[str, torch.Tensor] hidden_states: torch.Tensor | None = None class SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig): model_type = "vision-encoder-decoder" is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) if "encoder" in kwargs: encoder_config = kwargs.pop("encoder") decoder_config = kwargs.pop("decoder") else: encoder_config = DonutSwinTableRecConfig() decoder_config = SuryaTableRecDecoderConfig() self.encoder = encoder_config self.decoder = decoder_config self.is_encoder_decoder = True if isinstance(decoder_config, dict): self.decoder_start_token_id = decoder_config["bos_token_id"] self.pad_token_id = decoder_config["pad_token_id"] self.eos_token_id = decoder_config["eos_token_id"] else: self.decoder_start_token_id = decoder_config.bos_token_id self.pad_token_id = decoder_config.pad_token_id self.eos_token_id = decoder_config.eos_token_id class DonutSwinTableRecConfig(PretrainedConfig): model_type = "donut-swin" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } def __init__( self, image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]), patch_size=4, num_channels=3, embed_dim=128, depths=[2, 2, 12, 2], num_heads=[4, 8, 16, 32], num_kv_heads=[4, 8, 16, 32], window_size=8, mlp_ratio=4.0, qkv_bias=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, drop_path_rate=0.1, hidden_act="gelu", use_absolute_embeddings=False, initializer_range=0.02, layer_norm_eps=1e-5, encoder_length=1024, use_positional_embeddings=True, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.embed_dim = embed_dim self.depths = depths self.num_layers = len(depths) self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.drop_path_rate = drop_path_rate self.hidden_act = hidden_act self.use_absolute_embeddings = use_absolute_embeddings self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.encoder_length = encoder_length self.use_positional_embeddings = use_positional_embeddings class SuryaTableRecDecoderConfig(PretrainedConfig): model_type = "surya_tablerec" def __init__( self, num_hidden_layers=6, vocab_size=BOX_DIM + 1, bbox_size=BOX_DIM, hidden_size=512, property_embed_size=64, box_embed_size=512 - 64, intermediate_size=4 * 512, encoder_hidden_size=1024, num_attention_heads=8, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=1, pause_token_id=2, query_end_token_id=4, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=tuple(range(10)), encoder_cross_attn_layers=tuple(range(10)), self_attn_layers=tuple(range(10)), global_attn_layers=tuple(range(10)), attention_dropout=0.0, num_key_value_heads=4, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, aux_heads=0, # How many n-token-ahead heads to add causal=True, layer_norm_eps=1e-5, dropout=0.0, special_token_count=SPECIAL_TOKENS, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.aux_heads = aux_heads self.encoder_hidden_size=encoder_hidden_size self.causal = causal self.encoder_cross_attn_layers = encoder_cross_attn_layers self.layer_norm_eps = layer_norm_eps self.dropout = dropout self.bbox_size = bbox_size self.pause_token_id = pause_token_id self.box_properties = BOX_PROPERTIES self.property_embed_size = property_embed_size self.box_embed_size = box_embed_size self.special_token_count = special_token_count self.query_end_token_id = query_end_token_id self.double_residual_flow = False super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers] ================================================ FILE: surya/table_rec/model/decoder.py ================================================ from typing import Optional, Tuple, Union import torch from torch import nn from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel from surya.table_rec.model.config import TableRecModelOutput from surya.table_rec.shaper import LabelShaper from surya.settings import settings class LabelEmbedding(nn.Module): def __init__(self, config): super().__init__() # Bboxes self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) # Get indexes for passed in tensor shaper = LabelShaper() self.component_idxs = shaper.component_idx_dict() merge_count = shaper.get_box_property("merges")[1] + config.special_token_count category_count = shaper.get_box_property("category")[1] + config.special_token_count # Other box properties self.category_embed = nn.Embedding(category_count, config.property_embed_size) self.merge_embed = nn.Embedding(merge_count, config.property_embed_size) self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size) self.config = config def forward(self, boxes: torch.LongTensor, *args): # Need to keep *args for compatibility with common decoder boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size) boxes_unbound = boxes.to(torch.long).unbind(dim=-1) cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]] category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0] merges = boxes_unbound[self.component_idxs["merges"][0]:self.component_idxs["merges"][1]][0] colspan = boxes_unbound[self.component_idxs["colspan"][0]:self.component_idxs["colspan"][1]][0] xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long) yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long) x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew) corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3) box_embeds = size_embeds + skew_embeds + corner_embeds property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan) # Cat bbox and property embeddings embedded = torch.cat([box_embeds, property_embeds], dim=-1) return embedded class SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel): _tied_weights_keys = None def __init__(self, config, **kwargs): super().__init__(config) embed_tokens = LabelEmbedding(config) self.model = SuryaADETRDecoderModel( config, embedder=embed_tokens, static_cache=settings.TABLE_REC_STATIC_CACHE, max_boxes=settings.TABLE_REC_MAX_BOXES ) self.vocab_size = config.vocab_size shaper = LabelShaper() property_heads = {} for k in shaper.property_keys: _, kcount, mode = shaper.get_box_property(k) property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False) self.box_property_heads = nn.ModuleDict(property_heads) self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, prefill: bool = False, **kwargs ) -> Union[Tuple, TableRecModelOutput]: outputs = self.model( input_ids=input_ids, cache_position=cache_position, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_hidden_states=True, return_dict=True, prefill=prefill, ) hidden_states = self.pre_output_norm(outputs[0]) box_property_logits = {} for key in self.box_property_heads: box_property_logits[key] = self.box_property_heads[key](hidden_states) bbox_logits = nn.functional.sigmoid(box_property_logits["bbox"]) box_property_logits["bbox"] = bbox_logits return TableRecModelOutput( box_property_logits=box_property_logits, hidden_states=hidden_states, ) ================================================ FILE: surya/table_rec/model/encoder.py ================================================ from typing import Optional, Union, Tuple import torch import torch.nn as nn from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder class DonutSwinModel(DonutSwinPreTrainedModel): def __init__(self, config, add_pooling_layer=True, use_mask_token=False): super().__init__(config) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) self.position_embeddings = None if hasattr(config, "encoder_length"): self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, DonutSwinModelOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) embedding_output, input_dimensions = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) encoder_outputs = self.encoder( embedding_output, input_dimensions, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] if self.position_embeddings is not None: last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] return DonutSwinModelOutput( last_hidden_state=last_hidden_state, ) ================================================ FILE: surya/table_rec/model/encoderdecoder.py ================================================ from dataclasses import dataclass from typing import Optional, Union, Tuple, Dict import torch from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.table_rec.model.decoder import SuryaTableRecDecoder from surya.table_rec.model.encoder import DonutSwinModel from transformers.utils import ModelOutput @dataclass class TableRecOutput(ModelOutput): box_property_logits: Dict[str, torch.FloatTensor] decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None class TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel): config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False def __init__( self, config: Optional[PretrainedConfig] = None, encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, **kwargs, ): # initialize with config # make sure input & output embeddings is not tied config.tie_word_embeddings = False config.decoder.tie_word_embeddings = False super().__init__(config, **kwargs) if encoder is None: encoder = DonutSwinModel(config.encoder) if decoder is None: decoder = SuryaTableRecDecoder( config.decoder, attn_implementation=config._attn_implementation ) self.encoder = encoder self.decoder = decoder # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def get_output_embeddings(self): return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) def forward( self, decoder_input_ids: torch.LongTensor = None, decoder_cache_position: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]: kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # Decode decoder_outputs = self.decoder( input_labels=decoder_input_ids, input_boxes_counts=None, cache_position=decoder_cache_position, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs, encoder_attention_mask=None, use_cache=use_cache, **kwargs_decoder, ) return TableRecOutput( box_property_logits=decoder_outputs.box_property_logits, decoder_hidden_states=decoder_outputs.hidden_states, ) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) def _reorder_cache(self, past_key_values, beam_idx): # apply decoder cache reordering here return self.decoder._reorder_cache(past_key_values, beam_idx) ================================================ FILE: surya/table_rec/processor.py ================================================ from typing import List import PIL import torch from transformers import ProcessorMixin from surya.common.s3 import S3DownloaderMixin from surya.common.donut.processor import SuryaEncoderImageProcessor from surya.table_rec.shaper import LabelShaper from surya.settings import settings from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin): attributes = ["image_processor"] image_processor_class = "AutoImageProcessor" def __init__(self, checkpoint, **kwargs): image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint) image_processor.do_align_long_axis = False image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE self.image_processor = image_processor super().__init__(image_processor) self.box_size = (BOX_DIM, BOX_DIM) self.special_token_count = SPECIAL_TOKENS self.shaper = LabelShaper() def resize_polygon(self, polygon, orig_size, new_size): w_scaler = new_size[0] / orig_size[0] h_scaler = new_size[1] / orig_size[1] for corner in polygon: corner[0] = corner[0] * w_scaler corner[1] = corner[1] * h_scaler if corner[0] < 0: corner[0] = 0 if corner[1] < 0: corner[1] = 0 if corner[0] > new_size[0]: corner[0] = new_size[0] if corner[1] > new_size[1]: corner[1] = new_size[1] return polygon def __call__( self, images: List[PIL.Image.Image] | None, query_items: List[dict], columns: List[dict] | None = None, convert_images: bool = True, *args, **kwargs ): if convert_images: assert len(images) == len(query_items) assert len(images) > 0 # Resize input query items for image, query_item in zip(images, query_items): query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size) query_items = self.shaper.convert_polygons_to_bboxes(query_items) query_labels = self.shaper.dict_to_labels(query_items) decoder_input_boxes = [] col_count = len(query_labels[0]) for label in query_labels: decoder_input_boxes.append([ [self.token_bos_id] * col_count, label, [self.token_query_end_id] * col_count ]) # Add columns to end of decoder input if columns: columns = self.shaper.convert_polygons_to_bboxes(columns) column_labels = self.shaper.dict_to_labels(columns) for decoder_box in decoder_input_boxes: decoder_box += column_labels input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long) input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long) inputs = { "input_ids": input_boxes, "attention_mask": input_boxes_mask } if convert_images: inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"] return inputs ================================================ FILE: surya/table_rec/schema.py ================================================ from typing import List from pydantic import BaseModel from surya.common.polygon import PolygonBox class TableCell(PolygonBox): row_id: int colspan: int within_row_id: int cell_id: int is_header: bool rowspan: int | None = None merge_up: bool = False merge_down: bool = False col_id: int | None = None text_lines: List[dict] | None = None @property def label(self): return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}' class TableRow(PolygonBox): row_id: int is_header: bool @property def label(self): return f'Row {self.row_id}' class TableCol(PolygonBox): col_id: int is_header: bool @property def label(self): return f'Column {self.col_id}' class TableResult(BaseModel): cells: List[TableCell] unmerged_cells: List[TableCell] rows: List[TableRow] cols: List[TableCol] image_bbox: List[float] ================================================ FILE: surya/table_rec/shaper.py ================================================ import math from typing import List, Dict import numpy as np from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM class LabelShaper: def __init__(self): self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES] def dict_to_labels(self, label_components: List[dict]): if len(label_components) == 0: return [] out_list = [] for (k, kcount, mode) in BOX_PROPERTIES: for label_component in label_components: if k not in label_component: raise ValueError(f"Missing key {k} in label component {label_component}") if mode == "classification": assert isinstance(label_component[k], int) elif mode == "regression": assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount else: raise ValueError(f"Invalid mode {k['mode']} for key {k}") for label_component in label_components: bbox = label_component["bbox"] for i in range(len(bbox)): if bbox[i] < 0: bbox[i] = 0 if bbox[i] > BOX_DIM: bbox[i] = BOX_DIM vector = [] for (k, kcount, mode) in BOX_PROPERTIES: item = label_component[k] if isinstance(item, (list, tuple)): vector += list(item) elif isinstance(item, (float, int)): if mode == "classification": # Shift up for model item += SPECIAL_TOKENS vector.append(item) else: raise ValueError(f"Invalid item {item} for key {k}") out_list.append(vector) return out_list def component_idx(self, key): idx = 0 for (k, kcount, mode) in BOX_PROPERTIES: if mode == "regression": incr = kcount elif mode == "classification": incr = 1 else: raise ValueError(f"Invalid mode {mode} for key {k}") if k == key: return (idx, idx + incr) idx += incr raise ValueError(f"Key {key} not found in properties") def get_box_property(self, key, add_special_tokens=True): for (k, kcount, mode) in BOX_PROPERTIES: if k == key: # Add special token count if mode == "classification" and add_special_tokens: kcount += SPECIAL_TOKENS return (k, kcount, mode) raise ValueError(f"Key {key} not found in properties") def component_idx_dict(self): idx_dict = {} for (k, kcount, mode) in BOX_PROPERTIES: idx_dict[k] = self.component_idx(k) return idx_dict def convert_polygons_to_bboxes(self, label_components: List[Dict]): for i, label_component in enumerate(label_components): poly = label_component["polygon"] poly = np.clip(poly, 0, BOX_DIM) (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly cx = (x1 + x2 + x3 + x4) / 4 cy = (y1 + y2 + y3 + y4) / 4 width = (x2 + x3) / 2 - (x1 + x4) / 2 height = (y3 + y4) / 2 - (y2 + y1) / 2 bottom_avg_x = (x3 + x4) / 2 top_avg_x = (x1 + x2) / 2 right_avg_y = (y2 + y3) / 2 left_avg_y = (y1 + y4) / 2 x_skew = bottom_avg_x - top_avg_x y_skew = right_avg_y - left_avg_y x_skew += BOX_DIM // 2 # Shift up into positive space y_skew += BOX_DIM // 2 # Shift up into positive space new_poly = [ cx, cy, width, height, x_skew, y_skew ] label_component["bbox"] = new_poly return label_components def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001): cx = box[0] cy = box[1] width = box[2] height = box[3] x1 = cx - width / 2 y1 = cy - height / 2 x2 = cx + width / 2 y2 = cy + height / 2 skew_x = math.floor((box[4] - skew_scaler) / 2) skew_y = math.floor((box[5] - skew_scaler) / 2) # Ensures we don't get slightly warped boxes # Note that the values are later scaled, so this is in 1/1024 space if abs(skew_x) < skew_min: skew_x = 0 if abs(skew_y) < skew_min: skew_y = 0 polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, y2 - skew_y] poly = [] for i in range(4): poly.append([ polygon[2 * i], polygon[2 * i + 1] ]) return poly ================================================ FILE: table_recognition.py ================================================ from surya.scripts.table_recognition import table_recognition_cli if __name__ == "__main__": table_recognition_cli() ================================================ FILE: tests/conftest.py ================================================ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import pytest from PIL import Image, ImageDraw from surya.detection import DetectionPredictor from surya.ocr_error import OCRErrorPredictor from surya.layout import LayoutPredictor from surya.recognition import RecognitionPredictor from surya.foundation import FoundationPredictor from surya.table_rec import TableRecPredictor from surya.settings import settings @pytest.fixture(scope="session") def ocr_error_predictor() -> OCRErrorPredictor: ocr_error_predictor = OCRErrorPredictor() yield ocr_error_predictor del ocr_error_predictor @pytest.fixture(scope="session") def layout_predictor() -> LayoutPredictor: layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) yield layout_predictor del layout_predictor @pytest.fixture(scope="session") def detection_predictor() -> DetectionPredictor: detection_predictor = DetectionPredictor() yield detection_predictor del detection_predictor @pytest.fixture(scope="session") def recognition_predictor() -> RecognitionPredictor: recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)) yield recognition_predictor del recognition_predictor @pytest.fixture(scope="session") def table_rec_predictor() -> TableRecPredictor: table_rec_predictor = TableRecPredictor() yield table_rec_predictor del table_rec_predictor @pytest.fixture() def test_image(): image = Image.new("RGB", (1024, 1024), "white") draw = ImageDraw.Draw(image) draw.text((10, 10), "Hello World", fill="black", font_size=72) draw.text( (10, 200), "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", fill="black", font_size=24, ) return image @pytest.fixture() def test_image_tall(): image = Image.new("RGB", (4096, 4096), "white") draw = ImageDraw.Draw(image) draw.text((10, 10), "Hello World", fill="black", font_size=72) draw.text( (4000, 4000), "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.", fill="black", font_size=24, ) return image @pytest.fixture() def test_image_latex(): assets_dir = os.path.join(os.path.dirname(__file__), "assets") img_path = os.path.join(assets_dir, "test_latex.png") image = Image.open(img_path).convert("RGB") return image ================================================ FILE: tests/test_detection.py ================================================ def test_detection(detection_predictor, test_image): detection_results = detection_predictor([test_image]) assert len(detection_results) == 1 assert detection_results[0].image_bbox == [0, 0, 1024, 1024] bboxes = detection_results[0].bboxes assert len(bboxes) == 4 def test_detection_chunking(detection_predictor, test_image_tall): detection_results = detection_predictor([test_image_tall]) assert len(detection_results) == 1 assert detection_results[0].image_bbox == [0, 0, 4096, 4096] bboxes = detection_results[0].bboxes assert len(bboxes) >= 3 # Sometimes merges into 3 assert abs(4000 - bboxes[1].polygon[0][0]) < 50 ================================================ FILE: tests/test_foundation.py ================================================ from surya.foundation import FoundationPredictor def test_foundation_flash2(): try: f = FoundationPredictor(None, None, None, "flash_attention_2") assert f.model.decoder.config._attn_implementation == "flash_attention_2" assert f.model.vision_encoder.config._attn_implementation == "flash_attention_2" except Exception as e: assert False, ( f"FoundationPredictor with flash_attention_2 raised an exception: {e}" ) ================================================ FILE: tests/test_latex_ocr.py ================================================ from typing import List from PIL import Image, ImageDraw from surya.common.surya.schema import TaskNames from surya.recognition import OCRResult def test_latex_ocr(recognition_predictor, test_image_latex): width, height = test_image_latex.size results: List[OCRResult] = recognition_predictor( [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]] ) text = results[0].text_lines[0].text assert len(results) == 1 assert text.startswith("") ================================================ FILE: tests/test_layout.py ================================================ def test_layout_topk(layout_predictor, test_image): layout_results = layout_predictor([test_image]) assert len(layout_results) == 1 assert layout_results[0].image_bbox == [0, 0, 1024, 1024] bboxes = layout_results[0].bboxes assert len(bboxes) == 2 assert bboxes[0].label == "SectionHeader" assert len(bboxes[0].top_k) == 5 assert bboxes[1].label == "Text" assert len(bboxes[1].top_k) == 5 ================================================ FILE: tests/test_ocr_errors.py ================================================ def test_garbled_text(ocr_error_predictor): text = """" ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj 2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d """.strip() results = ocr_error_predictor([text]) assert results.labels[0] == "bad" def test_good_text(ocr_error_predictor): text = """" There are professions more harmful than industrial design, but only a very few of them. """.strip() results = ocr_error_predictor([text]) assert results.labels[0] == "good" ================================================ FILE: tests/test_recognition.py ================================================ import time from PIL import ImageDraw, Image from surya.recognition.util import clean_math_tags def test_recognition(recognition_predictor, detection_predictor, test_image): recognition_results = recognition_predictor([test_image], None, detection_predictor) assert len(recognition_results) == 1 assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] text_lines = recognition_results[0].text_lines assert len(text_lines) == 4 assert "Hello World" in text_lines[0].text def test_recognition_input_text(recognition_predictor, detection_predictor, test_image): start = time.time() recognition_predictor([test_image], None, detection_predictor) end = time.time() - start input_text = "a" * 400 start2 = time.time() recognition_results = recognition_predictor( [test_image], None, detection_predictor, input_text=[input_text] ) end2 = time.time() - start2 assert max([end, end2]) / min([end, end2]) < 1.5, ( "Input text should be truncated and not change inference time" ) assert len(recognition_results) == 1 assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] text_lines = recognition_results[0].text_lines assert len(text_lines) == 4 assert "Hello World" in text_lines[0].text def test_recognition_drop_repeats(recognition_predictor, detection_predictor): image = Image.new("RGB", (1024, 128), "white") draw = ImageDraw.Draw(image) text = "a" * 80 draw.text((5, 5), text, fill="black", font_size=24) recognition_results = recognition_predictor( [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True ) assert len(recognition_results) == 1 result = recognition_results[0].text_lines assert result[0].text == "" def test_recognition_clean_math(): math = """na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right)
+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right)
=: I_1 + I_2. (1.7)'""" clean_math = clean_math_tags(math) assert clean_math.count("") == 1, "Should have one closing math tag" assert "
" not in clean_math, "Should not have
tags in cleaned math" def test_recognition_clean_math_preserve_text(): text = """Hello, this is a sentence with x^2 + y^2 = z^2 and some text after it, with a weird tag and .""" clean_text = clean_math_tags(text) assert clean_text == text ================================================ FILE: tests/test_table_rec.py ================================================ from PIL import Image, ImageDraw def test_table_rec(table_rec_predictor): data = [ ["Name", "Age", "City"], ["Alice", 25, "New York"], ["Bob", 30, "Los Angeles"], ["Charlie", 35, "Chicago"], ] test_image = draw_table(data) results = table_rec_predictor([test_image]) assert len(results) == 1 assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]] cells = results[0].cells assert len(cells) == 12 for row_id in range(4): for col_id in range(3): cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id] assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}" def draw_table(data, cell_width=100, cell_height=40): rows = len(data) cols = len(data[0]) width = cols * cell_width height = rows * cell_height image = Image.new('RGB', (width, height), 'white') draw = ImageDraw.Draw(image) for i in range(rows + 1): y = i * cell_height draw.line([(0, y), (width, y)], fill='black', width=1) for i in range(cols + 1): x = i * cell_width draw.line([(x, 0), (x, height)], fill='black', width=1) for i in range(rows): for j in range(cols): text = str(data[i][j]) text_bbox = draw.textbbox((0, 0), text) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] x = j * cell_width + (cell_width - text_width) // 2 y = i * cell_height + (cell_height - text_height) // 2 draw.text((x, y), text, fill='black') return image ================================================ FILE: texify_app.py ================================================ from surya.scripts.run_texify_app import texify_app_cli if __name__ == "__main__": texify_app_cli()