Repository: datalab-to/surya
Branch: master
Commit: e735028979a2
Files: 136
Total size: 740.2 KB
Directory structure:
gitextract_x32e43uo/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── breaking-bug-report.md
│ │ ├── feature_request.md
│ │ └── output-bug-report.md
│ └── workflows/
│ ├── benchmarks.yml
│ ├── ci.yml
│ ├── cla.yml
│ ├── publish.yml
│ └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CLA.md
├── LICENSE
├── MODEL_LICENSE
├── README.md
├── benchmark/
│ ├── detection.py
│ ├── layout.py
│ ├── ordering.py
│ ├── recognition.py
│ ├── table_recognition.py
│ ├── texify.py
│ └── utils/
│ ├── __init__.py
│ ├── bbox.py
│ ├── metrics.py
│ ├── scoring.py
│ ├── tatr.py
│ ├── tesseract.py
│ ├── textract.py
│ └── verify_benchmark_scores.py
├── detect_layout.py
├── detect_text.py
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── pyproject.toml
├── pytest.ini
├── signatures/
│ └── version1/
│ └── cla.json
├── static/
│ └── fonts/
│ └── .gitignore
├── surya/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── adetr/
│ │ │ └── decoder.py
│ │ ├── donut/
│ │ │ ├── encoder.py
│ │ │ └── processor.py
│ │ ├── load.py
│ │ ├── polygon.py
│ │ ├── predictor.py
│ │ ├── pretrained.py
│ │ ├── s3.py
│ │ ├── surya/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── decoder/
│ │ │ │ ├── __init__.py
│ │ │ │ └── config.py
│ │ │ ├── embedder/
│ │ │ │ └── __init__.py
│ │ │ ├── encoder/
│ │ │ │ ├── __init__.py
│ │ │ │ └── config.py
│ │ │ ├── flash_attn_utils.py
│ │ │ ├── processor/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── schema.py
│ │ │ │ └── tokenizer.py
│ │ │ └── schema.py
│ │ ├── util.py
│ │ └── xla.py
│ ├── debug/
│ │ ├── draw.py
│ │ ├── fonts.py
│ │ ├── katex.js
│ │ ├── render_html.py
│ │ └── text.py
│ ├── detection/
│ │ ├── __init__.py
│ │ ├── heatmap.py
│ │ ├── loader.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── encoderdecoder.py
│ │ ├── parallel.py
│ │ ├── processor.py
│ │ ├── schema.py
│ │ └── util.py
│ ├── foundation/
│ │ ├── __init__.py
│ │ ├── cache/
│ │ │ ├── __init__.py
│ │ │ ├── dynamic_ops.py
│ │ │ └── static_ops.py
│ │ ├── loader.py
│ │ └── util.py
│ ├── input/
│ │ ├── load.py
│ │ └── processing.py
│ ├── layout/
│ │ ├── __init__.py
│ │ ├── label.py
│ │ └── schema.py
│ ├── logging.py
│ ├── models.py
│ ├── ocr_error/
│ │ ├── __init__.py
│ │ ├── loader.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── encoder.py
│ │ ├── schema.py
│ │ └── tokenizer.py
│ ├── recognition/
│ │ ├── __init__.py
│ │ ├── languages.py
│ │ ├── postprocessing.py
│ │ ├── schema.py
│ │ └── util.py
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── detect_layout.py
│ │ ├── detect_text.py
│ │ ├── finetune_ocr.py
│ │ ├── hf_to_s3.py
│ │ ├── ocr_latex.py
│ │ ├── ocr_text.py
│ │ ├── run_streamlit_app.py
│ │ ├── run_texify_app.py
│ │ ├── streamlit_app.py
│ │ ├── table_recognition.py
│ │ └── texify_app.py
│ ├── settings.py
│ └── table_rec/
│ ├── __init__.py
│ ├── loader.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ └── encoderdecoder.py
│ ├── processor.py
│ ├── schema.py
│ └── shaper.py
├── table_recognition.py
├── tests/
│ ├── conftest.py
│ ├── test_detection.py
│ ├── test_foundation.py
│ ├── test_latex_ocr.py
│ ├── test_layout.py
│ ├── test_ocr_errors.py
│ ├── test_recognition.py
│ └── test_table_rec.py
└── texify_app.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/breaking-bug-report.md
================================================
---
name: Breaking bug report
about: Create a report about a breaking bug
title: "[BUG: Breaking]"
labels: 'bug: breaking'
assignees: ''
---
## 🧨 Describe the Bug
A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc).
## 📄 Input Document
Attach the PDF or input file that triggered the error.
## 📤 Output Trace / Stack Trace
Paste the **complete** stack trace or error output, if available.
Click to expand
```
Paste stack trace here
```
## ⚙️ Environment
Please fill in all relevant details:
- **Marker version**:
- **Surya version**:
- **Python version**:
- **PyTorch version**:
- **Transformers version**:
- **Operating System** (incl. container info if relevant):
## ✅ Expected Behavior
What did you expect Marker to do?
## 📟 Command or Code Used
Paste the **exact bash command** or **Python code** you used to run Marker:
Click to expand
```bash
# or Python code block
your_command_here --with-flags
```
## 📎 Additional Context
Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings).
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: "[FEAT]"
labels: enhancement
assignees: ''
---
## ✨ Is your feature request related to a problem?
A clear and concise description of what the problem is.
## 💡 Describe the Solution You'd Like
A concise description of what you want to happen or how you envision it working.
## 📋 Alternatives Considered
Any alternative solutions or workarounds you've tried.
## 🧩 Additional Context
Any additional context, references, or related issues.
================================================
FILE: .github/ISSUE_TEMPLATE/output-bug-report.md
================================================
---
name: Output bug report
about: Create a report about poor output quality
title: "[BUG: Output]"
labels: 'bug: output'
assignees: ''
---
## 📝 Describe the Output Issue
A clear and concise description of the incorrect or unexpected output.
## 📄 Input Document
Attach the PDF or input file used.
## 📤 Current Output
Paste the Markdown or HTML that Marker generated:
````markdown
Paste output here
`````
## ✅ Expected Output
Describe or paste what you expected Marker to generate.
## ⚙️ Environment
Please fill in all relevant details:
* **Marker version**:
* **Surya version**:
* **Python version**:
* **PyTorch version**:
* **Transformers version**:
* **Operating System**:
## 📟 Command or Code Used
Paste the **exact bash command** or **Python code** you used to run Marker:
Click to expand
```bash
# or Python code block
your_command_here --with-flags
```
## 📎 Additional Context
Any other relevant info, configs, or assumptions.
================================================
FILE: .github/workflows/benchmarks.yml
================================================
name: Integration test
on: [push]
env:
PYTHONIOENCODING: "utf-8"
jobs:
build:
runs-on: t4_gpu
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max_rows 2
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
- name: Run recognition benchmark test
run: |
poetry run python benchmark/recognition.py --max_rows 2
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
- name: Run layout benchmark test
run: |
poetry run python benchmark/layout.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark
run: |
poetry run python benchmark/ordering.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
- name: Run table recognition benchmark
run: |
poetry run python benchmark/table_recognition.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
- name: Run texify benchmark
run: |
poetry run python benchmark/texify.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify
================================================
FILE: .github/workflows/ci.yml
================================================
name: Unit tests
on: [push]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [t4_gpu, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: poetry run pytest
================================================
FILE: .github/workflows/cla.yml
================================================
name: "Surya CLA Assistant"
on:
issue_comment:
types: [created]
pull_request_target:
types: [opened,closed,synchronize]
# explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings
permissions:
actions: write
contents: write
pull-requests: write
statuses: write
jobs:
CLAAssistant:
runs-on: ubuntu-latest
steps:
- name: "Surya CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action@v2.3.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# the below token should have repo scope and must be manually added by you in the repository's secret
# This token is required only if you have configured to store the signatures in a remote repository/organization
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
with:
path-to-signatures: 'signatures/version1/cla.json'
path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md'
# branch should not be protected
branch: 'master'
allowlist: VikParuchuri
================================================
FILE: .github/workflows/publish.yml
================================================
name: Python package
on:
push:
tags:
- "v*.*.*"
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Build package
run: |
poetry build
- name: Publish package
env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
poetry config pypi-token.pypi "$PYPI_TOKEN"
poetry publish
================================================
FILE: .github/workflows/scripts.yml
================================================
name: Test CLI scripts
on: [push]
jobs:
build:
runs-on: t4_gpu
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Download benchmark data
run: |
wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
unzip -o benchmark_data.zip
- name: Test detection
run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test OCR
env:
RECOGNITION_MAX_TOKENS: 25
run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test layout
run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test table
run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test texify
env:
TEXIFY_MAX_TOKENS: 25
run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test detection folder
run: poetry run surya_detect benchmark_data/pdfs --page_range 0
================================================
FILE: .gitignore
================================================
private.py
.DS_Store
local.env
experiments
test_data
training
wandb
notebooks
results
data
slices
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.10
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, please cite it using the following metadata."
title: "Surya: A lightweight framework for analyzing documents and PDFs at scale"
authors:
- family-names: Paruchuri
given-names: Vikas
- name: Datalab Team
date-released: 2025-05-13
url: https://github.com/VikParuchuri/surya
version: 0.14.0
repository-code: https://github.com/VikParuchuri/surya
================================================
FILE: CLA.md
================================================
Surya Contributor Agreement
This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below.
If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.
1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution:
- you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers;
- you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work;
- you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees;
- you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and
- you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution.
3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to:
- make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and
- at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements.
If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed.
4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license.
5. You covenant, represent, warrant and agree that:
- each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA;
- to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and
- each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.
You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply.
================================================
FILE: LICENSE
================================================
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
Surya OCR
Copyright (C) 2024 Endless Labs, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
Surya OCR Copyright (C) 2024 Endless Labs, Inc.
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
.
================================================
FILE: MODEL_LICENSE
================================================
AI PUBS OPEN RAIL-M LICENSE (MODIFIED)
Version 0.1, March 2, 2023 (Modified)
http://licenses.ai/
PLEASE READ THESE TERMS CAREFULLY BEFORE USING THE MODEL OR A DERIVATIVE WORKS OF THE MODEL MADE AVAILABLE IN CONNECTION WITH THESE TERMS. BY DOWNLOADING, REPRODUCING, DISTRIBUTING OR USING THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER, YOU (“YOU”) AGREE TO BE BOUND BY THESE TERMS (THE “AGREEMENT”) TO THE EXCLUSION OF ALL OTHER TERMS. YOU REPRESENT AND WARRANT THAT YOU HAVE THE AUTHORITY TO ENTER INTO THIS AGREEMENT; IF YOU ARE ENTERING INTO THIS AGREEMENT ON BEHALF OF AN ORGANIZATION OR ENTITY, REFERENCES TO AND “YOU” IN THIS AGREEMENT, REFER TO THAT ORGANIZATION OR ENTITY. IF YOU DO NOT AGREE TO ALL OF THE FOLLOWING, YOU MAY NOT DOWNLOAD, REPRODUCE, DISTRIBUTE OR USE THE MODEL OR A DERIVATIVE WORK OF THE MODEL IN ANY MANNER.
Section I: PREAMBLE
This OpenRAIL-M License, as modified, is generally applicable to any machine-learning Model.
The “Open” nomenclature indicates that the licensed Model is be freely accessible to downstream and other users. The “RAIL” nomenclature indicates that there are use restrictions prohibiting the use of the Model. These restrictions are intended to avoid potential misuse. This License specifies that the use restrictions in the original License must apply to such derivatives.
NOW THEREFORE, You and Licensor agree as follows:
1. Definitions
(a) “Complementary Material” means the applicable source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, and any related information, if any. Complementary Material is not licensed under this License.
(b) "Contribution" means any work, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the rights owner or by an individual or legal entity authorized to submit on behalf of the rights owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the rights owner as "Not a Contribution."
(c) "Contributor" means Licensor and any individual or legal entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
(d) “Data” means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
(e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
(f) “Distribution” means any transmission, reproduction, publication, distribution, or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means, including but not limited to API-based or web access.
(g) “Harm” includes but is not limited to physical, mental, psychological, financial and reputational damage, pain, or loss
(h) "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
(i) “Licensor” means the rights owner or entity authorized by the rights owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
(j) “Model” means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
(k) “Output” means the results of operating a Model as embodied in informational content resulting therefrom.
(l) “Third Parties” means individuals or legal entities that are not under common control with Licensor or You.
(m) "You" (or "Your") means an individual or legal entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application, including but not limited to a chatbot, translator, or image generator.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants may apply to the Model and Derivatives of the Model. The Model and Derivatives of the Model are subject to additional terms as described in Section III, which shall govern the use of the Model and Derivatives of the Model even in the event Section II is held unenforceable.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Model and Derivatives of the Model.
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and/or Derivatives of the Model where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model or Derivatives of the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model or Derivative of the Model and/or a Contribution incorporated within the Model or Derivative of the Model constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Derivative of the Model shall terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host the Model or Derivatives of the Model for remote access by Third Parties, including but not limited to software-as-a-service, reproduce, or Distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the conditions in this Section III:
(a) Use-based restrictions in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (for example, a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model and Derivatives of the Model are subject to paragraph 5;
(b) You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
(c) You must cause any modified files to carry prominent notices stating that You changed the files; and
(d) You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model or Derivatives of the Model.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions, consistent with paragraph 4.a., for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Accordingly, You cannot use the Model or the Derivatives of the Model in violation of such restrictions. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, fine-tuning, updating, running, training, evaluating and/or re-parametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph 5.
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are solely responsible for the Output you generate and its subsequent uses. No use of the Output can contravene any provision as stated in the License.
7. Attribution. In connection with any Output, or use of Distribution of any Model or Derivatives of the Model, You agree to give appropriate credit and attribution to Licensor, provide a link to the original Model or Derivatives of the Model, provide a copy of this License, and identify any changes You have made to the Model or Derivatives of the Model (collectively, the “Attribution”). The Attribution must not suggest endorsement by any Licensor.
8. Share-a-Like. As a condition to the license and authorizations herein, You agree to apply this License (to the exclusion of all others) to any and all copies of the Model, Derivatives of the Model, any changes or improvements to the Model or Derivatives of the Model, and to the Output and any derivatives, changes or improvements to or of the Output.
Section IV: OTHER PROVISIONS
9. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or cause modification to the Output resulting from updates to the Model based.
10. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
11. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model and Derivatives of the Model, and assume any risks associated with Your exercise of permissions under this License.
12. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
13. Accepting Warranty or Additional Liability. While Distributing the Model or Derivatives of the Model, You may choose to charge a fee in exchange for support, warranty, indemnity, or other obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor or Licensor, and only if You agree to indemnify, defend, and hold each Contributor and the Licensor harmless for any liability incurred by, or claims asserted against, such Contributor or Licensor by reason of your accepting any such warranty or additional liability.
14. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A
USE RESTRICTIONS
As conditions to the Licenses set forth in this Agreement, You agree not to use, reproduce, modify, create or Distribute the Model, Derivatives of the Model, or Output (collectively, “Use”) in any of the following ways:
1. Legal:
(a) In any way that violates any applicable national, federal, state, local or international law or regulation; or
(b) to directly or indirectly infringe or misappropriate any third party intellectual property rights (including those of Licensor or any Contributor)
2. Commercial:
(a) for any purpose if You (your employer, or the entity you are affiliated with) generated more than two million US Dollars ($2,000,000) in gross revenue in the prior year, except where Your Use is limited to personal use or research purposes;
(b) for any purpose if You (your employer, or the entity you are affiliated with) has raised more than two million US dollars ($2,000,000) in total equity or debt funding from any source, except where Your Use is limited to personal use or research purposes; or
(c) for any purpose if You (your employer, or the entity you are affiliated with) provides or otherwise makes available any product or service that competes with any product or service offered by or made available by Licensor or any of its affiliates.
Commercial and broader use licenses may be available from Licensor at the following URL: https://www.datalab.to/
================================================
FILE: README.md
================================================
# Surya
Surya is a document OCR toolkit that does:
- OCR in 90+ languages that benchmarks favorably vs cloud services
- Line-level text detection in any language
- Layout analysis (table, image, header, etc detection)
- Reading order detection
- Table recognition (detecting rows/columns)
- LaTeX OCR
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya).
| Detection | OCR |
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
| | |
| Layout | Reading Order |
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
| | |
| Table Recognition | LaTeX OCR |
|:-------------------------------------------------------------:|:------------------------------------------------------:|
| | |
Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
## Community
[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.
## Examples
| Name | Detection | OCR | Layout | Order | Table Rec |
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | |
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | |
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | |
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | |
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | |
# Hosted API
There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya):
- Works with PDF, images, word docs, and powerpoints
- Consistent speed, with no latency spikes
- High reliability and uptime
# Commercial usage
Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya).
# Installation
You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details.
Install with:
```shell
pip install surya-ocr
```
Model weights will automatically download the first time you run surya.
# Usage
- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
## Interactive App
I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with:
```shell
pip install streamlit pdftext
surya_gui
```
## OCR (text recognition)
This command will write out a json file with the detected text and bboxes:
```shell
surya_ocr DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--task_name` will specify which task to use for predicting the lines. `ocr_with_boxes` is the default, which will format text and give you bboxes. If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes. For blocks like equations and paragraphs, try `block_without_boxes`.
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
- `--disable_math` - by default, surya will recognize math in text. This can lead to false positives - you can disable this with this flag.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
- `text_lines` - the detected text and bounding boxes for each line
- `text` - the text in the line
- `confidence` - the confidence of the model in the detected text (0-1)
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `chars` - the individual characters in the line
- `text` - the text of the character
- `bbox` - the character bbox (same format as line bbox)
- `polygon` - the character polygon (same format as line polygon)
- `confidence` - the confidence of the model in the detected character (0-1)
- `bbox_valid` - if the character is a special token or math, the bbox may not be valid
- `words` - the individual words in the line (computed from the characters)
- `text` - the text of the word
- `bbox` - the word bbox (same format as line bbox)
- `polygon` - the word polygon (same format as line polygon)
- `confidence` - mean character confidence
- `bbox_valid` - if the word is a special token or math, the bbox may not be valid
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
**Performance tips**
Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `512`, which will use about 20GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`.
### From python
```python
from PIL import Image
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
image = Image.open(IMAGE_PATH)
foundation_predictor = FoundationPredictor()
recognition_predictor = RecognitionPredictor(foundation_predictor)
detection_predictor = DetectionPredictor()
predictions = recognition_predictor([image], det_predictor=detection_predictor)
```
## Text line detection
This command will write out a json file with the detected bboxes.
```shell
surya_detect DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `confidence` - the confidence of the model in the detected text (0-1)
- `vertical_lines` - vertical lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
**Performance tips**
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.
### From python
```python
from PIL import Image
from surya.detection import DetectionPredictor
image = Image.open(IMAGE_PATH)
det_predictor = DetectionPredictor()
# predictions is a list of dicts, one per image
predictions = det_predictor([image])
```
## Layout and reading order
This command will write out a json file with the detected layout and reading order.
```shell
surya_layout DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `position` - the reading order of the box.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
- `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
**Performance tips**
Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
### From python
```python
from PIL import Image
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.settings import settings
image = Image.open(IMAGE_PATH)
layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))
# layout_predictions is a list of dicts, one per image
layout_predictions = layout_predictor([image])
```
## Table Recognition
This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo. You can use the `TableConverter` to detect and extract tables in images and PDFs. It supports output in json (with bboxes), markdown, and html.
```shell
surya_table DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
- `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible.
- `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
- `rows` - detected table rows
- `bbox` - the bounding box of the table row
- `row_id` - the id of the row
- `is_header` - if it is a header row.
- `cols` - detected table columns
- `bbox` - the bounding box of the table column
- `col_id`- the id of the column
- `is_header` - if it is a header column
- `cells` - detected table cells
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `text` - if text could be pulled out of the pdf, the text of this cell.
- `row_id` - the id of the row the cell belongs to.
- `col_id` - the id of the column the cell belongs to.
- `colspan` - the number of columns spanned by the cell.
- `rowspan` - the number of rows spanned by the cell.
- `is_header` - whether it is a header cell.
- `page` - the page number in the file
- `table_idx` - the index of the table on the page (sorted in vertical order)
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
**Performance tips**
Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.
### From python
```python
from PIL import Image
from surya.table_rec import TableRecPredictor
image = Image.open(IMAGE_PATH)
table_rec_predictor = TableRecPredictor()
table_predictions = table_rec_predictor([image])
```
## LaTeX OCR
This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want.
```shell
surya_latex_ocr DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. See the OCR section above for the format of the output.
### From python
```python
from PIL import Image
from surya.texify import TexifyPredictor
image = Image.open(IMAGE_PATH)
predictor = TexifyPredictor()
predictor([image])
```
### Interactive app
You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:
```shell
pip install streamlit==1.40 streamlit-drawable-canvas-jsretry
texify_gui
```
## Compilation
The following models have support for compilation. You will need to set the following environment variables to enable compilation:
- Detection: `COMPILE_DETECTOR=true`
- Layout: `COMPILE_LAYOUT=true`
- Table recognition: `COMPILE_TABLE_REC=true`
Alternatively, you can also set `COMPILE_ALL=true` which will compile all models.
Here are the speedups on an A10 GPU:
| Model | Time per page (s) | Compiled time per page (s) | Speedup (%) |
| ----------------- | ----------------- | -------------------------- | ----------- |
| Detection | 0.108808 | 0.10521 | 3.306742151 |
| Layout | 0.27319 | 0.27063 | 0.93707676 |
| Table recognition | 0.0219 | 0.01938 | 11.50684932 |
# Limitations
- This is specialized for document OCR. It will likely not work on photos or other images.
- It is for printed text, not handwriting (though it may work on some handwriting).
- The text detection model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/recognition/languages.py`. Text detection, layout analysis, and reading order will work with any language.
## Troubleshooting
If OCR isn't working properly:
- Try increasing resolution of the image so the text is bigger. If the resolution is already very high, try decreasing it to no more than a `2048px` width.
- Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images.
- You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space. `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text. `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range. Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds).
# Manual install
If you want to develop surya, you can install it manually:
- `git clone https://github.com/VikParuchuri/surya.git`
- `cd surya`
- `poetry install` - installs main and dev dependencies
- `poetry shell` - activates the virtual environment
# Benchmarks
## OCR

| Model | Time per page (s) | Avg similarity (⬆) |
|-----------|-------------------|--------------------|
| surya | .62 | 0.97 |
| tesseract | .45 | 0.88 |
[Full language results](static/images/rec_acc_table.png)
Tesseract is CPU-based, and surya is CPU or GPU. I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean).
### Google Cloud Vision
I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya.

[Full language results](static/images/gcloud_full_langs.png)
**Methodology**
I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs. I sampled PDFs from common crawl, then filtered out the ones with bad OCR. I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those.
I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality.
For Google Cloud, I aligned the output from Google Cloud with the ground truth. I had to skip RTL languages since they didn't align well.
## Text line detection

| Model | Time (s) | Time per page (s) | precision | recall |
|-----------|------------|---------------------|-------------|----------|
| surya | 47.2285 | 0.094452 | 0.835857 | 0.960807 |
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |
Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage:
- tesseract - 32 CPU cores, or 8 workers using 4 cores each
- surya - 36 batch size, for 16GB VRAM usage
**Methodology**
Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level. It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation.
I instead used coverage, which calculates:
- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes
First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes. Anything with a coverage of 0.5 or higher is considered a match.
Then we calculate precision and recall for the whole dataset.
## Layout analysis
| Layout Type | precision | recall |
|---------------|-------------|----------|
| Image | 0.91265 | 0.93976 |
| List | 0.80849 | 0.86792 |
| Table | 0.84957 | 0.96104 |
| Text | 0.93019 | 0.94571 |
| Title | 0.92102 | 0.95404 |
Time per image - .13 seconds on GPU (A10).
**Methodology**
I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data. I had to align publaynet labels with the surya layout labels. I was then able to find coverage for each layout type:
- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes
## Reading Order
88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
**Methodology**
I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.
## Table Recognition
| Model | Row Intersection | Col Intersection | Time Per Image |
|-------------------|--------------------|--------------------|------------------|
| Surya | 1 | 0.98625 | 0.30202 |
| Table transformer | 0.84 | 0.86857 | 0.08082 |
Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions. This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker)
**Methodology**
The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns.
## LaTeX OCR
| Method | edit ⬇ | time taken (s) ⬇ |
|--------|----------|------------------|
| texify | 0.122617 | 35.6345 |
This inferences texify on a ground truth set of LaTeX, then does edit distance. This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them.
## Running your own benchmarks
You can benchmark the performance of surya on your machine.
- Follow the manual install instructions above.
- `poetry install --group dev` - installs dev dependencies
**Text line detection**
This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).
```shell
python benchmark/detection.py --max_rows 256
```
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images and detected bboxes
- `--pdf_path` will let you specify a pdf to benchmark instead of the default data
- `--results_dir` will let you specify a directory to save results to instead of the default one
**Text recognition**
This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).
```shell
python benchmark/recognition.py --tesseract
```
- `--max_rows` controls how many images to process for the benchmark
- `--debug 2` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
- Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark.
- Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking. This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus).
**Layout analysis**
This will evaluate surya on the publaynet dataset.
```shell
python benchmark/layout.py
```
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
**Reading Order**
```shell
python benchmark/ordering.py
```
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
**Table Recognition**
```shell
python benchmark/table_recognition.py --max_rows 1024 --tatr
```
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tatr` specifies whether to also run table transformer
**LaTeX OCR**
```shell
python benchmark/texify.py --max_rows 128
```
- `--max_rows` controls how many images to process for the benchmark
- `--results_dir` will let you specify a directory to save results to instead of the default one
# Training
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.
Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).
# Finetuning Surya OCR
You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py).
It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed.
To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script.
```bash
# Tested on 1xH100 GPU
# Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise
# the default surya ocr weights will be loaded as the initialization
python surya/scripts/finetune_ocr.py \
--output_dir $OUTPUT_DIR \
--dataset_name datalab-to/ocr_finetune_example \
--per_device_train_batch_size 64 \
--gradient_checkpointing true \
--max_sequence_length 1024
```
This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at hi@datalab.to!
# Thanks
This work would not have been possible without amazing open source AI work:
- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
- [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT
- [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman
- [Donut](https://github.com/clovaai/donut) from Naver
- [transformers](https://github.com/huggingface/transformers) from huggingface
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
Thank you to everyone who makes open source AI possible.
# Citation
If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry:
```bibtex
@misc{paruchuri2025surya,
author = {Vikas Paruchuri and Datalab Team},
title = {Surya: A lightweight document OCR and analysis toolkit},
year = {2025},
howpublished = {\url{https://github.com/VikParuchuri/surya}},
note = {GitHub repository},
}
================================================
FILE: benchmark/detection.py
================================================
import argparse
import collections
import copy
import json
import click
from benchmark.utils.bbox import get_pdf_lines
from benchmark.utils.metrics import precision_recall
from benchmark.utils.tesseract import tesseract_parallel
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.debug.draw import draw_polys_on_image
from surya.common.util import rescale_bbox
from surya.settings import settings
from surya.detection import DetectionPredictor
import os
import time
from tabulate import tabulate
import datasets
@click.command(help="Benchmark detection model.")
@click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False)
def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool):
det_predictor = DetectionPredictor()
if pdf_path is not None:
pathname = pdf_path
doc = open_pdf(pdf_path)
page_count = len(doc)
page_indices = list(range(page_count))
page_indices = page_indices[:max_rows]
images = get_page_images(doc, page_indices)
doc.close()
image_sizes = [img.size for img in images]
correct_boxes = get_pdf_lines(pdf_path, image_sizes)
else:
pathname = "det_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
images = list(dataset["image"])
images = convert_if_not_rgb(images)
correct_boxes = []
for i, boxes in enumerate(dataset["bboxes"]):
img_size = images[i].size
# 1000,1000 is bbox size for doclaynet
correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])
if settings.DETECTOR_STATIC_CACHE:
# Run through one batch to compile the model
det_predictor(images[:1])
start = time.time()
predictions = det_predictor(images)
surya_time = time.time() - start
if tesseract:
start = time.time()
tess_predictions = tesseract_parallel(images)
tess_time = time.time() - start
else:
tess_predictions = [None] * len(images)
tess_time = None
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
page_metrics = collections.OrderedDict()
for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
surya_boxes = [s.bbox for s in sb.bboxes]
surya_polys = [s.polygon for s in sb.bboxes]
surya_metrics = precision_recall(surya_boxes, cb)
if tb is not None:
tess_metrics = precision_recall(tb, cb)
else:
tess_metrics = None
page_metrics[idx] = {
"surya": surya_metrics,
"tesseract": tess_metrics
}
if debug:
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))
mean_metrics = {}
metric_types = sorted(page_metrics[0]["surya"].keys())
models = ["surya"]
if tesseract:
models.append("tesseract")
for k in models:
for m in metric_types:
metric = []
for page in page_metrics:
metric.append(page_metrics[page][k][m])
if k not in mean_metrics:
mean_metrics[k] = {}
mean_metrics[k][m] = sum(metric) / len(metric)
out_data = {
"times": {
"surya": surya_time,
"tesseract": tess_time
},
"metrics": mean_metrics,
"page_metrics": page_metrics
}
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)
table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
table_data = [
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
]
if tesseract:
table_data.append(
["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types]
)
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: benchmark/layout.py
================================================
import collections
import copy
import json
import click
from benchmark.utils.metrics import precision_recall
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.input.processing import convert_if_not_rgb
from surya.debug.draw import draw_bboxes_on_image
from surya.settings import settings
import os
import time
from tabulate import tabulate
import datasets
@click.command(help="Benchmark surya layout model.")
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows",
type=int,
help="Maximum number of images to run benchmark on.",
default=100,
)
@click.option("--debug", is_flag=True, help="Run in debug mode.", default=False)
def main(results_dir: str, max_rows: int, debug: bool):
foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_predictor = LayoutPredictor(foundation_predictor)
pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(
settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]"
)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
if settings.LAYOUT_STATIC_CACHE:
layout_predictor(images[:1])
start = time.time()
layout_predictions = layout_predictor(images)
surya_time = time.time() - start
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
label_alignment = { # First is publaynet, second is surya
"Image": [["Figure"], ["Picture", "Figure"]],
"Table": [["Table"], ["Table", "Form", "TableOfContents"]],
"Text": [
["Text"],
[
"Text",
"Formula",
"Footnote",
"Caption",
"TextInlineMath",
"Code",
"Handwriting",
],
],
"List": [["List"], ["ListItem"]],
"Title": [["Title"], ["SectionHeader", "Title"]],
}
page_metrics = collections.OrderedDict()
for idx, pred in enumerate(layout_predictions):
row = dataset[idx]
all_correct_bboxes = []
page_results = {}
for label_name in label_alignment:
correct_cats, surya_cats = label_alignment[label_name]
correct_bboxes = [
b
for b, category in zip(row["bboxes"], row["labels"])
if category in correct_cats
]
all_correct_bboxes.extend(correct_bboxes)
pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]
metrics = precision_recall(
pred_bboxes, correct_bboxes, penalize_double=False
)
weight = len(correct_bboxes)
metrics["weight"] = weight
page_results[label_name] = metrics
page_metrics[idx] = page_results
if debug:
bbox_image = draw_bboxes_on_image(
all_correct_bboxes, copy.deepcopy(images[idx])
)
bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))
mean_metrics = collections.defaultdict(dict)
layout_types = sorted(page_metrics[0].keys())
metric_types = sorted(page_metrics[0][layout_types[0]].keys())
metric_types.remove("weight")
for label in layout_types:
for m in metric_types:
metric = []
total = 0
for page in page_metrics:
metric.append(
page_metrics[page][label][m] * page_metrics[page][label]["weight"]
)
total += page_metrics[page][label]["weight"]
value = sum(metric)
if value > 0:
value /= total
mean_metrics[label][m] = value
out_data = {
"time": surya_time,
"metrics": mean_metrics,
"page_metrics": page_metrics,
}
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)
table_headers = [
"Layout Type",
] + metric_types
table_data = []
for layout_type in layout_types:
table_data.append(
[
layout_type,
]
+ [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types]
)
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print(
f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total."
)
print(
"Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold."
)
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: benchmark/ordering.py
================================================
import collections
import json
import click
from surya.foundation import FoundationPredictor
from surya.input.processing import convert_if_not_rgb
from surya.layout import LayoutPredictor
from surya.common.polygon import PolygonBox
from surya.settings import settings
from benchmark.utils.metrics import rank_accuracy
import os
import time
import datasets
@click.command(help="Benchmark surya layout for reading order.")
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with benchmark results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows",
type=int,
help="Maximum number of images to run benchmark on.",
default=None,
)
def main(results_dir: str, max_rows: int):
foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_predictor = LayoutPredictor(foundation_predictor)
pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if max_rows is not None:
split = f"train[:{max_rows}]"
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
start = time.time()
layout_predictions = layout_predictor(images)
surya_time = time.time() - start
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
page_metrics = collections.OrderedDict()
mean_accuracy = 0
for idx, order_pred in enumerate(layout_predictions):
row = dataset[idx]
labels = row["labels"]
bboxes = row["bboxes"]
pred_positions = []
for label, bbox in zip(labels, bboxes):
max_intersection = 0
matching_idx = 0
for pred_box in order_pred.bboxes:
intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox))
if intersection > max_intersection:
max_intersection = intersection
matching_idx = pred_box.position
pred_positions.append(matching_idx)
accuracy = rank_accuracy(pred_positions, labels)
mean_accuracy += accuracy
page_results = {"accuracy": accuracy, "box_count": len(labels)}
page_metrics[idx] = page_results
mean_accuracy /= len(layout_predictions)
out_data = {
"time": surya_time,
"mean_accuracy": mean_accuracy,
"page_metrics": page_metrics,
}
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)
print(f"Mean accuracy is {mean_accuracy:.2f}.")
print(
f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total."
)
print("Mean accuracy is the % of correct ranking pairs.")
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: benchmark/recognition.py
================================================
import re
import unicodedata
from collections import defaultdict
import click
from benchmark.utils.scoring import overlap_score, overlap_score_exact
from surya.input.processing import convert_if_not_rgb
from surya.debug.text import draw_text_on_image
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.settings import settings
from surya.recognition.languages import CODE_TO_LANGUAGE
from benchmark.utils.tesseract import (
tesseract_ocr_parallel,
surya_lang_to_tesseract,
TESS_CODE_TO_LANGUAGE,
)
from benchmark.utils.textract import textract_ocr_parallel
import os
import datasets
import json
import time
from tabulate import tabulate
KEY_LANGUAGES = [
"Chinese",
"Spanish",
"English",
"Arabic",
"Hindi",
"Bengali",
"Russian",
"Japanese",
]
def list_in(lst: str | list, lst2: list):
if isinstance(lst, str):
lst = [lst]
return any([item in lst for item in lst2])
def standardize_bullets(text):
patterns = [
r"•\s+",
r"·\s+",
r"○\s+",
r"◦\s+",
r"▪\s+",
r"▫\s+",
r"➢\s+",
r"➤\s+",
r"★\s+",
r"✓\s+",
r"✗\s+",
r"✦\s+",
r"\\bullet\s+",
]
combined_pattern = "|".join(patterns)
text = re.sub(combined_pattern, "*", text)
return text
def normalize_text(text: str) -> str:
# Remove HTML tags
text = re.sub(r"<[^>]+>", "", text)
# Remove LaTeX tags
text = re.sub(r"\\[a-zA-Z]+", "", text)
text = standardize_bullets(text)
text = unicodedata.normalize("NFKC", text)
return text.strip().lower().replace(",", ".")
@click.command(help="Benchmark recognition model.")
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None
)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option(
"--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False
)
@click.option(
"--textract", is_flag=True, help="Run benchmarks on textract.", default=False
)
@click.option(
"--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28
)
@click.option(
"--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28
)
@click.option(
"--languages",
type=str,
help="Comma-separated list of languages to benchmark.",
default=None,
)
@click.option(
"--print_results",
is_flag=True,
)
def main(
results_dir: str,
max_rows: int,
debug: bool,
tesseract: bool,
textract: bool,
tess_cpus: int,
textract_cpus: int,
languages: str | None,
print_results: bool,
):
foundation_predictor = FoundationPredictor()
rec_predictor = RecognitionPredictor(foundation_predictor)
split = "train"
dataset = datasets.load_dataset(
settings.RECOGNITION_BENCH_DATASET_NAME, split=split
)
if languages:
languages = languages.split(",")
dataset = dataset.filter(
lambda x: list_in(x["language"], languages), num_proc=4
)
if max_rows and max_rows < len(dataset):
dataset = dataset.shuffle(seed=1).select(range(max_rows))
images = list(dataset["image"])
images = convert_if_not_rgb(images)
bboxes = list(dataset["bboxes"])
line_text = list(dataset["text"])
languages = list(dataset["language"])
print(f"Loaded {len(images)} images. Running OCR...")
start = time.time()
predictions_by_image = rec_predictor(images, None, bboxes=bboxes)
surya_time = time.time() - start
lang_list = []
for lang in languages:
if not isinstance(lang, list):
lang_list.append([lang])
else:
lang_list.append(lang)
surya_scores = defaultdict(list)
img_surya_scores = []
outputs = []
for idx, (pred, ref_text, langs) in enumerate(
zip(predictions_by_image, line_text, lang_list)
):
pred_text = [line.text for line in pred.text_lines]
score_ref_text = [normalize_text(line) for line in ref_text]
score_pred_text = [normalize_text(text) for text in pred_text]
image_scores, image_weights = overlap_score_exact(
score_pred_text, score_ref_text
)
normalized_scores = [
score / max(1, weight) for score, weight in zip(image_scores, image_weights)
]
image_score = sum(image_scores) / max(1, sum(image_weights))
img_surya_scores.append(image_score)
for lang in langs:
surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score)
assert len(pred_text) == len(ref_text) == len(bboxes[idx])
if debug:
for j, (pred_line, ref_line, score, bbox) in enumerate(
zip(pred_text, ref_text, normalized_scores, bboxes[idx])
):
image_slice = images[idx].crop(bbox)
outputs.append(
{
"image": image_slice,
"bbox": bbox,
"score": score,
"pred": pred_line,
"ref": ref_line,
"langs": ",".join(langs),
}
)
if debug:
out_ds = datasets.Dataset.from_list(outputs)
out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True)
flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]]
benchmark_stats = {
"surya": {
"avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)),
"lang_scores": {
lang: sum(scores) / max(1, len(scores))
for lang, scores in surya_scores.items()
},
"time_per_img": surya_time / max(1, len(images)),
}
}
result_path = os.path.join(results_dir, "rec_bench")
os.makedirs(result_path, exist_ok=True)
with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
json.dump(surya_scores, f)
if tesseract:
tess_valid = []
tess_langs = []
for idx, lang in enumerate(lang_list):
# Tesseract does not support all languages
tess_lang = surya_lang_to_tesseract(lang[0])
if tess_lang is None:
continue
tess_valid.append(idx)
tess_langs.append(tess_lang)
tess_imgs = [images[i] for i in tess_valid]
tess_bboxes = [bboxes[i] for i in tess_valid]
tess_reference = [line_text[i] for i in tess_valid]
start = time.time()
tess_predictions = tesseract_ocr_parallel(
tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus
)
tesseract_time = time.time() - start
tess_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(
zip(tess_predictions, tess_reference, tess_langs)
):
image_scores, image_weights, _ = overlap_score(pred, ref_text)
image_score = sum(image_scores) / max(1, sum(image_weights))
tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)
flat_tess_scores = [
score for lang in tess_scores for score in tess_scores[lang]
]
benchmark_stats["tesseract"] = {
"avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
"lang_scores": {
lang: sum(scores) / len(scores) for lang, scores in tess_scores.items()
},
"time_per_img": tesseract_time / len(tess_imgs),
}
with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
json.dump(tess_scores, f)
if textract:
start = time.time()
textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)
textract_time = time.time() - start
textract_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(
zip(textract_predictions, line_text, lang_list)
):
image_scores, image_weights, _ = overlap_score(pred, ref_text)
image_score = sum(image_scores) / max(1, sum(image_weights))
for lang in lang:
textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score)
flat_textract_scores = [
score for lang in textract_scores for score in textract_scores[lang]
]
benchmark_stats["textract"] = {
"avg_score": sum(flat_textract_scores) / len(flat_textract_scores),
"lang_scores": {
lang: sum(scores) / len(scores)
for lang, scores in textract_scores.items()
},
"time_per_img": textract_time / len(images),
}
print(len(flat_textract_scores))
with open(os.path.join(result_path, "textract_scores.json"), "w+") as f:
json.dump(textract_scores, f)
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(benchmark_stats, f)
key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages
table_data = [
[
"surya",
benchmark_stats["surya"]["time_per_img"],
benchmark_stats["surya"]["avg_score"],
]
+ [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages],
]
if tesseract:
table_data.append(
[
"tesseract",
benchmark_stats["tesseract"]["time_per_img"],
benchmark_stats["tesseract"]["avg_score"],
]
+ [
benchmark_stats["tesseract"]["lang_scores"].get(lang, 0)
for lang in key_languages
]
)
if textract:
table_data.append(
[
"textract",
benchmark_stats["textract"]["time_per_img"],
benchmark_stats["textract"]["avg_score"],
]
+ [
benchmark_stats["textract"]["lang_scores"][lang]
for lang in key_languages
],
)
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print(
"Only a few major languages are displayed. See the result path for additional languages."
)
if debug >= 1:
bad_detections = []
for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
if score < 0.8:
bad_detections.append((idx, lang, score))
print(f"Found {len(bad_detections)} bad detections. Writing to file...")
with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
json.dump(bad_detections, f)
if debug == 2:
for idx, (image, pred, ref_text, bbox, lang) in enumerate(
zip(images, predictions_by_image, line_text, bboxes, lang_list)
):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_text = [line.text for line in pred.text_lines]
pred_image = draw_text_on_image(bbox, pred_text, image.size)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image.save(os.path.join(result_path, ref_image_name))
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))
print(f"Wrote results to {result_path}")
if print_results:
for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)):
print(f"Image {idx}")
print("----")
for line_idx, (pred_line, ref_line) in enumerate(
zip(pred.text_lines, ref_text)
):
print(f"Sample {line_idx}")
print(f"Pred: {pred_line.text}")
print(f"Ref: {ref_line}")
print()
if settings.TORCH_DEVICE == "xla":
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
if __name__ == "__main__":
main()
================================================
FILE: benchmark/table_recognition.py
================================================
import click
import collections
import json
from surya.debug.draw import draw_bboxes_on_image
from tabulate import tabulate
from surya.input.processing import convert_if_not_rgb
from surya.table_rec import TableRecPredictor
from surya.settings import settings
from benchmark.utils.metrics import penalized_iou_score
from benchmark.utils.tatr import load_tatr, batch_inference_tatr
import os
import time
import datasets
@click.command(help="Benchmark table rec dataset")
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with benchmark results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows",
type=int,
help="Maximum number of images to run benchmark on.",
default=512,
)
@click.option("--tatr", is_flag=True, help="Run table transformer.", default=False)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
def main(results_dir: str, max_rows: int, tatr: bool, debug: bool):
table_rec_predictor = TableRecPredictor()
pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if max_rows is not None:
split = f"train[:{max_rows}]"
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
if settings.TABLE_REC_STATIC_CACHE:
# Run through one batch to compile the model
table_rec_predictor(images[:1])
start = time.time()
table_rec_predictions = table_rec_predictor(images)
surya_time = time.time() - start
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)):
row = dataset[idx]
pred_row_boxes = [p.bbox for p in pred.rows]
pred_col_bboxes = [p.bbox for p in pred.cols]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes),
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
if debug:
# Save debug images
draw_img = image.copy()
draw_bboxes_on_image(
pred_row_boxes,
draw_img,
[f"Row {i}" for i in range(len(pred_row_boxes))],
)
draw_bboxes_on_image(
pred_col_bboxes,
draw_img,
[f"Col {i}" for i in range(len(pred_col_bboxes))],
color="blue",
)
draw_img.save(os.path.join(result_path, f"{idx}_bbox.png"))
actual_draw_image = image.copy()
draw_bboxes_on_image(
actual_row_bboxes,
actual_draw_image,
[f"Row {i}" for i in range(len(actual_row_bboxes))],
)
draw_bboxes_on_image(
actual_col_bboxes,
actual_draw_image,
[f"Col {i}" for i in range(len(actual_col_bboxes))],
color="blue",
)
actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png"))
mean_col_iou /= len(table_rec_predictions)
mean_row_iou /= len(table_rec_predictions)
out_data = {
"surya": {
"time": surya_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics,
}
}
if tatr:
tatr_model = load_tatr()
start = time.time()
tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
tatr_time = time.time() - start
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, pred in enumerate(tatr_predictions):
row = dataset[idx]
pred_row_boxes = [p["bbox"] for p in pred["rows"]]
pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes),
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
mean_col_iou /= len(tatr_predictions)
mean_row_iou /= len(tatr_predictions)
out_data["tatr"] = {
"time": tatr_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics,
}
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)
table = [
["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
[
"Surya",
f"{out_data['surya']['mean_row_iou']:.2f}",
f"{out_data['surya']['mean_col_iou']:.5f}",
f"{surya_time / len(images):.5f}",
],
]
if tatr:
table.append(
[
"Table transformer",
f"{out_data['tatr']['mean_row_iou']:.2f}",
f"{out_data['tatr']['mean_col_iou']:.5f}",
f"{tatr_time / len(images):.5f}",
]
)
print(tabulate(table, headers="firstrow", tablefmt="github"))
print(
"Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions."
)
print(
"Note that table transformers is unbatched, since the example code in the repo is unbatched."
)
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: benchmark/texify.py
================================================
import os.path
import re
import time
from pathlib import Path
from typing import List
import click
import datasets
from tabulate import tabulate
from bs4 import BeautifulSoup
from surya.common.surya.schema import TaskNames
from surya.settings import settings
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor, OCRResult
import json
from rapidfuzz.distance import Levenshtein
def normalize_text(text):
soup = BeautifulSoup(text, "html.parser")
# Unwrap math tags
for tag in soup.find_all():
if tag.name == "math":
tag.unwrap()
text = soup.get_text()
text = re.sub(r"\n", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def score_text(predictions, references):
lev_dist = []
for p, r in zip(predictions, references):
p = normalize_text(p)
r = normalize_text(r)
lev_dist.append(Levenshtein.normalized_distance(p, r))
return sum(lev_dist) / len(lev_dist)
def inference_texify(
source_data, predictor: RecognitionPredictor, line_mode: bool = False
):
images = [sd["image"] for sd in source_data]
mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes
tasks = [mode] * len(images)
bboxes = [[[0, 0, image.width, image.height]] for image in images]
texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes)
out_data = [
{
"text": texify_predictions[i].text_lines[0].text,
"equation": source_data[i]["equation"],
}
for i in range(len(texify_predictions))
]
return out_data
@click.command(help="Benchmark the performance of texify.")
@click.option(
"--ds_name",
type=str,
help="Path to dataset file with source images/equations.",
default=settings.TEXIFY_BENCHMARK_DATASET,
)
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with benchmark results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows", type=int, help="Maximum number of images to benchmark.", default=None
)
@click.option(
"--line_mode", is_flag=True, help="Use line mode for texify.", default=False
)
def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):
foundation_predictor = FoundationPredictor()
predictor = RecognitionPredictor(foundation_predictor)
ds = datasets.load_dataset(ds_name, split="train")
if max_rows:
ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)
start = time.time()
predictions = inference_texify(ds, predictor, line_mode)
time_taken = time.time() - start
text = [p["text"] for p in predictions]
references = [p["equation"] for p in predictions]
scores = score_text(text, references)
write_data = {
"scores": scores,
"text": [{"prediction": p, "reference": r} for p, r in zip(text, references)],
}
score_table = [["texify", write_data["scores"], time_taken]]
score_headers = ["edit", "time taken (s)"]
score_dirs = ["⬇", "⬇"]
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
table = tabulate(score_table, headers=["Method", *score_headers])
print()
print(table)
result_path = Path(results_dir) / "texify_bench"
result_path.mkdir(parents=True, exist_ok=True)
with open(result_path / "results.json", "w", encoding="utf-8") as f:
json.dump(write_data, f, indent=4)
if __name__ == "__main__":
main()
================================================
FILE: benchmark/utils/__init__.py
================================================
================================================
FILE: benchmark/utils/bbox.py
================================================
import fitz as pymupdf
from surya.common.util import rescale_bbox
def get_pdf_lines(pdf_path, img_sizes):
doc = pymupdf.open(pdf_path)
page_lines = []
for idx, img_size in enumerate(img_sizes):
page = doc[idx]
blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]
line_boxes = []
for block_idx, block in enumerate(blocks):
for l in block["lines"]:
line_boxes.append(list(l["bbox"]))
page_box = page.bound()
pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
page_lines.append(line_boxes)
return page_lines
def merge_boxes(box1, box2):
return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))
def join_lines(bboxes, max_gap=5):
to_merge = {}
for i, box1 in bboxes:
for z, box2 in bboxes[i + 1:]:
j = i + z + 1
if box1 == box2:
continue
if box1[0] <= box2[0] and box1[2] >= box2[2]:
if abs(box1[1] - box2[3]) <= max_gap:
if i not in to_merge:
to_merge[i] = []
to_merge[i].append(j)
merged_boxes = set()
merged = []
for i, box in bboxes:
if i in merged_boxes:
continue
if i in to_merge:
for j in to_merge[i]:
box = merge_boxes(box, bboxes[j][1])
merged_boxes.add(j)
merged.append(box)
return merged
================================================
FILE: benchmark/utils/metrics.py
================================================
from functools import partial
from itertools import repeat
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def calculate_iou(box1, box2, box1_only=False):
intersection = intersection_area(box1, box2)
union = box_area(box1)
if not box1_only:
union += box_area(box2) - intersection
if union == 0:
return 0
return intersection / union
def match_boxes(preds, references):
num_actual = len(references)
num_predicted = len(preds)
iou_matrix = np.zeros((num_actual, num_predicted))
for i, actual in enumerate(references):
for j, pred in enumerate(preds):
iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
sorted_ious = iou_matrix.flatten()[sorted_indices]
actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
assigned_actual = set()
assigned_pred = set()
matches = []
for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
i, j = idx
if i not in assigned_actual and j not in assigned_pred:
iou_val = iou_matrix[i, j]
if iou_val > .95: # Account for rounding on box edges
iou_val = 1.0
matches.append((i, j, iou_val))
assigned_actual.add(i)
assigned_pred.add(j)
unassigned_actual = set(range(num_actual)) - assigned_actual
unassigned_pred = set(range(num_predicted)) - assigned_pred
matches.extend([(i, None, -1.0) for i in unassigned_actual])
matches.extend([(None, j, 0.0) for j in unassigned_pred])
return matches
def penalized_iou_score(preds, references):
matches = match_boxes(preds, references)
iou = sum([match[2] for match in matches]) / len(matches)
return iou
def intersection_pixels(box1, box2):
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
x_right = min(box1[2], box2[2])
y_bottom = min(box1[3], box2[3])
if x_right < x_left or y_bottom < y_top:
return set()
x_left, x_right = int(x_left), int(x_right)
y_top, y_bottom = int(y_top), int(y_bottom)
coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
pixels = set(zip(coords[0].flat, coords[1].flat))
return pixels
def calculate_coverage(box, other_boxes, penalize_double=False):
box_area = (box[2] - box[0]) * (box[3] - box[1])
if box_area == 0:
return 0
# find total coverage of the box
covered_pixels = set()
double_coverage = list()
for other_box in other_boxes:
ia = intersection_pixels(box, other_box)
double_coverage.append(list(covered_pixels.intersection(ia)))
covered_pixels = covered_pixels.union(ia)
# Penalize double coverage - having multiple bboxes overlapping the same pixels
double_coverage_penalty = len(double_coverage)
if not penalize_double:
double_coverage_penalty = 0
covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
return covered_pixels_count / box_area
def intersection_area(box1, box2):
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
x_right = min(box1[2], box2[2])
y_bottom = min(box1[3], box2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
return (x_right - x_left) * (y_bottom - y_top)
def calculate_coverage_fast(box, other_boxes, penalize_double=False):
box = np.array(box)
other_boxes = np.array(other_boxes)
# Calculate box area
box_area = (box[2] - box[0]) * (box[3] - box[1])
if box_area == 0:
return 0
x_left = np.maximum(box[0], other_boxes[:, 0])
y_top = np.maximum(box[1], other_boxes[:, 1])
x_right = np.minimum(box[2], other_boxes[:, 2])
y_bottom = np.minimum(box[3], other_boxes[:, 3])
widths = np.maximum(0, x_right - x_left)
heights = np.maximum(0, y_bottom - y_top)
intersect_areas = widths * heights
total_intersect = np.sum(intersect_areas)
return min(1.0, total_intersect / box_area)
def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
if len(references) == 0:
return {
"precision": 1,
"recall": 1,
}
if len(preds) == 0:
return {
"precision": 0,
"recall": 0,
}
# If we're not penalizing double coverage, we can use a faster calculation
coverage_func = calculate_coverage_fast
if penalize_double:
coverage_func = calculate_coverage
with ThreadPoolExecutor(max_workers=workers) as executor:
precision_func = partial(coverage_func, penalize_double=penalize_double)
precision_iou = executor.map(precision_func, preds, repeat(references))
reference_iou = executor.map(coverage_func, references, repeat(preds))
precision_classes = [1 if i > threshold else 0 for i in precision_iou]
precision = sum(precision_classes) / len(precision_classes)
recall_classes = [1 if i > threshold else 0 for i in reference_iou]
recall = sum(recall_classes) / len(recall_classes)
return {
"precision": precision,
"recall": recall,
}
def mean_coverage(preds, references):
coverages = []
for box1 in references:
coverage = calculate_coverage(box1, preds)
coverages.append(coverage)
for box2 in preds:
coverage = calculate_coverage(box2, references)
coverages.append(coverage)
# Calculate the average coverage over all comparisons
if len(coverages) == 0:
return 0
coverage = sum(coverages) / len(coverages)
return {"coverage": coverage}
def rank_accuracy(preds, references):
# Preds and references need to be aligned so each position refers to the same bbox
pairs = []
for i, pred in enumerate(preds):
for j, pred2 in enumerate(preds):
if i == j:
continue
pairs.append((i, j, pred > pred2))
# Find how many of the prediction rankings are correct
correct = 0
for i, ref in enumerate(references):
for j, ref2 in enumerate(references):
if (i, j, ref > ref2) in pairs:
correct += 1
return correct / len(pairs)
================================================
FILE: benchmark/utils/scoring.py
================================================
import math
from typing import List
from rapidfuzz import fuzz
def overlap_score(pred_lines: List[str], reference_lines: List[str]):
line_scores = []
line_weights = []
line_match = {}
for i, pred_line in enumerate(pred_lines):
max_score = 0
line_weight = 1
match = None
for j, ref_line in enumerate(reference_lines):
score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
if score > max_score:
max_score = score
line_weight = math.sqrt(len(ref_line))
match = j
line_scores.append(max_score)
line_weights.append(line_weight)
line_match[i] = match
line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]
return line_scores, line_weights, line_match
def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]):
line_scores = []
line_weights = []
assert len(pred_lines) == len(reference_lines)
for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)):
score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
weight = math.sqrt(len(ref_line))
line_scores.append(score * weight)
line_weights.append(weight)
return line_scores, line_weights
================================================
FILE: benchmark/utils/tatr.py
================================================
import torch
from transformers import AutoModelForObjectDetection
from surya.settings import settings
import numpy as np
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def to_tensor(image):
# Convert PIL Image to NumPy array
np_image = np.array(image).astype(np.float32)
# Rearrange dimensions to [C, H, W] format
np_image = np_image.transpose((2, 0, 1))
# Normalize to [0.0, 1.0]
np_image /= 255.0
return torch.from_numpy(np_image)
def normalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def structure_transform(image):
image = MaxResize(1000)(image)
tensor = to_tensor(image)
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return normalized_tensor
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
return boxes
def outputs_to_objects(outputs, img_sizes, id2label):
m = outputs.logits.softmax(-1).max(-1)
batch_labels = list(m.indices.detach().cpu().numpy())
batch_scores = list(m.values.detach().cpu().numpy())
batch_bboxes = outputs['pred_boxes'].detach().cpu()
batch_objects = []
for i in range(len(img_sizes)):
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
pred_scores = batch_scores[i]
pred_labels = batch_labels[i]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({
'label': class_label,
'score': float(score),
'bbox': [float(elem) for elem in bbox]}
)
rows = []
cols = []
for cell in objects:
if cell["label"] == "table column":
cols.append(cell)
if cell["label"] == "table row":
rows.append(cell)
batch_objects.append({
"rows": rows,
"cols": cols
})
return batch_objects
def load_tatr():
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
def batch_inference_tatr(model, images, batch_size):
device = model.device
rows_cols = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
return rows_cols
================================================
FILE: benchmark/utils/tesseract.py
================================================
from typing import List, Optional
import numpy as np
from tqdm import tqdm
from surya.input.processing import slice_bboxes_from_image
from surya.settings import settings
import os
from concurrent.futures import ProcessPoolExecutor
from surya.recognition.languages import CODE_TO_LANGUAGE
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
def surya_lang_to_tesseract(code: str) -> Optional[str]:
lang_str = CODE_TO_LANGUAGE[code]
try:
tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
except KeyError:
return None
return tess_lang
def tesseract_ocr(img, bboxes, lang: str):
import pytesseract
line_imgs = slice_bboxes_from_image(img, bboxes)
config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
lines = []
for line_img in line_imgs:
line = pytesseract.image_to_string(line_img, lang=lang, config=config)
lines.append(line)
return lines
def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())
if not cpus:
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)
# Tesseract uses up to 4 processes per instance
# Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
tess_parallel = max(tess_parallel_cores // 2, 1)
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
tess_text = list(tess_text)
return tess_text
def tesseract_bboxes(img):
import pytesseract
from pytesseract import Output
arr_img = np.asarray(img, dtype=np.uint8)
ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
bboxes = []
n_boxes = len(ocr['level'])
for i in range(n_boxes):
# It is possible to merge by line here with line number, but it gives bad results.
_, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
bbox = (x, y, x + w, y + h)
bboxes.append(bbox)
return bboxes
def tesseract_parallel(imgs):
# Tesseract uses 4 threads per instance
tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)
# Tesseract uses 4 threads per instance
tess_parallel = max(tess_parallel_cores // 4, 1)
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
tess_bboxes = list(tess_bboxes)
return tess_bboxes
TESS_CODE_TO_LANGUAGE = {
"afr": "Afrikaans",
"amh": "Amharic",
"ara": "Arabic",
"asm": "Assamese",
"aze": "Azerbaijani",
"bel": "Belarusian",
"ben": "Bengali",
"bod": "Tibetan",
"bos": "Bosnian",
"bre": "Breton",
"bul": "Bulgarian",
"cat": "Catalan",
"ceb": "Cebuano",
"ces": "Czech",
"chi_sim": "Chinese",
"chr": "Cherokee",
"cym": "Welsh",
"dan": "Danish",
"deu": "German",
"dzo": "Dzongkha",
"ell": "Greek",
"eng": "English",
"epo": "Esperanto",
"est": "Estonian",
"eus": "Basque",
"fas": "Persian",
"fin": "Finnish",
"fra": "French",
"fry": "Western Frisian",
"guj": "Gujarati",
"gla": "Scottish Gaelic",
"gle": "Irish",
"glg": "Galician",
"heb": "Hebrew",
"hin": "Hindi",
"hrv": "Croatian",
"hun": "Hungarian",
"hye": "Armenian",
"iku": "Inuktitut",
"ind": "Indonesian",
"isl": "Icelandic",
"ita": "Italian",
"jav": "Javanese",
"jpn": "Japanese",
"kan": "Kannada",
"kat": "Georgian",
"kaz": "Kazakh",
"khm": "Khmer",
"kir": "Kyrgyz",
"kor": "Korean",
"lao": "Lao",
"lat": "Latin",
"lav": "Latvian",
"lit": "Lithuanian",
"mal": "Malayalam",
"mar": "Marathi",
"mkd": "Macedonian",
"mlt": "Maltese",
"mon": "Mongolian",
"msa": "Malay",
"mya": "Burmese",
"nep": "Nepali",
"nld": "Dutch",
"nor": "Norwegian",
"ori": "Oriya",
"pan": "Punjabi",
"pol": "Polish",
"por": "Portuguese",
"pus": "Pashto",
"ron": "Romanian",
"rus": "Russian",
"san": "Sanskrit",
"sin": "Sinhala",
"slk": "Slovak",
"slv": "Slovenian",
"snd": "Sindhi",
"spa": "Spanish",
"sqi": "Albanian",
"srp": "Serbian",
"swa": "Swahili",
"swe": "Swedish",
"syr": "Syriac",
"tam": "Tamil",
"tel": "Telugu",
"tgk": "Tajik",
"tha": "Thai",
"tir": "Tigrinya",
"tur": "Turkish",
"uig": "Uyghur",
"ukr": "Ukrainian",
"urd": "Urdu",
"uzb": "Uzbek",
"vie": "Vietnamese",
"yid": "Yiddish"
}
TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
================================================
FILE: benchmark/utils/textract.py
================================================
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import traceback
from surya.input.processing import slice_bboxes_from_image
from surya.recognition import RecognitionPredictor
def textract_ocr(extractor, img):
try:
document = extractor.detect_document_text(file_source=img)
return [line.text for line in document.lines]
except:
traceback.print_exc()
return [None]
def textract_ocr_parallel(imgs, cpus=None):
from textractor import Textractor # Optional dependency
extractor = Textractor(profile_name='default')
parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())
if not cpus:
cpus = os.cpu_count()
parallel_cores = min(parallel_cores, cpus)
with ThreadPoolExecutor(max_workers=parallel_cores) as executor:
textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR")
textract_text = list(textract_text)
return textract_text
================================================
FILE: benchmark/utils/verify_benchmark_scores.py
================================================
import json
import click
def verify_layout(data):
scores = data["metrics"]
for layout_type, metrics in scores.items():
if layout_type == "List": # Skip lists since none appear early on
continue
if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6:
raise ValueError("Scores do not meet the required threshold")
def verify_det(data):
scores = data["metrics"]["surya"]
if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
raise ValueError("Scores do not meet the required threshold")
def verify_rec(data):
scores = data["surya"]
if scores["avg_score"] <= 0.9:
raise ValueError("Scores do not meet the required threshold")
def verify_order(data):
score = data["mean_accuracy"]
if score < 0.75:
raise ValueError("Scores do not meet the required threshold")
def verify_table_rec(data):
row_score = data["surya"]["mean_row_iou"]
col_score = data["surya"]["mean_col_iou"]
if row_score < 0.75 or col_score < 0.75:
raise ValueError("Scores do not meet the required threshold")
def verify_texify(data):
edit_dist = data["scores"]
if edit_dist > 0.2:
raise ValueError("Scores do not meet the required threshold")
@click.command(help="Verify benchmark scores")
@click.argument("file_path", type=str)
@click.option(
"--bench_type", type=str, help="Type of benchmark to verify", default="detection"
)
def main(file_path, bench_type):
with open(file_path, "r") as file:
data = json.load(file)
if bench_type == "detection":
verify_det(data)
elif bench_type == "recognition":
verify_rec(data)
elif bench_type == "layout":
verify_layout(data)
elif bench_type == "ordering":
verify_order(data)
elif bench_type == "table_recognition":
verify_table_rec(data)
elif bench_type == "texify":
verify_texify(data)
else:
raise ValueError("Invalid benchmark type")
if __name__ == "__main__":
main()
================================================
FILE: detect_layout.py
================================================
from surya.scripts.detect_layout import detect_layout_cli
if __name__ == "__main__":
detect_layout_cli()
================================================
FILE: detect_text.py
================================================
from surya.scripts.detect_text import detect_text_cli
if __name__ == "__main__":
detect_text_cli()
================================================
FILE: ocr_app.py
================================================
from surya.scripts.run_streamlit_app import streamlit_app_cli
if __name__ == "__main__":
streamlit_app_cli()
================================================
FILE: ocr_latex.py
================================================
from surya.scripts.ocr_latex import ocr_latex_cli
if __name__ == "__main__":
ocr_latex_cli()
================================================
FILE: ocr_text.py
================================================
from surya.scripts.ocr_text import ocr_text_cli
if __name__ == "__main__":
ocr_text_cli()
================================================
FILE: pyproject.toml
================================================
[tool.poetry]
name = "surya-ocr"
version = "0.17.1"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri "]
readme = "README.md"
license = "GPL-3.0-or-later"
repository = "https://github.com/VikParuchuri/surya"
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
packages = [
{include = "surya"}
]
[tool.poetry.dependencies]
python = "^3.10"
transformers = ">=4.56.1"
torch = "^2.7.0"
pydantic = "^2.5.3"
pydantic-settings = "^2.1.0"
python-dotenv = "^1.0.0"
pillow = "^10.2.0"
pypdfium2 = "=4.30.0"
filetype = "^1.2.0"
click = "^8.1.8"
platformdirs = "^4.3.6"
opencv-python-headless = "==4.11.0.86"
einops = "^0.8.1"
pre-commit = "^4.2.0"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
pytesseract = "^0.3.10"
pymupdf = "^1.23.8"
datasets = "^2.16.1"
rapidfuzz = "^3.6.1"
streamlit = "^1.31.0"
pytest = "^8.3.4"
pdftext = "^0.5.1"
tabulate = "^0.9.0"
[tool.poetry.scripts]
surya_detect = "surya.scripts.detect_text:detect_text_cli"
surya_ocr = "surya.scripts.ocr_text:ocr_text_cli"
surya_layout = "surya.scripts.detect_layout:detect_layout_cli"
surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli"
surya_table = "surya.scripts.table_recognition:table_recognition_cli"
surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli"
texify_gui = "surya.scripts.run_texify_app:texify_app_cli"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[[tool.poetry.source]]
name = "libtpu-releases"
url = "https://storage.googleapis.com/libtpu-releases/index.html"
priority = "supplemental"
[[tool.poetry.source]]
name = "libtpu-wheels"
url = "https://storage.googleapis.com/libtpu-wheels/index.html"
priority = "supplemental"
[tool.poetry.group.xla]
optional = true
[tool.poetry.group.xla.dependencies]
torch-xla = {version = "^2.4.1", extras = ["tpu"]}
================================================
FILE: pytest.ini
================================================
[pytest]
testpaths=tests
pythonpath=.
filterwarnings =
ignore::UserWarning
ignore::PendingDeprecationWarning
ignore::DeprecationWarning
================================================
FILE: signatures/version1/cla.json
================================================
{
"signedContributors": [
{
"name": "rishiraj",
"id": 44090649,
"comment_id": 2170578748,
"created_at": "2024-06-15T19:31:20Z",
"repoId": 741297064,
"pullRequestNo": 135
},
{
"name": "mmacvicar",
"id": 59354,
"comment_id": 2236493182,
"created_at": "2024-07-18T13:17:43Z",
"repoId": 741297064,
"pullRequestNo": 152
},
{
"name": "jimexist",
"id": 622789,
"comment_id": 2255151376,
"created_at": "2024-07-29T07:23:55Z",
"repoId": 741297064,
"pullRequestNo": 160
},
{
"name": "michaeldriscoll-avant",
"id": 85255083,
"comment_id": 2259143427,
"created_at": "2024-07-30T20:21:33Z",
"repoId": 741297064,
"pullRequestNo": 161
},
{
"name": "EdoardoPona",
"id": 29152472,
"comment_id": 2271115922,
"created_at": "2024-08-06T11:58:00Z",
"repoId": 741297064,
"pullRequestNo": 167
},
{
"name": "hidenori-endo",
"id": 15546605,
"comment_id": 2307217499,
"created_at": "2024-08-23T14:31:17Z",
"repoId": 741297064,
"pullRequestNo": 182
},
{
"name": "dobosevych",
"id": 12053536,
"comment_id": 2430376828,
"created_at": "2024-10-22T21:48:34Z",
"repoId": 741297064,
"pullRequestNo": 220
},
{
"name": "iammosespaulr",
"id": 28682735,
"comment_id": 2447941238,
"created_at": "2024-10-30T17:55:23Z",
"repoId": 741297064,
"pullRequestNo": 235
},
{
"name": "ArthurMor4is",
"id": 42987302,
"comment_id": 2515315717,
"created_at": "2024-12-03T18:37:45Z",
"repoId": 741297064,
"pullRequestNo": 255
},
{
"name": "tarun-menta",
"id": 66506307,
"comment_id": 2543457960,
"created_at": "2024-12-15T05:43:33Z",
"repoId": 741297064,
"pullRequestNo": 261
},
{
"name": "jonaskahn",
"id": 4338500,
"comment_id": 2556622097,
"created_at": "2024-12-20T09:36:20Z",
"repoId": 741297064,
"pullRequestNo": 269
},
{
"name": "kumsumit",
"id": 95072784,
"comment_id": 2574534622,
"created_at": "2025-01-07T07:05:59Z",
"repoId": 741297064,
"pullRequestNo": 276
},
{
"name": "kevinhu",
"id": 6051736,
"comment_id": 2614135351,
"created_at": "2025-01-25T23:34:12Z",
"repoId": 741297064,
"pullRequestNo": 291
},
{
"name": "zanussbaum",
"id": 33707069,
"comment_id": 3008673416,
"created_at": "2025-06-26T14:20:46Z",
"repoId": 741297064,
"pullRequestNo": 403
},
{
"name": "mebriki",
"id": 35892987,
"comment_id": 3154706976,
"created_at": "2025-08-05T10:54:27Z",
"repoId": 741297064,
"pullRequestNo": 418
},
{
"name": "starikovplusplus",
"id": 56602036,
"comment_id": 3168958011,
"created_at": "2025-08-08T18:29:50Z",
"repoId": 741297064,
"pullRequestNo": 423
},
{
"name": "sandy0kwon",
"id": 78377296,
"comment_id": 3207932260,
"created_at": "2025-08-20T20:07:15Z",
"repoId": 741297064,
"pullRequestNo": 434
},
{
"name": "n0kovo",
"id": 16690056,
"comment_id": 3208251881,
"created_at": "2025-08-20T22:22:06Z",
"repoId": 741297064,
"pullRequestNo": 435
},
{
"name": "davidxifeng",
"id": 158052,
"comment_id": 3249594859,
"created_at": "2025-09-03T14:52:16Z",
"repoId": 741297064,
"pullRequestNo": 445
},
{
"name": "u-ashish",
"id": 14264791,
"comment_id": 3258734182,
"created_at": "2025-09-05T15:16:48Z",
"repoId": 741297064,
"pullRequestNo": 447
},
{
"name": "Mohking1",
"id": 63689545,
"comment_id": 3314908963,
"created_at": "2025-09-20T11:21:42Z",
"repoId": 741297064,
"pullRequestNo": 462
},
{
"name": "wkpark",
"id": 232347,
"comment_id": 3330009557,
"created_at": "2025-09-24T17:42:55Z",
"repoId": 741297064,
"pullRequestNo": 464
},
{
"name": "coval3nte",
"id": 65908512,
"comment_id": 3848768229,
"created_at": "2026-02-04T17:28:32Z",
"repoId": 741297064,
"pullRequestNo": 483
},
{
"name": "bailey-coding",
"id": 29517254,
"comment_id": 3955014177,
"created_at": "2026-02-24T22:09:52Z",
"repoId": 741297064,
"pullRequestNo": 487
},
{
"name": "Br1an67",
"id": 29810238,
"comment_id": 3979412700,
"created_at": "2026-03-01T07:32:18Z",
"repoId": 741297064,
"pullRequestNo": 489
}
]
}
================================================
FILE: static/fonts/.gitignore
================================================
*
!.gitignore
================================================
FILE: surya/__init__.py
================================================
================================================
FILE: surya/common/__init__.py
================================================
================================================
FILE: surya/common/adetr/decoder.py
================================================
from typing import Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.xla import mark_step
_MAX_SQRT_GRADIENT = 1000.0
class WrappedEmbedding(nn.Embedding):
def forward(self, input_ids, *args, **kwargs):
return super().forward(input_ids)
class SuryaADETRDecoderRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
# Add clipping to prevent division by zero
variance = torch.clamp(variance, min=self.eps)
return x * torch.rsqrt(variance)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
# Clamp to float16 range
f16_info = torch.finfo(x.dtype)
output = output.clamp(min=f16_info.min, max=f16_info.max)
output = torch.where(
torch.isnan(output), torch.tensor(0.0, device=output.device), output
)
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm)
class SuryaADETRDecoderRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
)
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
1, 2
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class SuryaADETRDecoderSdpaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
Modified for GQA
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size,
self.num_attention_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
self.config.encoder_hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.config.encoder_hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
)
self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Encoder attention mask currently ignored
bsz, q_len, _ = hidden_states.size()
_, v_len, _ = encoder_hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_attention_heads, self.head_dim
).transpose(1, 2)
if self.key_states is None:
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
key_states = key_states.view(
bsz, v_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, v_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if use_cache:
self._update_cache(key_states, value_states)
else:
key_states = self.key_states
value_states = self.value_states
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _clear_cache(self):
if self.value_states is not None:
del self.value_states
if self.key_states is not None:
del self.key_states
def _setup_cache(self, batch_size, device, dtype=None):
# Setup initial caches
self.value_states = None
self.key_states = None
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
self.value_states = value_states
self.key_states = key_states
class SuryaADETRDecoderSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size,
self.num_attention_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
)
self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
self.static_cache = static_cache
self.max_boxes = max_boxes
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
window_attn: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Final is bsz, num_attention_heads, seq_len, head_dim
query_states = query_states.view(
bsz, q_len, self.num_attention_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if use_cache and hasattr(self, "key_states"):
cache_kwargs = {
"cache_position": cache_position,
"window_attn": window_attn,
}
key_states, value_states = self._update_cache(
key_states, value_states, **cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
# Mask is batch, head, seq_len, kv_len
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if cache_position is not None and self.static_cache:
current_pos = cache_position[-1]
causal_mask[:, :, :, current_pos + 1 :] = torch.finfo(
causal_mask.dtype
).min
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
if dtype is None and self.config.torch_dtype is not None:
dtype = self.config.torch_dtype
dtype = dtype if dtype is not None else torch.float32
# Setup initial caches
self.value_states = None
self.key_states = None
if self.static_cache:
cache_shape = (
batch_size,
self.num_key_value_heads,
self.max_boxes,
self.head_dim,
)
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
def _clear_cache(self):
if self.value_states is not None:
del self.value_states
if self.key_states is not None:
del self.key_states
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
cache_position = cache_kwargs.get("cache_position")
k_out, v_out = (
self.key_states.to(key_states.device),
self.value_states.to(value_states.device),
)
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
k_out = key_states
if self.key_states is not None:
k_out = torch.cat([self.key_states, key_states], dim=2)
v_out = value_states
if self.value_states is not None:
v_out = torch.cat([self.value_states, value_states], dim=2)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
if self.static_cache:
return self._update_static_cache(key_states, value_states, **cache_kwargs)
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
class SuryaADETRDecoderMlp(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class SuryaADETRDecoderLayer(nn.Module):
def __init__(self, config, layer_idx, static_cache=False, max_boxes=None):
super().__init__()
self.cross_pre_norm = SuryaADETRDecoderRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.temporal_pre_norm = SuryaADETRDecoderRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.temporal_block = None
if layer_idx in config.self_attn_layers:
self.temporal_block = SuryaADETRDecoderSdpaAttention(
config, static_cache=static_cache, max_boxes=max_boxes
)
self.cross_attn_block = None
if layer_idx in config.cross_attn_layers:
self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config)
self.window_attn = layer_idx not in config.global_attn_layers
self.channel_pre_norm = SuryaADETRDecoderRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.mlp_block = SuryaADETRDecoderMlp(config)
self.double_residual_flow = getattr(config, "double_residual_flow", False)
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
if self.double_residual_flow:
return self.double_res_forward(
activations,
position_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
cache_position,
use_cache,
)
hidden_states = activations
if self.cross_attn_block is not None:
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(hidden_states)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
use_cache=use_cache,
)
hidden_states = cross_attn_path + hidden_states
if self.temporal_block is not None:
temporal_inputs = self.temporal_pre_norm(
hidden_states
) # RMSNorm introduces slight slight differences
temporal_path = self.temporal_block(
temporal_inputs,
position_ids,
attention_mask,
cache_position=cache_position,
use_cache=use_cache,
window_attn=self.window_attn,
)
hidden_states = temporal_path + hidden_states
block_input = hidden_states
hidden_states = self.channel_pre_norm(block_input)
hidden_states = self.mlp_block(hidden_states)
hidden_states = hidden_states + block_input
return hidden_states
def double_res_forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raw_activations = activations
if self.cross_attn_block is not None:
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(activations)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
use_cache=use_cache,
)
cross_attn_output = cross_attn_path + raw_activations
else:
cross_attn_output = raw_activations
if self.temporal_block is not None:
inputs_normalized = self.temporal_pre_norm(
cross_attn_output
) # RMSNorm introduces slight slight differences
hidden_states = self.temporal_block(
inputs_normalized,
position_ids,
attention_mask,
cache_position=cache_position,
use_cache=use_cache,
window_attn=self.window_attn,
)
residual = hidden_states + raw_activations
else:
residual = cross_attn_output
hidden_states = self.channel_pre_norm(residual)
hidden_states = self.mlp_block(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel):
config_class = PretrainedConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SuryaADETRDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
if isinstance(module, SuryaADETRDecoderSdpaAttention):
torch.nn.init.normal_(
module.q_proj.weight, mean=0.0, std=self.config.init_std
)
torch.nn.init.normal_(
module.k_proj.weight, mean=0.0, std=self.config.init_std
)
torch.nn.init.normal_(
module.v_proj.weight, mean=0.0, std=self.config.init_std
)
torch.nn.init.normal_(
module.o_proj.weight, mean=0.0, std=self.config.init_std
)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers
for layer in layers:
if layer.temporal_block:
layer.temporal_block._setup_cache(batch, device, dtype)
if layer.cross_attn_block:
layer.cross_attn_block._setup_cache(batch, device, dtype)
def _clear_cache(self):
layers = getattr(self, "model", self).layers
for layer in layers:
if layer.temporal_block:
layer.temporal_block._clear_cache()
if layer.cross_attn_block:
layer.cross_attn_block._clear_cache()
def reset_cache(self, batch, device, dtype):
pass
def _tie_weights(self):
pass
def tie_weights(self):
pass
class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`]
Args:
config: PretrainedConfig
"""
def __init__(
self,
config: PretrainedConfig,
embedder: nn.Module = None,
max_boxes: int = None,
static_cache: bool = False,
):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.causal = config.causal
self.embed_tokens = embedder
self.max_boxes = max_boxes
self.static_cache = static_cache
self.layers = nn.ModuleList(
[
SuryaADETRDecoderLayer(
config, layer_idx, static_cache=static_cache, max_boxes=max_boxes
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_norm = SuryaADETRDecoderRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.register_buffer(
"normalizer",
torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32),
persistent=False,
)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
input_boxes_counts: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
prefill: bool = False,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
hidden_states = inputs_embeds
if use_cache and prefill:
self._setup_cache(
self.config,
hidden_states.shape[0],
hidden_states.device,
hidden_states.dtype,
)
if cache_position is None:
cache_position = torch.arange(
hidden_states.shape[1], device=hidden_states.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position
)
all_hidden_states = () if output_hidden_states else None
for i, residual_block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
residual_block.__call__,
hidden_states,
position_ids,
causal_mask,
encoder_hidden_states,
encoder_attention_mask,
cache_position,
use_cache,
)
else:
hidden_states = residual_block(
hidden_states,
position_ids,
causal_mask,
encoder_hidden_states,
encoder_attention_mask,
cache_position,
use_cache,
)
hidden_states = self.final_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if not self.causal:
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
target_length = max(self.max_boxes, sequence_length)
diagonal = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
causal_mask = diagonal
if sequence_length != 1:
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
# triu will be the min_dtype, everything else is 0 (attended to)
causal_mask = torch.triu(diagonal, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
# Mask positions in the causal mask that are masked in the attention mask
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
:, None, None, :
].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[
..., :mask_length
].masked_fill(padding_mask, min_dtype)
if attention_mask is not None and attention_mask.device.type == "cuda":
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
================================================
FILE: surya/common/donut/encoder.py
================================================
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.pytorch_utils import (
find_pruneable_heads_and_indices,
meshgrid,
prune_linear_layer,
)
from transformers.utils import ModelOutput
from transformers import DonutSwinConfig
from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.xla import mark_step
_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
class DonutSwinEncoderOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class DonutSwinModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size,
height // window_size,
window_size,
width // window_size,
window_size,
num_channels,
)
windows = (
input_feature.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, window_size, window_size, num_channels)
)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.view(
-1,
height // window_size,
width // window_size,
window_size,
window_size,
num_channels,
)
windows = (
windows.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, height, width, num_channels)
)
return windows
# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
class DonutSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config, use_mask_token=False):
super().__init__()
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
self.mask_token = (
nn.Parameter(torch.zeros(1, 1, config.embed_dim))
if use_mask_token
else None
)
self.position_embeddings = None
self.row_embeddings = None
self.column_embeddings = None
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + 1, config.embed_dim)
)
if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings:
self.row_embeddings = nn.Parameter(
torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)
)
self.column_embeddings = nn.Parameter(
torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)
)
self.norm = nn.LayerNorm(config.embed_dim)
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
else:
embeddings = embeddings + self.position_embeddings[:, :seq_len]
if self.row_embeddings is not None and self.column_embeddings is not None:
# Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
row_embeddings = self.row_embeddings[
:, : output_dimensions[0], :
].repeat_interleave(output_dimensions[1], dim=1)
column_embeddings = self.column_embeddings[
:, : output_dimensions[1], :
].repeat(1, output_dimensions[0], 1)
embeddings = embeddings + row_embeddings + column_embeddings
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
class DonutSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = (
image_size
if isinstance(image_size, collections.abc.Iterable)
else (image_size, image_size)
)
patch_size = (
patch_size
if isinstance(patch_size, collections.abc.Iterable)
else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (
image_size[0] // patch_size[0]
)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
)
self.projection = nn.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(
self, pixel_values: Optional[torch.FloatTensor]
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class DonutSwinPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(
self,
input_resolution: Tuple[int],
dim: int,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(
self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
) -> torch.Tensor:
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
# pad input to be disible by width and height, if needed
input_feature = self.maybe_pad(input_feature, height, width)
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = input_feature[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels
input_feature = torch.cat(
[input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
)
input_feature = input_feature.view(
batch_size, -1, 4 * num_channels
) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
class DonutSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.num_kv_heads = num_kv_heads
self.kv_repeats = self.num_attention_heads // self.num_kv_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.kv_head_size = self.num_kv_heads * self.attention_head_size
self.window_size = (
window_size
if isinstance(window_size, collections.abc.Iterable)
else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(
self.all_head_size, self.all_head_size, bias=config.qkv_bias
)
self.key = nn.Linear(
self.all_head_size, self.kv_head_size, bias=config.qkv_bias
)
self.value = nn.Linear(
self.all_head_size, self.kv_head_size, bias=config.qkv_bias
)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_kv_for_scores(self, x, repeats):
new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
x = x.view(new_x_shape)
x = x.repeat(
1, 1, repeats, 1
) # repeat the values for each key-value head to match query dim
return x.permute(0, 2, 1, 3).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
# Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
key_layer = self.transpose_kv_for_scores(
self.key(hidden_states), self.kv_repeats
)
value_layer = self.transpose_kv_for_scores(
self.value(hidden_states), self.kv_repeats
)
query_layer = self.transpose_for_scores(mixed_query_layer)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
)
relative_position_bias = (
relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
)
relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1)
if attention_mask is None:
attention_mask = relative_position_bias
else:
mask_shape = attention_mask.shape[0]
repeat_count = batch_size // mask_shape
attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
attention_mask = attention_mask + relative_position_bias
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=0.0,
scale=self.attention_head_size**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, dim, num_channels)
outputs = (attn_output,)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class DonutSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
return self.dense(hidden_states)
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class DonutSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
super().__init__()
self.self = DonutSwinSelfAttention(
config, dim, num_heads, num_kv_heads, window_size
)
self.output = DonutSwinSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = (
self.self.attention_head_size * self.self.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
class DonutSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput
class DonutSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.dense(hidden_states)
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class DonutSwinLayer(nn.Module):
def __init__(
self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0
):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(
config, dim, num_heads, num_kv_heads, window_size=self.window_size
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DonutSwinIntermediate(config, dim)
self.output = DonutSwinOutput(config, dim)
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution))
if torch.jit.is_tracing()
else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not always_partition:
self.set_shift_and_window_size(input_dimensions)
else:
pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(
hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(
shifted_hidden_states, self.window_size
)
hidden_states_windows = hidden_states_windows.view(
-1, self.window_size * self.window_size, channels
)
attn_mask = self.get_attn_mask(
height_pad,
width_pad,
dtype=hidden_states.dtype,
device=hidden_states_windows.device,
)
attention_outputs = self.attention(
hidden_states_windows,
attn_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
attention_windows = attention_output.view(
-1, self.window_size, self.window_size, channels
)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
)
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = torch.roll(
shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + attention_windows
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_outputs = (
(layer_output, attention_outputs[1])
if output_attentions
else (layer_output,)
)
return layer_outputs
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module):
def __init__(
self,
config,
layer_num,
dim,
input_resolution,
depth,
num_heads,
num_kv_heads,
downsample,
):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
DonutSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=nn.LayerNorm
)
else:
self.downsample = None
self.pointing = False
self.positional_encoding = None
if config.use_positional_embeddings:
self.positional_encoding = self.build_2d_sincos_position_embedding(
input_resolution[1],
input_resolution[0],
embed_dim=dim,
)
@staticmethod
def build_2d_sincos_position_embedding(
width,
height,
embed_dim=256,
temperature=10000.0,
device="cpu",
dtype=torch.float32,
):
grid_w = torch.arange(int(width), dtype=dtype, device=device)
grid_h = torch.arange(int(height), dtype=dtype, device=device)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
if embed_dim % 4 != 0:
raise ValueError(
"Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
)
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
omega = 1.0 / (temperature**omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.concat(
[out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1
)[None, :, :]
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
if self.positional_encoding is not None:
hidden_states = hidden_states + self.positional_encoding.to(
hidden_states.dtype
).to(hidden_states.device)
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(
hidden_states_before_downsampling, input_dimensions
)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (
hidden_states,
hidden_states_before_downsampling,
output_dimensions,
)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
class DonutSwinEncoder(nn.Module):
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
self.layers = nn.ModuleList(
[
DonutSwinStage(
config=config,
layer_num=i_layer,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(
grid_size[0] // (2**i_layer),
grid_size[1] // (2**i_layer),
),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
num_kv_heads=config.num_kv_heads[i_layer]
if hasattr(config, "num_kv_heads")
else config.num_heads[i_layer],
downsample=DonutSwinPatchMerging
if (i_layer < self.num_layers - 1)
else None,
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
always_partition: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, DonutSwinEncoderOutput]:
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(
batch_size, *input_dimensions, hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size,
*(output_dimensions[0], output_dimensions[1]),
hidden_size,
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(
batch_size, *input_dimensions, hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions:
all_self_attentions += layer_outputs[3:]
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return DonutSwinEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
class DonutSwinPreTrainedModel(SuryaPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DonutSwinConfig
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["DonutSwinStage"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
================================================
FILE: surya/common/donut/processor.py
================================================
from typing import Dict, Union, Optional, List, Iterable
import cv2
from torch import TensorType
from transformers import ImageProcessingMixin
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import pad, normalize
from transformers.image_utils import (
ImageInput,
ChannelDimension,
make_list_of_images,
get_image_size,
)
import numpy as np
from PIL import Image
import PIL
from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from surya.common.s3 import S3DownloaderMixin
from surya.settings import settings
class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin):
def __init__(
self,
*args,
max_size=None,
align_long_axis=False,
rescale_factor: Union[int, float] = 1 / 255,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.patch_size = kwargs.get("patch_size", (4, 4))
self.max_size = max_size
self.do_align_long_axis = align_long_axis
self.resample = Image.Resampling.BILINEAR
self.rescale_factor = rescale_factor
self.image_mean = (
image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
)
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def __call__(self, images, **kwargs) -> PIL.Image.Image:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
@classmethod
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
max_width, max_height = size["width"], size["height"]
resized_image = cv2.resize(
image, (max_width, max_height), interpolation=interpolation
)
resized_image = resized_image.transpose(2, 0, 1)
return resized_image
def process_inner(self, images: List[np.ndarray]):
assert images[0].shape[2] == 3 # RGB input images, channel dim last
if self.do_align_long_axis:
# Rotate if the bbox is wider than it is tall
images = [
SuryaEncoderImageProcessor.align_long_axis(
image, size=self.max_size, input_data_format=ChannelDimension.LAST
)
for image in images
]
# Verify that the image is wider than it is tall
for img in images:
assert img.shape[1] >= img.shape[0]
# This also applies the right channel dim format, to channel x height x width
images = [
SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample)
for img in images
]
assert images[0].shape[0] == 3 # RGB input images, channel dim first
# Convert to float32 for rescale/normalize
images = [img.astype(np.float32) for img in images]
# Pads with 255 (whitespace)
# Pad to max size to improve performance
max_size = self.max_size
images = [
SuryaEncoderImageProcessor.pad_image(
image=image,
size=max_size,
input_data_format=ChannelDimension.FIRST,
pad_value=settings.RECOGNITION_PAD_VALUE,
)
for image in images
]
# Rescale and normalize
for idx in range(len(images)):
images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype(
np.float32
)
images = [
SuryaEncoderImageProcessor.normalize(
img,
mean=self.image_mean,
std=self.image_std,
input_data_format=ChannelDimension.FIRST,
)
for img in images
]
return images
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> PIL.Image.Image:
images = make_list_of_images(images)
# Convert to numpy for later processing steps
images = [np.array(img) for img in images]
images = self.process_inner(images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@classmethod
def pad_image(
cls,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_value: float = 0.0,
) -> np.ndarray:
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
delta_width = output_width - input_width
delta_height = output_height - input_height
assert delta_width >= 0 and delta_height >= 0
pad_top = delta_height // 2
pad_left = delta_width // 2
pad_bottom = delta_height - pad_top
pad_right = delta_width - pad_left
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
return pad(
image,
padding,
data_format=data_format,
input_data_format=input_data_format,
constant_values=pad_value,
)
@classmethod
def align_long_axis(
cls, image: np.ndarray, size: Dict[str, int], **kwargs
) -> np.ndarray:
input_height, input_width = image.shape[:2]
output_height, output_width = size["height"], size["width"]
if (output_width < output_height and input_width > input_height) or (
output_width > output_height and input_width < input_height
):
image = np.rot90(image, 3)
return image
@classmethod
def normalize(
cls,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
return normalize(
image,
mean=mean,
std=std,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
================================================
FILE: surya/common/load.py
================================================
from typing import Optional, Any
import torch
from surya.settings import settings
class ModelLoader:
def __init__(self, checkpoint: Optional[str] = None):
self.checkpoint = checkpoint
def model(
self,
device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
attention_implementation: Optional[str] = None,
) -> Any:
raise NotImplementedError()
def processor(
self,
device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
) -> Any:
raise NotImplementedError()
================================================
FILE: surya/common/polygon.py
================================================
import copy
from typing import List, Optional
import numpy as np
from pydantic import BaseModel, field_validator, computed_field
import numbers
class PolygonBox(BaseModel):
polygon: List[List[float]]
confidence: Optional[float] = None
@field_validator("polygon", mode="before")
@classmethod
def convert_bbox_to_polygon(cls, value):
if isinstance(value, (list, tuple)) and len(value) == 4:
if all(isinstance(x, numbers.Number) for x in value):
value = [float(v) for v in value]
x_min, y_min, x_max, y_max = value
polygon = [
[x_min, y_min],
[x_max, y_min],
[x_max, y_max],
[x_min, y_max],
]
return polygon
elif all(
isinstance(point, (list, tuple)) and len(point) == 2 for point in value
):
value = [[float(v) for v in point] for point in value]
return value
elif isinstance(value, np.ndarray):
if value.shape == (4, 2):
return value.tolist()
raise ValueError(
f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]. All values must be numeric. You passed {value} of type {type(value)}. The first value is of type {type(value[0])}."
)
@property
def height(self):
return self.bbox[3] - self.bbox[1]
@property
def width(self):
return self.bbox[2] - self.bbox[0]
@property
def area(self):
return self.width * self.height
@computed_field
@property
def bbox(self) -> List[float]:
x_coords = [point[0] for point in self.polygon]
y_coords = [point[1] for point in self.polygon]
return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
def rescale(self, processor_size, image_size):
# Point is in x, y format
page_width, page_height = processor_size
img_width, img_height = image_size
width_scaler = img_width / page_width
height_scaler = img_height / page_height
for corner in self.polygon:
corner[0] = int(corner[0] * width_scaler)
corner[1] = int(corner[1] * height_scaler)
def round(self, divisor):
for corner in self.polygon:
corner[0] = int(corner[0] / divisor) * divisor
corner[1] = int(corner[1] / divisor) * divisor
def fit_to_bounds(self, bounds):
new_corners = copy.deepcopy(self.polygon)
for corner in new_corners:
corner[0] = max(min(corner[0], bounds[2]), bounds[0])
corner[1] = max(min(corner[1], bounds[3]), bounds[1])
self.polygon = new_corners
def merge(self, other):
x1 = min(self.bbox[0], other.bbox[0])
y1 = min(self.bbox[1], other.bbox[1])
x2 = max(self.bbox[2], other.bbox[2])
y2 = max(self.bbox[3], other.bbox[3])
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def merge_left(self, other):
x1 = min(self.bbox[0], other.bbox[0])
self.polygon[0][0] = x1
self.polygon[3][0] = x1
def merge_right(self, other):
x2 = max(self.bbox[2], other.bbox[2])
self.polygon[1][0] = x2
self.polygon[2][0] = x2
def expand(self, x_margin: float, y_margin: float):
new_polygon = []
x_margin = x_margin * self.width
y_margin = y_margin * self.height
for idx, poly in enumerate(self.polygon):
if idx == 0:
new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)])
elif idx == 1:
new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)])
elif idx == 2:
new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)])
elif idx == 3:
new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])
self.polygon = new_polygon
def intersection_polygon(self, other) -> List[List[float]]:
new_poly = []
for i in range(4):
if i == 0:
new_corner = [
max(self.polygon[0][0], other.polygon[0][0]),
max(self.polygon[0][1], other.polygon[0][1]),
]
elif i == 1:
new_corner = [
min(self.polygon[1][0], other.polygon[1][0]),
max(self.polygon[1][1], other.polygon[1][1]),
]
elif i == 2:
new_corner = [
min(self.polygon[2][0], other.polygon[2][0]),
min(self.polygon[2][1], other.polygon[2][1]),
]
elif i == 3:
new_corner = [
max(self.polygon[3][0], other.polygon[3][0]),
min(self.polygon[3][1], other.polygon[3][1]),
]
new_poly.append(new_corner)
return new_poly
def intersection_area(self, other, x_margin=0, y_margin=0):
x_overlap = self.x_overlap(other, x_margin)
y_overlap = self.y_overlap(other, y_margin)
return x_overlap * y_overlap
def x_overlap(self, other, x_margin=0):
return max(
0,
min(self.bbox[2] + x_margin, other.bbox[2] + x_margin)
- max(self.bbox[0] - x_margin, other.bbox[0] - x_margin),
)
def y_overlap(self, other, y_margin=0):
return max(
0,
min(self.bbox[3] + y_margin, other.bbox[3] + y_margin)
- max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),
)
def intersection_pct(self, other, x_margin=0, y_margin=0):
assert 0 <= x_margin <= 1
assert 0 <= y_margin <= 1
if self.area == 0:
return 0
if x_margin:
x_margin = int(min(self.width, other.width) * x_margin)
if y_margin:
y_margin = int(min(self.height, other.height) * y_margin)
intersection = self.intersection_area(other, x_margin, y_margin)
return intersection / self.area
def shift(self, x_shift: float | None = None, y_shift: float | None = None):
if x_shift is not None:
for corner in self.polygon:
corner[0] += x_shift
if y_shift is not None:
for corner in self.polygon:
corner[1] += y_shift
def clamp(self, bbox: List[float]):
for corner in self.polygon:
corner[0] = max(min(corner[0], bbox[2]), bbox[0])
corner[1] = max(min(corner[1], bbox[3]), bbox[1])
@property
def center(self):
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
def distance(self, other):
center = self.center
other_center = other.center
return (
(center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2
) ** 0.5
def __hash__(self):
return hash(tuple(self.bbox))
================================================
FILE: surya/common/predictor.py
================================================
from typing import Optional
import torch
import torch.nn.functional as F
from surya.common.load import ModelLoader
from surya.settings import settings
class BasePredictor:
model_loader_cls = ModelLoader
batch_size: Optional[int] = None
default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1}
torch_dtype = settings.MODEL_DTYPE
@property
def disable_tqdm(self) -> bool:
return self._disable_tqdm
@disable_tqdm.setter
def disable_tqdm(self, value: bool) -> None:
self._disable_tqdm = bool(value)
def __init__(
self,
checkpoint: Optional[str] = None,
device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
dtype: Optional[torch.dtype | str] = None,
attention_implementation: Optional[str] = None,
):
if dtype is None:
dtype = self.torch_dtype
self.model = None
self.processor = None
loader = self.model_loader_cls(checkpoint)
self.model = loader.model(device, dtype, attention_implementation)
self.processor = loader.processor()
self._disable_tqdm = settings.DISABLE_TQDM
def to(self, device_dtype: torch.device | str | None = None):
model_moved = False
if hasattr(self, "model") and self.model:
self.model.to(device_dtype)
model_moved = True
if hasattr(self, "foundation_predictor") and self.foundation_predictor:
self.foundation_predictor.model.to(device_dtype)
model_moved = True
if not model_moved:
raise ValueError("Model not loaded")
def get_batch_size(self):
batch_size = self.batch_size
if batch_size is None:
batch_size = self.default_batch_sizes["cpu"]
if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes:
batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]
return batch_size
@staticmethod
def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor
if len(tensor.shape) == 1:
# If tensor is 1D, we need to pad it to the batch size
pad_size = batch_size - current_batch_size
return F.pad(tensor, (0, pad_size), mode="constant", value=0)
pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
return F.pad(tensor, padding, mode="constant", value=0)
def __call__(self, *args, **kwargs):
raise NotImplementedError()
================================================
FILE: surya/common/pretrained.py
================================================
from typing import Optional
from transformers import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
class SuryaPreTrainedModel(PreTrainedModel):
# No-op if we pass attention, so we can set attention however we want in the config
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], **kwargs
):
if attn_implementation is None:
try:
self._sdpa_can_dispatch(True)
attn_implementation = "sdpa"
except (ValueError, ImportError):
attn_implementation = "eager"
if self._supports_flash_attn and is_flash_attn_2_available():
attn_implementation = "flash_attention_2"
return attn_implementation
================================================
FILE: surya/common/s3.py
================================================
import json
import os
import shutil
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import requests
from tqdm import tqdm
from surya.logging import get_logger
from surya.settings import settings
logger = get_logger()
# Lock file expiration time in seconds (10 minutes)
LOCK_EXPIRATION = 600
def join_urls(url1: str, url2: str):
url1 = url1.rstrip("/")
url2 = url2.lstrip("/")
return f"{url1}/{url2}"
def get_model_name(pretrained_model_name_or_path: str):
return pretrained_model_name_or_path.split("/")[0]
def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024):
local_path = Path(local_path)
try:
response = requests.get(remote_path, stream=True, allow_redirects=True)
response.raise_for_status() # Raise an exception for bad status codes
# Get file size from headers for progress bar
total_size = int(response.headers.get('content-length', 0))
# Create progress bar with file name and size info
filename = local_path.name
pbar = tqdm(
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading {filename}",
miniters=1
)
with open(local_path, "wb") as f:
downloaded = 0
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
downloaded += len(chunk)
pbar.update(len(chunk))
pbar.close()
return local_path
except Exception as e:
if local_path.exists():
local_path.unlink()
logger.error(f"Download error for file {remote_path}: {str(e)}")
raise
def check_manifest(local_dir: str):
local_dir = Path(local_dir)
manifest_path = local_dir / "manifest.json"
if not os.path.exists(manifest_path):
return False
try:
with open(manifest_path, "r") as f:
manifest = json.load(f)
for file in manifest["files"]:
if not os.path.exists(local_dir / file):
return False
except Exception:
return False
return True
def download_directory(remote_path: str, local_dir: str):
model_name = get_model_name(remote_path)
s3_url = join_urls(settings.S3_BASE_URL, remote_path)
# Check to see if it's already downloaded
model_exists = check_manifest(local_dir)
if model_exists:
return
# Use tempfile.TemporaryDirectory to automatically clean up
with tempfile.TemporaryDirectory() as temp_dir:
# Download the manifest file
manifest_file = join_urls(s3_url, "manifest.json")
manifest_path = os.path.join(temp_dir, "manifest.json")
download_file(manifest_file, manifest_path)
# List and download all files
with open(manifest_path, "r") as f:
manifest = json.load(f)
pbar = tqdm(
desc=f"Downloading {model_name} model to {local_dir}",
total=len(manifest["files"]),
)
with ThreadPoolExecutor(
max_workers=settings.PARALLEL_DOWNLOAD_WORKERS
) as executor:
futures = []
for file in manifest["files"]:
remote_file = join_urls(s3_url, file)
local_file = os.path.join(temp_dir, file)
futures.append(executor.submit(download_file, remote_file, local_file))
for future in futures:
future.result()
pbar.update(1)
pbar.close()
# Move all files to new directory
for file in os.listdir(temp_dir):
shutil.move(os.path.join(temp_dir, file), local_dir)
class S3DownloaderMixin:
s3_prefix = "s3://"
@classmethod
def get_local_path(cls, pretrained_model_name_or_path) -> str:
if pretrained_model_name_or_path.startswith(cls.s3_prefix):
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
cls.s3_prefix, ""
)
cache_dir = settings.MODEL_CACHE_DIR
local_path = os.path.join(cache_dir, pretrained_model_name_or_path)
os.makedirs(local_path, exist_ok=True)
else:
local_path = ""
return local_path
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
# Allow loading models directly from the hub, or using s3
if not pretrained_model_name_or_path.startswith(cls.s3_prefix):
return super().from_pretrained(
pretrained_model_name_or_path, *args, **kwargs
)
local_path = cls.get_local_path(pretrained_model_name_or_path)
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
cls.s3_prefix, ""
)
# Retry logic for downloading the model folder
retries = 3
delay = 5
attempt = 0
success = False
while not success and attempt < retries:
try:
download_directory(pretrained_model_name_or_path, local_path)
success = True # If download succeeded
except Exception as e:
logger.error(
f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}"
)
attempt += 1
if attempt < retries:
logger.info(f"Retrying in {delay} seconds...")
time.sleep(delay) # Wait before retrying
else:
logger.error(
f"Failed to download {pretrained_model_name_or_path} after {retries} attempts."
)
raise e # Reraise exception after max retries
return super().from_pretrained(local_path, *args, **kwargs)
================================================
FILE: surya/common/surya/__init__.py
================================================
import warnings
from typing import Optional, Tuple, TypedDict
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.config import SuryaModelConfig
from surya.common.surya.decoder import SuryaDecoderModel
from surya.common.surya.embedder import SimpleTokenEmbedder
from surya.common.surya.encoder import SuryaEncoderModel
from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat
from surya.common.xla import get_nearest_pad
from surya.settings import settings
from surya.logging import get_logger
logger = get_logger()
@dataclass
class SuryaModelOutput(CausalLMOutputWithPast):
bbox_logits: torch.FloatTensor = None
lm_logits: torch.FloatTensor = None
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
class KwargsForCausalLM(FlashAttentionKwargs): ...
class DistanceProjection(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.act = nn.SiLU()
self.fc2 = nn.Linear(out_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc1.bias)
nn.init.zeros_(self.fc2.bias)
class BboxHead(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.proj_layers = nn.ModuleList(
[nn.Linear(in_features, in_features) for _ in range(6)]
)
self.act = nn.SiLU()
self.out_proj = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.proj_layers:
x = layer(x)
x = self.act(x)
x = self.out_proj(x)
return x
class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
config_class = SuryaModelConfig
supports_gradient_checkpointing = True
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
main_input_name = "input_ids"
_tied_weights_keys = ["lm_head.weight"]
def __init__(
self,
config: SuryaModelConfig,
embedder: SimpleTokenEmbedder = None,
vision_encoder: SuryaEncoderModel = None,
decoder: SuryaDecoderModel = None,
**kwargs,
):
super().__init__(config, **kwargs)
if vision_encoder is None:
vision_encoder = SuryaEncoderModel(config.vision_encoder)
if decoder is None:
decoder = SuryaDecoderModel(config.decoder)
if embedder is None:
embedder = SimpleTokenEmbedder(config)
self.vision_encoder = vision_encoder
self.decoder = decoder
self.embedder = embedder
# Simple encoding for image patches
self.img_w_embed = nn.Embedding(
self.config.image_embed_encoding_size,
self.config.hidden_size,
)
self.img_h_embed = nn.Embedding(
self.config.image_embed_encoding_size,
self.config.hidden_size,
)
# Tying configs
self.vision_encoder.config = self.config.vision_encoder
self.decoder.config = self.config.decoder
self.bbox_head = BboxHead(config.hidden_size, 6)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
if (
self.config.multi_output_distance is not None
and self.config.multi_output_distance > 0
):
self.multi_output_projections = nn.ModuleList(
[
DistanceProjection(
in_features=config.hidden_size, out_features=config.hidden_size
)
for _ in range(self.config.multi_output_distance)
]
)
def tie_weights(self):
self._tie_weights()
def _tie_weights(self):
# Tie weights of lm head and token embedder
self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def get_input_embeddings(self) -> nn.Module:
return self.embedder.token_embed
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head = new_embeddings
def set_input_embeddings(self, new_embeddings: nn.Module):
self.embedder.token_embed = new_embeddings
def maybe_static_pad_image_inputs(
self,
chunk_pixels: torch.Tensor,
chunk_grid_thw: torch.Tensor,
actual_chunk_len: int,
encoder_chunk_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
valid_embed_len = actual_chunk_len // (
self.vision_encoder.spatial_merge_size**2
)
if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size:
padding_len = encoder_chunk_size - actual_chunk_len
chunk_pixels = F.pad(
chunk_pixels,
(0, 0, 0, padding_len),
mode="constant",
value=0.0,
)
padding_grid = torch.tensor(
[[1, 2, padding_len // 2]],
device=chunk_grid_thw.device,
dtype=chunk_grid_thw.dtype,
)
chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0)
return chunk_pixels, chunk_grid_thw, valid_embed_len
def get_image_embeddings(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
encoder_chunk_size: int,
valid_batch_size: torch.Tensor | None = None,
max_batch_size: int | None = None,
):
# embed all images with the vision encoder after they have already been tiled and flattened into a single batch
chunks = [0]
grid_chunks = [0]
curr_chunk_len = 0
curr_seq_len = 0
for i in range(len(grid_thw)):
curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item()
if curr_chunk_len > encoder_chunk_size:
chunks.append(curr_chunk_len + curr_seq_len)
curr_seq_len += curr_chunk_len
curr_chunk_len = 0
grid_chunks.append(i + 1)
if curr_chunk_len > 0:
chunks.append(pixel_values.shape[0])
grid_chunks.append(len(grid_thw))
assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], (
f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}"
)
logger.debug(
f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}"
)
embeddings = []
for i in range(len(chunks) - 1):
start = chunks[i]
end = chunks[i + 1]
grid_start = grid_chunks[i]
grid_end = grid_chunks[i + 1]
chunk_pixels = pixel_values[start:end]
chunk_grid_thw = grid_thw[grid_start:grid_end]
actual_chunk_len = end - start
chunk_pixels, chunk_grid_thw, valid_embed_len = (
self.maybe_static_pad_image_inputs(
chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size
)
)
chunk_embeddings = self.vision_encoder.embed_images(
image_batch=chunk_pixels.unsqueeze(0).to(device=self.device),
grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device),
)
embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0))
if len(embeddings) == 0:
raise ValueError(
"No image embeddings were generated. Check the input images and grid sizes."
)
elif len(embeddings) == 1:
embeddings = embeddings[0]
else:
embeddings = torch.cat(embeddings, dim=0)
encoding_2d = self.get_2d_learned_embeddings(
grid_thw,
device=embeddings.device,
bbox_size=self.config.image_embed_encoding_multiplier,
)
assert embeddings.shape[0] == encoding_2d.shape[0], (
f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}"
)
assert embeddings.shape[1] == encoding_2d.shape[1], (
f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}"
)
embeddings = embeddings + encoding_2d
return embeddings
def embed_ids_boxes_images(
self,
input_ids,
image_embeddings,
encoder_chunk_size: int,
valid_batch_size: torch.Tensor | None = None,
input_boxes: torch.Tensor | None = None,
embed_boxes: torch.Tensor | None = None,
):
"""
Insert embedded image tiles into the corresponding positions into the full input sequence
Positions to insert new tokens are indicated by the special image token index
"""
# This is batched in the inner call
inputs_embeds = self.embedder.embed(
input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
)
if image_embeddings is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)
if inputs_embeds[special_image_mask].numel() != image_embeddings.numel():
n_image_tokens = torch.sum((input_ids == self.config.image_token_id))
n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1]
warnings.warn(
f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results"
)
image_features = image_embeddings.to(inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
else:
assert (input_ids == self.config.image_token_id).sum() == 0, (
"Image tokens were present in the input but no input images were provided"
)
return inputs_embeds
def get_2d_learned_embeddings(
self,
grid_thw,
device: str | torch.device = "cpu",
bbox_size: int = 256,
):
all_embeddings = []
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.config.merge_size,
grid_w // self.config.merge_size,
)
# Scale to 0-1024
llm_grid_h = (
torch.arange(llm_grid_h, device=device)
/ max(1, (llm_grid_h - 1))
* bbox_size
)
llm_grid_w = (
torch.arange(llm_grid_w, device=device)
/ max(1, (llm_grid_w - 1))
* bbox_size
)
llm_grid_w_idx = llm_grid_w.to(torch.long)
llm_grid_h_idx = llm_grid_h.to(torch.long)
llm_grid_w = self.img_w_embed(llm_grid_w_idx)
llm_grid_h = self.img_h_embed(llm_grid_h_idx)
full_grid = llm_grid_h[:, None] + llm_grid_w[None, :]
flattened = full_grid.flatten(
0, 1
) # Flatten first dimension, so they are seq_len x embed_dim
all_embeddings.append(flattened)
return torch.concat(
all_embeddings, dim=0
) # Shape is num_image_tokens x embed_dim
def get_logits(self, hidden_states):
assert hidden_states.shape[1] == 1, (
"Multi output predictions only applied on the last token"
)
all_lm_logits = []
all_bbox_logits = []
current_hidden = hidden_states
# Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
for i in range(self.config.multi_output_distance + 1):
if i > 0:
current_hidden = self.multi_output_projections[i - 1](current_hidden)
lm_logits = self.lm_head(current_hidden)
bbox_logits = F.sigmoid(self.bbox_head(current_hidden))
all_lm_logits.append(lm_logits)
all_bbox_logits.append(bbox_logits)
# Concatenate along sequence dimension (dim=1)
final_lm_logits = torch.cat(all_lm_logits, dim=1)
final_bbox_logits = torch.cat(all_bbox_logits, dim=1)
return final_lm_logits, final_bbox_logits
def forward(
self,
input_ids=None,
image_embeddings=None,
labels=None,
image_tiles=None,
grid_thw=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
cache_position=None,
past_key_values=None,
output_hidden_states=False,
output_attentions=False,
use_cache=False,
encoder_chunk_size=32768,
cache_idxs=None,
num_valid_tokens=None,
prefill=True,
text_lengths=None,
valid_batch_size: torch.Tensor = None,
input_boxes=None,
embed_boxes=None,
logits_to_keep=None,
**kwargs: KwargsForCausalLM,
):
if any([
input_ids is None,
position_ids is None,
cache_position is None,
(
prefill
and not (
(image_tiles is not None and grid_thw is not None)
or image_embeddings is not None
)
),
]):
raise ValueError(
"`input_ids`, `position_ids`, and `cache_position` **must** be specified. "
"For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`."
)
inputs_embeds = self.embed_ids_boxes_images(
input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes
)
# Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
# Skipped during decoding since not required
if self.decoder.config._attn_implementation == "flash_attention_2" and prefill:
# Needed for CPU -> GPU
from surya.common.surya.flash_attn_utils import _get_unpad_data
batch_size, query_length, _ = inputs_embeds.shape
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
attention_mask
)
kwargs["batch_size"] = batch_size
kwargs["query_length"] = query_length
kwargs["indices_k"] = indices_k
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
attention_mask = causal_mask
outputs = self.decoder(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=True,
use_cache=use_cache,
cache_idxs=cache_idxs,
num_valid_tokens=num_valid_tokens,
prefill=prefill,
text_lengths=text_lengths,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if logits_to_keep is not None:
hidden_states = hidden_states[:, -logits_to_keep:, :]
hidden_states = hidden_states.contiguous()
loss = None
if labels is not None:
# Training, return full logits
lm_logits = self.lm_head(hidden_states)
bbox_logits = None
vocab_size = lm_logits.shape[-1]
labels = torch.roll(labels, shifts=-1, dims=-1)
loss = F.cross_entropy(
lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean"
)
else:
lm_logits, bbox_logits = self.get_logits(hidden_states)
return SuryaModelOutput(
loss=loss,
bbox_logits=bbox_logits,
lm_logits=lm_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions if output_attentions else None,
past_key_values=outputs.past_key_values,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.decoder.config._attn_implementation == "flash_attention_2":
return attention_mask
# We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_key_values.max_cache_len
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: SuryaModelConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
# Batch-aware diagonal attend mask
diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(
0
) > cache_position.unsqueeze(-1)
causal_mask = (
causal_mask.unsqueeze(0) * diagonal_attend_mask
) # (batch_size, seq_len, target_len)
causal_mask = causal_mask[
:, None, :, :
] # (batch_size, 1, seq_len, target_len)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class SuryaXLAModel(SuryaModel):
def get_image_embeddings(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
encoder_chunk_size: int,
valid_batch_size: torch.Tensor | None = None,
max_batch_size: int | None = None,
):
# embed all images with the vision encoder after they have already been tiled and flattened into a single batch
unpadded_max_grid_size = (
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item()
)
max_grid_size = get_nearest_pad(
unpadded_max_grid_size,
) # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw
# Always need 2 items in each row batch
if max_grid_size == unpadded_max_grid_size:
max_grid_size += 16
full_image_grid = torch.zeros(
(valid_batch_size, max_grid_size, pixel_values.shape[-1]),
dtype=pixel_values.dtype,
)
# Roll out into a full grid
seq_len = 0
row_grids = []
for i in range(valid_batch_size):
curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]
full_image_grid[i, -curr_sample_len:] = pixel_values[
seq_len : seq_len + curr_sample_len
]
padded_len = max_grid_size - curr_sample_len
if padded_len > 0:
row_grid = torch.tensor(
[
[1, 4, padded_len // 4],
grid_thw[i].tolist(),
],
dtype=torch.long,
)
else:
row_grid = torch.tensor(
[
grid_thw[i].tolist(),
],
dtype=torch.long,
)
row_grids.append(row_grid)
seq_len += curr_sample_len
# bsz, 2, 3
row_grids = torch.stack(row_grids, dim=0)
if settings.FOUNDATION_STATIC_CACHE:
# Pad to max batch size, repeat the final row
row_grids = pad_to_batch_size_repeat(
row_grids,
batch_size=max_batch_size,
)
full_image_grid = pad_to_batch_size(
full_image_grid,
batch_size=max_batch_size,
)
full_image_grid = full_image_grid.to(self.device)
embeddings = self.vision_encoder.embed_images(
image_batch=full_image_grid, grid_thw=row_grids.to(self.device)
)
encoding_2d = self.get_2d_learned_embeddings(
row_grids,
bbox_size=self.config.image_embed_encoding_multiplier,
)
embeddings += encoding_2d
return embeddings
def embed_ids_boxes_images(
self,
input_ids,
image_embeddings,
encoder_chunk_size: int,
valid_batch_size: torch.Tensor | None = None,
input_boxes: torch.Tensor | None = None,
embed_boxes: torch.Tensor | None = None,
):
"""
Insert embedded image tiles into the corresponding positions into the full input sequence
Positions to insert new tokens are indicated by the special image token index
"""
# This is batched in the inner call
inputs_embeds = self.embedder.embed(
input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
)
if image_embeddings is not None:
image_token_id_tensor = torch.tensor(
self.config.image_token_id,
device=inputs_embeds.device,
dtype=torch.long,
)
mask = input_ids == image_token_id_tensor
last_image_token_pos = (
mask.size(1)
- 1
- mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True)
)
# Calculate start position to replace N positions ending at (and including) the last image token
start_positions = last_image_token_pos - image_embeddings[0].shape[0]
batch_size, insert_len = image_embeddings.shape[:2]
# Create position indices for each insertion
pos_indices = torch.arange(
insert_len, device=inputs_embeds.device
).unsqueeze(0)
insert_positions = start_positions + pos_indices
idx = insert_positions.unsqueeze(-1).expand(
-1, -1, inputs_embeds.size(-1)
) # [B,N,D]
inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings)
inputs_embeds = inputs_embeds * (
input_ids != self.config.pad_token_id
).unsqueeze(-1).to(inputs_embeds.dtype)
return inputs_embeds
def get_2d_learned_embeddings(
self,
grid_thw,
bbox_size: int = 256,
):
dev = grid_thw.device
all_row_coords = []
all_col_coords = []
for row_grid in grid_thw:
merge = self.config.merge_size
# per-sample grid sizes after merge
H = (row_grid[:, 1] // merge).long() # (B,)
W = (row_grid[:, 2] // merge).long() # (B,)
row_coords = torch.cat(
[
torch.linspace(0, bbox_size, steps=int(h), device=dev)
.round()
.repeat_interleave(w) # repeat each row value w times
for h, w in zip(H.tolist(), W.tolist())
]
) # (full_grid_size,)
col_coords = torch.cat(
[
torch.linspace(0, bbox_size, steps=int(w), device=dev)
.round()
.repeat(int(h)) # tile the column vector h times
for h, w in zip(H.tolist(), W.tolist())
]
) # (full_grid_size,)
all_row_coords.append(row_coords)
all_col_coords.append(col_coords)
row_coords = torch.stack(all_row_coords, dim=0).to(self.device)
col_coords = torch.stack(all_col_coords, dim=0).to(self.device)
emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long())
return emb
================================================
FILE: surya/common/surya/config.py
================================================
from typing import Optional
from transformers import PretrainedConfig
from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.encoder.config import SuryaEncoderConfig
from surya.common.surya.decoder.config import SuryaDecoderConfig
class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig):
model_type = "surya-multimodal-foundation"
is_composition = True
def __init__(
self,
vocab_size=65536,
bbox_size=1025,
blank_bbox_token_id=1025,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
image_token_id=3,
register_token_ids=(4, 5, 6, 7),
eoi_token_id=8,
beacon_token_id=9,
special_token_count=4,
max_sequence_length=1536,
special_ocr_tokens=None,
vision_encoder=None,
decoder=None,
tasks: dict | None = None,
bbox_embed_size: int = 64,
num_register_tokens: int = 4,
image_embed_encoding_size: int = 1024,
image_embed_encoding_multiplier: int = 256,
num_beacon_tokens: int = 1,
beacon_token_interval: int = 4096,
sliding_window: Optional[int] = None,
multi_output_distance: int = 4,
max_multi_out: int = 8,
**kwargs,
):
super().__init__(**kwargs)
self.is_encoder_decoder = False
self.vocab_size = vocab_size
self.bbox_size = bbox_size
self.blank_bbox_token_id = blank_bbox_token_id
self.image_token_id = image_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.eoi_token_id = eoi_token_id
self.beacon_token_id = beacon_token_id
self.special_ocr_tokens = special_ocr_tokens
self.special_token_count = special_token_count # pad, bos, etc, tokens
self.max_sequence_length = max_sequence_length
self.tasks = tasks
self.tie_word_embeddings = True
self.bbox_embed_size = bbox_embed_size
self.num_register_tokens = num_register_tokens
self.register_token_ids = register_token_ids
self.image_embed_encoding_size = image_embed_encoding_size
self.image_embed_encoding_multiplier = image_embed_encoding_multiplier
self.num_beacon_tokens = num_beacon_tokens
self.beacon_token_interval = beacon_token_interval
self.sliding_window = sliding_window
self.multi_output_distance = multi_output_distance
self.max_multi_out = max_multi_out
if self.sliding_window is None:
self.sliding_window = self.max_sequence_length
if isinstance(vision_encoder, dict):
vision_encoder = SuryaEncoderConfig(**vision_encoder)
elif vision_encoder is None:
vision_encoder = SuryaEncoderConfig()
self.vision_encoder = vision_encoder
if isinstance(decoder, dict):
decoder = SuryaDecoderConfig(**decoder)
elif decoder is None:
decoder = SuryaDecoderConfig()
self.decoder = decoder
self.hidden_size = self.decoder.hidden_size
self.patch_size = self.vision_encoder.spatial_patch_size
self.merge_size = self.vision_encoder.spatial_merge_size
================================================
FILE: surya/common/surya/decoder/__init__.py
================================================
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.processing_utils import Unpack
from transformers.utils import (
logging,
)
from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.surya.decoder.config import SuryaDecoderConfig
logger = logging.get_logger(__name__)
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=True
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_idxs: Optional[List[int]] = None,
num_valid_tokens: Optional[List[int]] = None,
text_lengths: Optional[List[int]] = None,
prefill: bool = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"cache_idxs": cache_idxs,
"num_valid_tokens": num_valid_tokens,
"prefill": prefill,
"text_lengths": text_lengths,
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
elif self.config._attn_implementation == "flash_attention_2":
# Needed for CPU -> GPU
from surya.common.surya.flash_attn_utils import (
flash_attn_decode,
flash_attn_prefill,
)
if prefill:
attention_interface = flash_attn_prefill
else:
attention_interface = flash_attn_decode
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
"""
IMPORTANT:
We sometimes use a custom sliding window impl. during training
We force this to None to ensure that the HF attention integrations do not
perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead
to infer the final mask
SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23)
"""
sliding_window = None
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
cache_idxs: Optional[List[int]] = None,
num_valid_tokens: Optional[List[int]] = None,
text_lengths: Optional[List[int]] = None,
prefill: bool = False,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cache_idxs=cache_idxs,
num_valid_tokens=num_valid_tokens,
text_lengths=text_lengths,
prefill=prefill,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, config: SuryaDecoderConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Qwen2PreTrainedModel(SuryaPreTrainedModel):
config_class = SuryaDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class SuryaDecoderModel(Qwen2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input
Args:
config: Qwen2Config
"""
def __init__(self, config: SuryaDecoderConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_idxs: Optional[List[int]] = None,
num_valid_tokens: Optional[List[int]] = None,
text_lengths: Optional[List[int]] = None,
prefill: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
raise ValueError("You must specify inputs_embeds")
if cache_position is None:
raise ValueError("You must specify cache_position")
if position_ids is None:
raise ValueError("You must specify position_ids")
hidden_states = inputs_embeds
causal_mask = (
attention_mask # We make the 4D mask in the combined model when needed
)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cache_idxs=cache_idxs,
num_valid_tokens=num_valid_tokens,
prefill=prefill,
text_lengths=text_lengths,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
return output if return_dict else output.to_tuple()
================================================
FILE: surya/common/surya/decoder/config.py
================================================
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class SuryaDecoderConfig(PretrainedConfig):
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = False # Disable sliding window
self.sliding_window = (
sliding_window # we check `use_sliding_window` in the modeling code
)
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
================================================
FILE: surya/common/surya/embedder/__init__.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleTokenEmbedder(nn.Module):
def __init__(self, config):
super().__init__()
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.bbox_embed = nn.ModuleList(
[
nn.Embedding(
config.bbox_size + config.special_token_count,
config.bbox_embed_size,
)
for _ in range(6)
]
)
self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1
self.max_bbox_size = config.bbox_size
def embed(
self,
input_tokens: torch.Tensor,
input_boxes: torch.Tensor | None,
embed_boxes: torch.Tensor,
) -> torch.Tensor:
# Embed tokens
token_embeds = self.token_embed(input_tokens)
# Optionally embed boxes
if input_boxes is not None and embed_boxes.any(): # Is none in prefill
input_boxes = input_boxes.to(torch.long)
bbox_loss_ignore_mask = (
(input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size)
).unsqueeze(-1)
input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding)
bbox_embeds = torch.sum(
torch.stack(
[
self.bbox_embed[i](input_boxes[:, :, i])
for i in range(len(self.bbox_embed))
],
dim=-1,
),
dim=-1,
)
bbox_embeds = F.pad(
bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0)
)
embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds)
bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds)
mask = embed_boxes & ~bbox_loss_ignore_mask
bbox_embeds *= mask.float()
token_embeds = token_embeds + bbox_embeds
return token_embeds
================================================
FILE: surya/common/surya/encoder/__init__.py
================================================
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.surya.encoder.config import SuryaEncoderConfig
from surya.common.xla import get_nearest_pad
from surya.logging import get_logger
from surya.settings import settings
if settings.FOUNDATION_XLA:
import torch_xla.experimental.custom_kernel
from surya.logging import get_logger
logger = get_logger()
class Qwen2_5_VLMLP(nn.Module):
def __init__(self, config, bias: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(
self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
)
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(
in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
bsz = hidden_states.shape[0]
hidden_states = hidden_states.view(
-1,
self.in_channels,
self.temporal_patch_size,
self.patch_size,
self.patch_size,
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
bsz, -1, self.embed_dim
)
return hidden_states
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.inv_freq = 1.0 / (
theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2_5_VLPatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz = x.shape[0]
x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size))
return x
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
from flash_attn.layers.rotary import apply_rotary_emb
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class Qwen2_5_VLVisionXLASdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
self.head_dim = dim // num_heads
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
q, k, v = (
self.qkv(hidden_states)
.reshape(bsz, seq_length, 3, self.num_heads, -1)
.permute(0, 2, 1, 3, 4)
.unbind(1)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool)
cu_seqlens_cpu = cu_seqlens.cpu()
for j in range(bsz):
batch_seqlens = cu_seqlens_cpu[j]
for i in range(1, len(batch_seqlens)):
attention_mask[
j,
...,
batch_seqlens[i - 1] : batch_seqlens[i],
batch_seqlens[i - 1] : batch_seqlens[i],
] = True
attention_mask = attention_mask.to(q.device)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attention_mask,
dropout_p=0.0,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
self.head_dim = dim // num_heads
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# Note, this is faster than SDPA, but pretty memory inefficient
# It also has significant accuracy issues
bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
# Single reshape to target layout - avoid multiple operations
q, k, v = (
self.qkv(hidden_states)
.reshape(bsz, seq_length, 3, self.num_heads, -1)
.permute(0, 2, 1, 3, 4)
.unbind(1)
)
# Apply rotary embeddings if provided
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
# Single reshape to flash attention format [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
total_seqlen = q.shape[2]
# from cu_seqlens to segment ids for each position in dim 0
additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype)
min_val = torch.finfo(q.dtype).min
for i in range(bsz):
padding_end = cu_seqlens[i][1].item()
additive_bias[i, :, :, :padding_end] = min_val
additive_bias = additive_bias.to(hidden_states.device)
attn_scale = 1 / math.sqrt(self.head_dim)
attn_output = torch_xla.experimental.custom_kernel.flash_attention(
q, k, v, sm_scale=attn_scale, ab=additive_bias
)
attn_output = (
attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1)
)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
from flash_attn import flash_attn_varlen_func
bsz = hidden_states.shape[0]
seq_length = hidden_states.shape[1]
q, k, v = (
self.qkv(hidden_states)
.reshape(bsz, seq_length, 3, self.num_heads, -1)
.permute(0, 2, 1, 3, 4)
.unbind(1)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0))
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
cu_seqlens = cu_seqlens.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(bsz, seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class Qwen2_5_VLVisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
q, k, v = (
self.qkv(hidden_states)
.reshape(bsz, seq_length, 3, self.num_heads, -1)
.permute(0, 2, 1, 3, 4)
.unbind(1)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.full(
[bsz, 1, seq_length, seq_length],
torch.finfo(q.dtype).min,
device=q.device,
dtype=q.dtype,
)
for j in range(bsz):
batch_seqlens = cu_seqlens[j]
for i in range(1, len(batch_seqlens)):
attention_mask[
j,
...,
batch_seqlens[i - 1] : batch_seqlens[i],
batch_seqlens[i - 1] : batch_seqlens[i],
] = 0
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5_VLVisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def unpack_qkv_with_mask(self, q, k, v, cu_seqlens):
"""
Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask.
Args:
q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim)
cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths
Returns:
batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len)
with 0 for valid tokens and -inf for padding (for additive attention)
"""
device = q.device
dtype = q.dtype
batch_size = cu_seqlens.shape[0] - 1
num_heads = q.shape[1]
head_dim = q.shape[2]
seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] # Keep as tensor
max_seq_len = seq_lengths.max().item() # Use .max() on tensor
if settings.FOUNDATION_STATIC_CACHE:
# Pad max_seq_len to the nearest multiple for compilation
max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16)
# Pad batch_size to the nearest multiple for compilation
batch_size = get_nearest_pad(batch_size, pad_multiple=2)
# Ensure seq_lengths is a tensor of the correct size
seq_lengths = F.pad(
seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0
)
# some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"?
# don't do this - it's a path to madness. For some reason, this loop is optimal
batch_indices = []
position_indices = []
for i, seq_len in enumerate(
seq_lengths.tolist()
): # Convert to list only for iteration
batch_indices.extend([i] * seq_len)
position_indices.extend(list(range(seq_len)))
batch_indices = torch.tensor(batch_indices, device=device)
position_indices = torch.tensor(position_indices, device=device)
batched_q = torch.zeros(
(batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype
)
batched_k = torch.zeros_like(batched_q)
batched_v = torch.zeros_like(batched_q)
# Create additive attention mask
attention_mask = torch.full(
(batch_size, max_seq_len, max_seq_len),
fill_value=float("-inf"),
device=device,
dtype=dtype,
)
# Create mask for valid positions
seq_range = torch.arange(max_seq_len, device=device)
valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze(
1
) # (batch_size, max_seq_len)
valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(
1
) # (batch_size, max_seq_len, max_seq_len)
# Simply use boolean indexing to set valid positions to 0
attention_mask[valid_2d] = 0
attention_mask = attention_mask.unsqueeze(
1
) # (batch_size, 1, max_seq_len, max_seq_len)
batched_q[batch_indices, position_indices] = q
batched_k[batch_indices, position_indices] = k
batched_v[batch_indices, position_indices] = v
return (
batched_q,
batched_k,
batched_v,
attention_mask,
batch_indices,
position_indices,
)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states.squeeze(0)
cu_seqlens = cu_seqlens.squeeze(0)
seq_length = hidden_states.shape[0]
q, k, v = (
self.qkv(hidden_states)
.reshape(seq_length, 3, self.num_heads, -1)
.permute(1, 0, 2, 3)
.unbind(0)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
q, k, v, attention_mask, batch_indices, position_indices = (
self.unpack_qkv_with_mask(q, k, v, cu_seqlens)
)
batch_size, max_seqlen = q.shape[:2]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attention_mask,
dropout_p=0.0,
)
attn_output = attn_output.permute(0, 2, 1, 3).reshape(
batch_size, max_seqlen, -1
) # Bring back to (batch_size, max_seqlen, hidden_dim)
attn_output = attn_output[batch_indices, position_indices]
attn_output = self.proj(attn_output)
return attn_output.unsqueeze(0)
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5_VLVisionAttention,
"flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2
if settings.FOUNDATION_XLA
else Qwen2_5_VLVisionFlashAttention2,
"sdpa": Qwen2_5_VLVisionXLASdpaAttention
if settings.FOUNDATION_XLA
else Qwen2_5_VLVisionSdpaAttention,
}
class Qwen2_5_VLVisionBlock(nn.Module):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
config.hidden_size, num_heads=config.num_heads
)
self.mlp = Qwen2_5_VLMLP(config, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
Qwen2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Qwen2_5_VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel):
config_class = SuryaEncoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
config_class = SuryaEncoderConfig
_no_split_modules = ["Qwen2_5_VLVisionBlock"]
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.fullatt_block_indexes = config.fullatt_block_indexes
self.window_size = config.window_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.hidden_size,
)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen2_5_VLVisionBlock(config, config._attn_implementation)
for _ in range(config.depth)
]
)
self.merger = Qwen2_5_VLPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
)
self.gradient_checkpointing = False
def rot_pos_emb(self, grid_thw):
rotary_pos_emb = []
grid_thw_list = grid_thw.cpu().tolist()
for batch_item in grid_thw_list:
row_pos_ids = []
heights = [h for _, h, _ in batch_item]
widths = [w for _, _, w in batch_item]
max_grid_size = max(heights + widths)
for t, h, w in batch_item:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
# shape: token_count, 2
row_pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
)
# shape: token_count, 2
pos_ids = torch.cat(row_pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb.append(rotary_pos_emb_row)
rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0)
return rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states.
"""
bsz, seq_len, _ = hidden_states.size()
hidden_states = self.patch_embed(hidden_states) # (bsz, seq_len, hidden_dim)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# hidden_states = hidden_states.reshape(bsz, seq_len, -1)
# rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to(
hidden_states.device
)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum(
dim=1,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__,
hidden_states,
cu_seqlens,
None,
position_embeddings,
)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
hidden_states = self.merger(hidden_states)
return hidden_states
class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel):
@property
def image_size(self) -> int:
config: SuryaEncoderConfig = self.config
if isinstance(config.image_size, tuple) and len(config.image_size) == 2:
return config.image_size
elif isinstance(config.image_size, int):
return (config.image_size, config.image_size)
raise ValueError(
f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}"
)
@property
def hidden_size(self) -> int:
config: SuryaEncoderConfig = self.config
return config.hidden_size
def embed_images(
self,
image_batch: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
return super().forward(
hidden_states=image_batch,
grid_thw=grid_thw,
)
================================================
FILE: surya/common/surya/encoder/config.py
================================================
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class SuryaEncoderConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "depth",
}
def __init__(
self,
depth=8,
hidden_size=1280,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
spatial_patch_size=14,
temporal_patch_size=1,
tokens_per_second=4,
window_size=112,
out_hidden_size=1280,
fullatt_block_indexes=(3, 7),
initializer_range=0.02,
image_size=4096,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
self.spatial_patch_size = spatial_patch_size
self.image_size = image_size
================================================
FILE: surya/common/surya/flash_attn_utils.py
================================================
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
from flash_attn.bert_padding import index_first_axis as _index_first_axis
from flash_attn.bert_padding import pad_input
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
cu_seqlens (`torch.Tensor`):
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
max_seqlen_in_batch (`int`):
Maximum sequence length in batch.
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _upad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
query_length: int,
indices_k,
cu_seqlens_k,
max_seqlen_in_batch_k
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
tensors for query, key, value tensors.
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = _index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
raise NotImplementedError()
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def flash_attn_prefill(
module: torch.nn.Module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
dropout: float,
scaling: float,
query_length: int,
batch_size: int,
indices_k: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_in_batch_k: int,
**kwargs
):
"""
Wrapper for flash attention during the prefill stage
query_states must have shape (batch_size, num_heads, seq_len, head_dim)
key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)
This is the opposite of what is required by flash attention, but keeps parity with the HF convention
query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs
"""
query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)
q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
# Returning None for attn_weights to match other attention interfaces
flash_attn_out = _flash_attn_varlen_func(
q_flash,
k_flash,
v_flash,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=scaling,
causal=module.is_causal,
)
return pad_input(flash_attn_out, indices_q, batch_size, query_length), None
# NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility
# This function is an order of magnitude faster than the prefill variant, or using the HF interface
def flash_attn_decode(
module: torch.nn.Module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
scaling: float,
**kwargs,
):
"""
Wrapper for flash attention during the decode stage
query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage
key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)
This is the opposite of what is required by flash attention, but keeps parity with the HF convention
This function computes the left pad and cache seqlens to pass into FA2. For example -
Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token
attention_mask =
tensor([
[0, 0, 1, 1, 1, 0, 0, 0], # ← batch 0
[0, 1, 1, 1, 1, 1, 1, 0], # ← batch 1
])
cache_leftpad = tensor([2, 1], dtype=torch.int32)
cache_seqlens = tensor([5, 7], dtype=torch.int32)
These values allow FlashAttention to use a static cache layout with efficient slicing during decoding.
"""
query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)
cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32)
cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1
# Returning None for attn_weights to match other attention interfaces
return _flash_attn_with_kvcache(
q=query_states,
k_cache=key_states,
v_cache=value_states,
cache_leftpad=cache_leftpad,
cache_seqlens=cache_seqlens,
causal=module.is_causal,
softmax_scale=scaling,
), None
================================================
FILE: surya/common/surya/processor/__init__.py
================================================
import math
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from typing import List, Optional, Tuple
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.processor.schema import (
TextInput,
ImageInput,
ProcessorOutput,
)
from surya.common.surya.schema import TaskNames
from surya.logging import get_logger
from surya.settings import settings
logger = get_logger()
# Task agnostic tokens - Every task will use these in some form or another
EOS_TOKEN = ""
EOI_TOKEN = "" # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways.
IMAGE_TOKEN = ""
PAD_TOKEN = ""
NO_OUTPUT_TOKEN = ""
IMAGE_ROTATED_TOKEN = ""
REGISTER_TOKENS = ["", "", "", ""]
BEACON_TOKEN = ""
NOMATH_TOKEN = ""
# Task specific tokens
OCR_WITH_BOXES_BOS_TOKEN = ""
OCR_WITHOUT_BOXES_BOS_TOKEN = ""
BLOCK_WITHOUT_BOXES_TOKEN = ""
LAYOUT_BOS_TOKEN = ""
TABLE_STRUCTURE_BOS_TOKEN = ""
class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):
attributes = ["image_processor", "ocr_tokenizer"]
image_processor_class = "BaseImageProcessor"
ocr_tokenizer_class = "PreTrainedTokenizer"
rescale_factor = 1 / 255.0
image_mean = (0.485, 0.456, 0.406)
image_std = (0.229, 0.224, 0.225)
def __init__(
self,
ocr_tokenizer: PreTrainedTokenizer,
blank_bbox_token_id: int,
num_register_tokens: int,
patch_size: int,
merge_size: int,
num_beacon_tokens: int,
beacon_token_interval: int,
model_device: str,
**kwargs,
):
self.ocr_tokenizer = ocr_tokenizer
self.patch_size = patch_size
self.merge_size = merge_size
self.num_register_tokens = num_register_tokens
self.num_beacon_tokens = num_beacon_tokens
self.beacon_token_interval = beacon_token_interval
self.tokenizer_vocab_size = 0
for attr in self.attributes:
if "tokenizer" in attr:
self.tokenizer_vocab_size += getattr(self, attr).vocab_size
self.offsets = {"ocr": 0}
# Create special token mapping
self.special_token_mapping = self.ocr_tokenizer.system_tokens
self.register_token_ids = [
self.special_token_mapping.get(r) for r in REGISTER_TOKENS
]
self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN)
self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN)
self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN)
self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN)
self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN)
self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN)
self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN)
self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN)
self.bos_token_id = {
TaskNames.ocr_with_boxes: self.special_token_mapping.get(
OCR_WITH_BOXES_BOS_TOKEN
),
TaskNames.ocr_without_boxes: self.special_token_mapping.get(
OCR_WITHOUT_BOXES_BOS_TOKEN
),
TaskNames.block_without_boxes: self.special_token_mapping.get(
BLOCK_WITHOUT_BOXES_TOKEN
),
TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN),
TaskNames.table_structure: self.special_token_mapping.get(
TABLE_STRUCTURE_BOS_TOKEN
),
}
if self.image_token_id is None:
logger.warning("Warning: Image token not found in special tokens")
self.blank_bbox_token_id = blank_bbox_token_id
self.bbox_pad_token_id = self.blank_bbox_token_id
self.ignore_bbox_token_ids = [
v
for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
if k not in self.ocr_tokenizer.special_tokens["math_external"]
]
math_end_token = ""
self.math_start_token_ids = [
v
for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
if k in self.ocr_tokenizer.special_tokens["math_external"]
and k != math_end_token
]
self.math_end_token_ids = [
v
for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
if k == math_end_token
]
if self.num_register_tokens > len(self.register_token_ids):
raise ValueError(
"The number of register tokens requested exceeds the number of register tokens defined in the special token mapping."
)
self.image_mean = np.array(self.image_mean, dtype=np.float32)
self.image_std = np.array(self.image_std, dtype=np.float32)
self.model_device = model_device
@property
def vocab_size(self):
return self.tokenizer_vocab_size
def image_processor(self, image: Image.Image) -> np.ndarray:
# Convert to array
image = np.asarray(image, dtype=np.float32)
return image
@staticmethod
def scale_to_fit(
img: np.ndarray,
max_size: Tuple[int, int],
min_size: Tuple[int, int] = (168, 168),
):
# Get current dimensions
height, width = img.shape[:2]
# Check for empty or invalid image
if width == 0 or height == 0:
return img
max_width, max_height = max_size
min_width, min_height = min_size
# Calculate pixel counts
current_pixels = width * height
max_pixels = max_width * max_height
min_pixels = min_width * min_height
if current_pixels > max_pixels:
scale_factor = (max_pixels / current_pixels) ** 0.5
new_width = math.floor(width * scale_factor)
new_height = math.floor(height * scale_factor)
elif current_pixels == 0:
return img
elif current_pixels < min_pixels:
scale_factor = (min_pixels / current_pixels) ** 0.5
new_width = math.ceil(width * scale_factor)
new_height = math.ceil(height * scale_factor)
else:
return img
return cv2.resize(
img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4
)
def _image_processor(self, image: np.ndarray):
image = image.astype(np.float64) * self.rescale_factor
image = (image.astype(np.float32) - self.image_mean) / self.image_std
return image
def _process_and_tile(
self, image: np.ndarray
) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio
and returns a tensor of image tiles.
"""
extra_multipler = (
4 if settings.FOUNDATION_XLA else 1
) # Needed to force same size grid_thws per row with padding
factor = (
self.patch_size * self.merge_size * extra_multipler
) # Make a multiple of window size
height, width = image.shape[:2]
h_bar = math.ceil(height / factor) * factor
w_bar = math.ceil(width / factor) * factor
if h_bar != height or w_bar != width:
if height == 0 or width == 0:
image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8)
else:
image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC)
# Handle scaling and normalization
image = self._image_processor(image)
height, width = image.shape[:2]
# Numpy array to torch tensor
img_tensor = torch.from_numpy(image.transpose(2, 0, 1))
patches = img_tensor.unsqueeze(0)
channel = patches.shape[1]
grid_t = patches.shape[0]
grid_h, grid_w = height // self.patch_size, width // self.patch_size
patches = patches.reshape(
grid_t,
1,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size
)
return flatten_patches, (grid_t, grid_h, grid_w)
# Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly
def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput:
rotated = image_input.get("rotated", False)
image = image_input.get("image", None)
assert image is not None, (
"A PIL Image must be provided when the input type is `image`"
)
image_tiles, grid_thw = self._process_and_tile(image)
num_tokens = image_tiles.shape[0] / self.merge_size**2
assert num_tokens.is_integer(), (
f"Expected number of tokens to be an integer, got {num_tokens}"
)
input_ids = [self.image_token_id] * int(num_tokens)
input_ids += self.register_token_ids[: self.num_register_tokens]
# Handle the image being rotated in the imdataset
if rotated:
input_ids = [self.image_rotated_token] + input_ids
return ProcessorOutput(
input_ids=input_ids,
image_tiles=image_tiles,
grid_thw=grid_thw,
)
def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput:
input_text = text_input.get("text", None)
math_mode = text_input.get("math", False)
input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0]
input_ids = [self.offsets["ocr"] + id for id in input_ids]
# nomath token does not work for layout
if not math_mode and task != "layout":
input_ids.insert(0, self.nomath_token)
return ProcessorOutput(
input_ids=input_ids,
image_tiles=None,
grid_thw=None,
)
def _process_input(self, input_dict: dict, task: str):
input_type = input_dict["type"]
if input_type == "image":
return self._process_image_input(input_dict)
elif input_type == "text":
return self._process_text_input(input_dict, task)
raise NotImplementedError(f"Input of type `{input_type}` is not implemented")
# Peprocessing for OCR task
# The task is expected to have - image_dict, user_input_dict, output_dict
# use_input_dict is allowed to have an empty input which is fine, but needs to be present
def _process_ocr_with_boxes(
self,
mixed_input: List[dict],
bos_token_id: int,
task: str = TaskNames.ocr_with_boxes,
):
processed_input_ids = []
all_image_tiles = []
all_grid_thw = []
# 1. Process the image input
for i, input_dict in enumerate(mixed_input):
processor_output = self._process_input(input_dict, task)
input_ids = processor_output["input_ids"]
image_tiles = processor_output["image_tiles"]
grid_thw = processor_output["grid_thw"]
# Special handling of some delimiter tokens
if i == 1:
assert input_dict["type"] == "text", (
"Expected text input for model input."
)
# Case for input - Add task specific bos token + end_of_input token
# We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these
input_ids = [bos_token_id] + input_ids + [self.eoi_token_id]
if i == 2:
assert input_dict["type"] == "text", (
"Expected text for final model input"
)
input_ids = input_ids + [self.eos_token_id]
elif i > 2:
raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}")
# Some input types don't return any image tiles, accounting for that
if image_tiles is not None:
all_image_tiles.append(image_tiles)
all_grid_thw.append(grid_thw)
processed_input_ids.extend(input_ids)
return (
torch.tensor(processed_input_ids, dtype=torch.long),
all_image_tiles,
all_grid_thw,
)
def _process_layout(self, mixed_input: List[dict], bos_token_id: int):
return self._process_ocr_with_boxes(
mixed_input, bos_token_id=bos_token_id, task="layout"
)
def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int):
return self._process_ocr_with_boxes(
mixed_input, bos_token_id=bos_token_id, task="table_structure"
)
def _process_ocr_without_boxes(
self,
mixed_input: List[dict],
bos_token_id: int,
task: str = "ocr_without_boxes",
):
# Boxes are set to None, so this will work
# TODO: improve this behavior
return self._process_ocr_with_boxes(
mixed_input, bos_token_id=bos_token_id, task=task
)
def _process_block_without_boxes(
self,
mixed_input: List[dict],
bos_token_id: int,
task: str = "block_without_boxes",
):
return self._process_ocr_with_boxes(
mixed_input, bos_token_id=bos_token_id, task=task
)
def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]:
height, width, _ = image.shape
if height > width: # Rotate vertical lines
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
return image, True
return image, False
def __call__(
self,
mixed_batch: List[dict],
padding_side: Optional[str] = "left",
device: Optional[torch.device] = None,
pad_to_multiple: Optional[int] = None,
):
all_image_tiles = []
all_input_ids = []
all_grid_thw = []
for b in mixed_batch:
mixed_input = b["inputs"]
task = b["task"]
assert task in self.bos_token_id, f"Task {task} has no bos token defined."
# Select the correct processing function based on the task type
input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")(
mixed_input, self.bos_token_id[task]
)
all_input_ids.append(input_ids)
all_image_tiles.extend(image_tiles)
all_grid_thw.extend(grid_thw)
batched_input_ids = pad_sequence(
all_input_ids,
batch_first=True,
padding_side=padding_side,
padding_value=self.pad_token_id,
)
if pad_to_multiple is not None:
current_len = batched_input_ids.shape[1]
# Calculate the next multiple of pad_to_multiple
padded_len = (
(current_len + pad_to_multiple - 1) // pad_to_multiple
) * pad_to_multiple
if padded_len > current_len:
pad_len = padded_len - current_len
batched_input_ids = torch.nn.functional.pad(
batched_input_ids, (pad_len, 0), value=self.pad_token_id
)
attention_mask = batched_input_ids.ne(self.pad_token_id)
# Generating position IDs that are independent of left and right padding;
# This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked
position_ids = attention_mask.cumsum(dim=-1) - 1
position_ids[position_ids < 0] = (
0 # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0
)
position_ids = (
attention_mask.to(torch.long) * position_ids
) # Ensure right pad ids get set to zero
batched_image_tiles = torch.cat(all_image_tiles, dim=0)
batched_grid_thw = torch.from_numpy(np.array(all_grid_thw))
# Pin memory for CUDA
if device == torch.device("cuda"):
batched_image_tiles = batched_image_tiles.pin_memory()
batched_grid_thw = batched_grid_thw.pin_memory()
attention_mask = attention_mask.pin_memory()
batched_input_ids = batched_input_ids.pin_memory()
position_ids = position_ids.pin_memory()
return BatchFeature(
{
"input_ids": batched_input_ids,
"image_tiles": batched_image_tiles,
"attention_mask": attention_mask,
"position_ids": position_ids,
"grid_thw": batched_grid_thw,
}
)
# Decode model outputs; Strips special tokens
def decode(self, tokens: List[int], task: str):
filtered_tokens = [
t
for t in tokens
if t not in self.special_token_mapping.values() and t != -100
] # Skip special tokens and loss ignore index
return self.ocr_tokenizer.decode(filtered_tokens, task=task)
================================================
FILE: surya/common/surya/processor/schema.py
================================================
from typing import TypedDict, Literal, List, Tuple
import torch
from PIL import Image
class TaskDict(TypedDict):
datasets: List[str]
img_size: Tuple[int, int]
class TasksDict(TypedDict):
ocr_with_boxes: TaskDict
ocr_without_boxes: TaskDict
block_without_boxes: TaskDict
class ProcessorInput(TypedDict):
type: Literal["image", "ocr", "text", "empty_output"]
class ImageInput(ProcessorInput):
type: Literal["image"]
image: Image.Image
rotated: bool
class TextInput(ProcessorInput):
type: Literal["text"]
text: str
math: bool
class ProcessorOutput(TypedDict):
input_ids: List[int]
image_tiles: torch.Tensor | None
grid_thw: torch.Tensor | None
================================================
FILE: surya/common/surya/processor/tokenizer.py
================================================
import html
import re
from typing import List, Union, Dict, Optional, Tuple, Iterable
import numpy as np
import torch
from tokenizers import AddedToken
import json
import os
from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer
from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.schema import TASK_NAMES, TaskNames
from surya.logging import get_logger
from surya.settings import settings
logger = get_logger()
def create_token_regex(tokens):
escaped_tokens = [re.escape(token) for token in tokens]
escaped_tokens.sort(key=len, reverse=True)
pattern = r"^(" + "|".join(escaped_tokens) + r")"
regex = re.compile(pattern)
return regex
class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):
pass
class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer):
"""
HuggingFace slow tokenizer implementing:
- UTF-16 code units as the base [0..65535]
- Math tokens as greedy-longest-match ids after UTF-16
- Literal special tokens after math tokens
Absolute ID layout:
[0 .. 65535] : UTF-16 units
[65536 .. 65536+M-1] : math tokens
[65536+M .. 65536+M+S-1] : special tokens
"""
vocab_files_names = {
"vocab_file": "vocab_math.json", # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1
"specials_file": "specials.json", # [flat list for legacy]
"specials_dict_file": "specials_dict.json", # category dict (preferred)
}
model_input_names = ["input_ids", "attention_mask"]
is_fast = False
# ---------- helpers ----------
@staticmethod
def _to_utf16_units(s: str) -> List[int]:
b = s.encode("utf-16le")
return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)]
@staticmethod
def _from_utf16_units(units: List[int]) -> str:
b = bytearray()
for u in units:
b += int(u).to_bytes(2, "little")
return b.decode("utf-16le", errors="ignore")
class _TrieNode:
__slots__ = ("child", "id", "leaf")
def __init__(self):
self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {}
self.id: Optional[int] = None
self.leaf: bool = False
@classmethod
def _build_trie(
cls, token_to_id: Dict[str, int]
) -> "GreedyMathUTF16Tokenizer._TrieNode":
root = cls._TrieNode()
for tok, tid in token_to_id.items():
node = root
for ch in tok:
node = node.child.setdefault(ch, cls._TrieNode())
node.leaf = True
node.id = tid
return root
def _build_escape_patterns(self, math_token_to_rawid):
"""Build pattern list from vocab commands that start with control characters.
Scans the math vocab for LaTeX commands that could be corrupted by JSON
escape sequence interpretation (e.g., \\begin becomes egin).
"""
control_chars = {
'\x08': 'b', # backspace
'\t': 't', # tab
'\n': 'n', # newline
'\r': 'r', # carriage return
'\f': 'f', # form feed
'\x07': 'a', # bell
'\x0b': 'v', # vertical tab
}
patterns = {char: [] for char in control_chars}
for token in math_token_to_rawid.keys():
if token.startswith('\\') and len(token) > 1:
letter = token[1:2] # First char after backslash
for ctrl_char, ctrl_letter in control_chars.items():
if letter == ctrl_letter:
# This token could be corrupted: \token -> oken
suffix = token[2:] # Everything after \X
patterns[ctrl_char].append((suffix, token))
# Sort by length (longest first) to avoid partial matches
for char in patterns:
patterns[char].sort(key=lambda x: len(x[0]), reverse=True)
return patterns
@classmethod
def _encode_math_greedy(
cls,
s: str,
trie: "GreedyMathUTF16Tokenizer._TrieNode",
math_base: int,
debug: bool = False,
) -> List[int]:
i, n = 0, len(s)
out: List[int] = []
while i < n:
node = trie
j = i
last_id = None
last_j = i
while j < n and (ch := s[j]) in node.child:
node = node.child[ch]
j += 1
if node.leaf:
last_id, last_j = node.id, j
if last_id is not None:
if debug:
print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}")
out.append(math_base + last_id)
i = last_j
else:
units = cls._to_utf16_units(s[i])
if debug:
print(f"[MATH] fallback {s[i]!r} -> utf16 {units}")
out.extend(units)
i += 1
return out
# ---------- init ----------
def __init__(
self,
vocab_file: Optional[str] = None,
specials_file: Optional[str] = None,
specials_dict_file: Optional[str] = None,
*,
# You can also pass programmatically instead of files:
math_vocab: Optional[Dict[str, int]] = None,
special_tokens: Optional[List[str]] = None,
special_tokens_dict: Optional[Dict[str, List[str]]] = None,
debug: bool = False,
# Standard HF special token kwargs:
bos_token: Optional[str] = None,
eos_token: Optional[str] = None,
pad_token: Optional[str] = None,
unk_token: Optional[str] = None,
**kwargs,
):
# Load math vocab
if vocab_file and os.path.isfile(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
mv = json.load(f)
else:
mv = math_vocab or {}
# Make math ids contiguous if needed
if mv:
max_id = max(mv.values())
if set(mv.values()) != set(range(max_id + 1)):
items = sorted(mv.items(), key=lambda kv: kv[1])
mv = {tok: i for i, (tok, _) in enumerate(items)}
# Load special tokens (prefer category dict; fallback to flat list or defaults)
sp_dict = None
if specials_dict_file and os.path.isfile(specials_dict_file):
with open(specials_dict_file, "r", encoding="utf-8") as f:
sp_dict = json.load(f)
elif special_tokens_dict is not None:
sp_dict = dict(special_tokens_dict)
if sp_dict is None:
# Legacy path: flat list from file or provided/default list
if specials_file and os.path.isfile(specials_file):
with open(specials_file, "r", encoding="utf-8") as f:
sp_list_flat = json.load(f)
else:
sp_list_flat = special_tokens or SPECIAL_TOKENS
sp_dict = {"all": list(sp_list_flat)}
# Ensure "all" exists and is unique/preserved in order.
if "all" not in sp_dict or not isinstance(sp_dict["all"], list):
order = [
"system",
"formatting",
"math_external",
"script",
"layout",
"reasoning",
"table_structure",
"reserved",
]
seen = set()
all_tokens: List[str] = []
for k in order:
if k in sp_dict and isinstance(sp_dict[k], list):
for t in sp_dict[k]:
if t not in seen:
all_tokens.append(t)
seen.add(t)
sp_dict["all"] = all_tokens
# Keep a copy of categories (if present) for downstream processor logic.
self.special_tokens = sp_dict
sp_list = list(sp_dict.get("all", []))
# Regex list should favor longest-first to avoid partial matches.
specials_for_regex = sorted(sp_list, key=len, reverse=True)
self.debug = debug
self.UTF16_SPACE = 65536
self.math_token_to_rawid = dict(mv) # 0..M-1
self.math_vocab_size = len(self.math_token_to_rawid)
self.MATH_BASE = self.UTF16_SPACE
self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size
# Maps
self.math_absid_to_token = {
self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items()
}
self.special_tokens_list = sp_list # ID assignment order
self.special_to_absid = {
tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list)
}
self.absid_to_special = {v: k for k, v in self.special_to_absid.items()}
# Public attributes for legacy/processor:
# All specials mapping (token -> absolute id)
self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid)
# Subset used heavily by processor for quick access
self.reverse_special_token_mapping = {
v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
}
self.LAYOUT_LABEL2ID = {
k: v
for k, v in self.SPECIAL_TOKEN_MAPPING.items()
if k in self.special_tokens["layout"]
}
self.TABLE_STRUCTURE_LABEL2ID = {
k: v
for k, v in self.SPECIAL_TOKEN_MAPPING.items()
if k in self.special_tokens["table_structure"]
}
if not self.special_tokens.get("system", []):
print("Warning: No system tokens found in special_tokens")
self.MATH_TAG_START = "