Repository: Unstructured-IO/unstructured-inference
Branch: main
Commit: 56fadb3e9593
Files: 62
Total size: 14.6 MB
Directory structure:
gitextract_bzec1cqm/
├── .github/
│ ├── dependabot.yml
│ └── workflows/
│ ├── ci.yml
│ ├── claude.yml
│ ├── create_issue.yml
│ ├── release.yml
│ └── version-bump.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── benchmarks/
│ ├── __init__.py
│ └── test_benchmark_yolox.py
├── examples/
│ └── ocr/
│ ├── engine.py
│ ├── requirements.txt
│ └── validate_ocr_performance.py
├── logger_config.yaml
├── pyproject.toml
├── renovate.json
├── sample-docs/
│ └── loremipsum.tiff
├── scripts/
│ ├── docker-build.sh
│ ├── shellcheck.sh
│ ├── test-unstructured-ingest-helper.sh
│ └── version-sync.sh
├── test_unstructured_inference/
│ ├── conftest.py
│ ├── inference/
│ │ ├── test_layout.py
│ │ ├── test_layout_element.py
│ │ └── test_layout_rotation.py
│ ├── models/
│ │ ├── test_detectron2onnx.py
│ │ ├── test_eval.py
│ │ ├── test_model.py
│ │ ├── test_tables.py
│ │ └── test_yolox.py
│ ├── test_config.py
│ ├── test_elements.py
│ ├── test_logger.py
│ ├── test_math.py
│ ├── test_utils.py
│ └── test_visualization.py
└── unstructured_inference/
├── __init__.py
├── __version__.py
├── config.py
├── constants.py
├── inference/
│ ├── __init__.py
│ ├── elements.py
│ ├── layout.py
│ ├── layoutelement.py
│ └── pdf_image.py
├── logger.py
├── math.py
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── detectron2onnx.py
│ ├── eval.py
│ ├── table_postprocess.py
│ ├── tables.py
│ ├── unstructuredmodel.py
│ └── yolox.py
├── utils.py
└── visualize.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
- package-ecosystem: "uv"
directory: "/"
schedule:
interval: "monthly"
- package-ecosystem: "github-actions"
# NOTE(robinson) - Workflow files stored in the
# default location of `.github/workflows`
directory: "/"
schedule:
interval: "monthly"
================================================
FILE: .github/workflows/ci.yml
================================================
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
lint:
runs-on: opensource-linux-8core
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install lint dependencies
run: make install-lint
- name: Lint
run: make check
shellcheck:
runs-on: opensource-linux-8core
steps:
- uses: actions/checkout@v4
- name: ShellCheck
uses: ludeeus/action-shellcheck@master
test:
runs-on: opensource-linux-8core
needs: lint
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get -y install poppler-utils tesseract-ocr
- name: Install dependencies
run: make install
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-2
- name: Test
env:
UNSTRUCTURED_HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
aws s3 cp s3://utic-dev-models/ci_test_model/test_ci_model.onnx test_unstructured_inference/models/
CI=true make test
make check-coverage
changelog:
runs-on: opensource-linux-8core
steps:
- uses: actions/checkout@v4
- if: github.ref != 'refs/heads/main'
uses: dorny/paths-filter@v2
id: changes
with:
filters: |
src:
- 'unstructured_inference/**'
- if: steps.changes.outputs.src == 'true' && github.ref != 'refs/heads/main'
uses: dangoslen/changelog-enforcer@v3
================================================
FILE: .github/workflows/claude.yml
================================================
name: Claude Code
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
issues:
types: [opened, assigned]
pull_request_review:
types: [submitted]
jobs:
claude:
if: |
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: read
issues: read
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@beta
with:
anthropic_api_key: ${{ secrets.GH_ANTHROPIC_API_KEY }}
allowed_tools: "Bash(git:*),View,GlobTool,GrepTool,BatchTool"
================================================
FILE: .github/workflows/create_issue.yml
================================================
name: create_jira_issue
on:
issues:
types:
- opened
jobs:
create:
runs-on: ubuntu-latest
name: Create JIRA Issue
steps:
- name: Login to Jira
uses: atlassian/gajira-login@v3
env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
- name: Create Jira issue
uses: atlassian/gajira-create@v3
with:
project: CORE
issuetype: Task
summary: ${{ github.event.issue.title }}
description: |
Created from github issue: ${{ github.event.issue.html_url }}
----
${{ github.event.issue.body }}
fields: '{ "labels": ["github-issue"] }'
- name: Log created issue
run: echo "Issue ${{ steps.create.outputs.issue }} was created"
================================================
FILE: .github/workflows/release.yml
================================================
name: Release
on:
release:
types: [published]
permissions:
contents: read
id-token: write # Required for PyPI trusted publishing / attestations
concurrency:
group: release
cancel-in-progress: false
jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Install uv
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
with:
enable-cache: true
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: "3.12"
- name: Verify tag matches package version
run: |
PKG_VERSION=$(python -c "exec(open('unstructured_inference/__version__.py').read()); print(__version__)")
TAG_VERSION="${GITHUB_REF_NAME#v}"
if [ "$PKG_VERSION" != "$TAG_VERSION" ]; then
echo "::error::Tag ($TAG_VERSION) does not match package version ($PKG_VERSION)"
exit 1
fi
- name: Install release dependencies
run: uv sync --locked --only-group release --no-install-project
- name: Build package
id: build
run: uv build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1
# Best-effort: attempt Azure upload even if PyPI fails, but only if build succeeded.
# continue-on-error allows the workflow to pass when Azure secrets are not configured.
- name: Publish to Azure Artifacts
if: always() && steps.build.outcome == 'success'
continue-on-error: true
run: |
uv run --no-sync twine upload \
--repository-url "${{ secrets.AZURE_ARTIFACTS_FEED }}" \
--username "${{ secrets.AZURE_ARTIFACTS_USERNAME }}" \
--password "${{ secrets.AZURE_ARTIFACTS_PAT }}" \
dist/*
================================================
FILE: .github/workflows/version-bump.yml
================================================
name: Version Bump
on:
pull_request:
branches: [main]
types: [opened, synchronize, reopened]
permissions:
contents: write
pull-requests: read
jobs:
version-bump:
if: github.event.pull_request.user.login == 'utic-renovate[bot]'
uses: Unstructured-IO/infra/.github/workflows/version-bump.yml@main
with:
component-paths: '["."]'
default-bump: patch
update-changelog: true
update-lockfile: true
renovate-app-id: ${{ vars.RENOVATE_APP_ID }}
secrets:
token: ${{ secrets.GITHUB_TOKEN }}
private-pypi-url: ${{ secrets.PRIVATE_PYPI_INDEX_URL }}
renovate-app-private-key: ${{ secrets.RENOVATE_APP_PRIVATE_KEY }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
nbs/
# IPython
profile_default/
ipython_config.py
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# Pycharm
.idea/
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Model artifacts
.models/*
!.models/.gitkeep
# Mac stuff
.DS_Store
# VSCode
.vscode/
sample-docs/*_images
examples/**/output
figures
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v5.0.0"
hooks:
- id: check-added-large-files
- id: check-toml
- id: check-yaml
- id: check-json
- id: check-xml
- id: end-of-file-fixer
exclude: \.json$
include: \.py$
- id: trailing-whitespace
- id: mixed-line-ending
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.15.0"
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format
================================================
FILE: CHANGELOG.md
================================================
## 1.6.11
### Enhancement
- Add `table_extraction_method` field to `LayoutElements` and `LayoutElement` to track which algorithm produced a table (grid, tatr, vlm).
## 1.6.10
### Enhancement
- Add Python 3.13 support.
## 1.6.9
### Enhancement
- Restore support for Python 3.11 alongside Python 3.12.
## 1.6.8
### Fix
- Reject PDF pages that would render beyond the configured pixel limit before
allocating the page bitmap.
## 1.6.7
### Fix
- `get_model` now materializes `LazyDict` model configs into a plain dict before
unpacking into `initialize(**...)`. Uses `__iter__` + `__getitem__` to avoid
depending on `Mapping.keys()`, which has been observed to fail at `**`
unpacking with "argument after ** must be a mapping, not LazyDict" in some
deployment environments.
## 1.6.6
### Enhancement
- Relax the lower bound of the pandas and numpy dependency
## 1.6.5
### Enhancement
- Store `pdf_rotation` in `page.image_metadata` so downstream consumers can check page rotation after the page image is freed
- Add targeted unittest coverage for PDF page rotation handling in `convert_pdf_to_image`
- Speed up the targeted rotation unittest by isolating the PDF image conversion surface into a lightweight module and mocking the PDFium rendering path for the timing-critical test
## 1.6.4
### Fix
- Apply PDF `/Rotate` metadata during page rendering - pypdfium2's `page.render()` ignores the flag, producing sideways images for rotated pages
## 1.6.3
### Security
- **security:** fix(deps): upgrade vulnerable transitive dependencies [security]
## 1.6.2
### Enhancement
- Make `dpi` an explicit parameter on `convert_pdf_to_image` (default 200) instead of reading from config internally, enabling unstructured to use this as the single source of truth for PDF rendering
## 1.6.1
### Enhancement
- Free intermediate arrays (`origin_img`, `img`, `ort_inputs`, `output`) and PIL pixel buffer at dead points during YoloX `image_processing()` to reduce peak memory during inference
## 1.6.0
### Fix
- Relax `huggingface-hub` lower bound from `>=1.4.1` to `>=0.22.0` (the `>=1.4.1` was an artifact of the uv migration and broke compatibility with `transformers<5.0`)
## 1.5.5
### Enhancement
- Lazy page rendering in `convert_pdf_to_image` to reduce peak memory from O(N pages) to O(1 page)
## 1.5.4
### Enhancement
- Use `np.full()` instead of `np.ones() * scalar` in YoloX preprocessing to avoid a redundant temporary array
## 1.5.3
- Store routing in LayoutElement
## 1.5.2
### Fix
- Switch to PyPI trusted publishing (OIDC) and remove API token auth
## 1.5.1
### Fix
- Add `id-token: write` permission to release workflow for PyPI attestations
## 1.5.0
### Enhancement
- Automate PyPI and Azure Artifacts publishing via GitHub release workflow
- Replace `--frozen` with `--locked` across Makefile and Dockerfile for stricter lockfile validation
- Add `release` dependency group with `twine` for Azure Artifacts upload
- Constrain pillow to >=12.1.1 to address CVE for out-of-bounds write when loading PSD images
## 1.4.0
### Enhancement
- Switch CI runners to `opensource-linux-8core` for faster builds
- Add pytest-xdist parallelization (`-n auto`) to `docker-test` target
- Remove mypy from lint pipeline; ruff covers linting needs sufficiently
- Add `install-lint` target; CI lint job no longer downloads full project dependencies
## 1.3.0
### Enhancement
- Migrate project to native uv with hatchling build backend
- Consolidate all configuration into pyproject.toml
- Replace pip/requirements workflow with uv sync/lock
- Parallelize test runs with pytest-xdist (`-n auto`)
### Breaking
- Drop support for Python 3.10 and 3.11; require Python >=3.12, <3.13
## 1.2.0
### Enhancement
- **Per-model locks for parallel model loading**: Replace single global lock with per-model locks
- Allows concurrent loading of different models (detectron2, yolox, etc.)
- 10x+ concurrency improvement in multi-model environments
- Maintains thread-safe initialization with double-check pattern
- Backward compatible - no API changes
## 1.1.9
### Fix
- **TableTransformer device_map fix**: Remove device_map parameter to prevent meta tensor errors
- Device normalization (cuda -> cuda:0) for consistent caching
- Load models without device_map, use explicit .to(device, dtype=torch.float32)
- Fixes concurrent PDF processing AssertionError
- Prevents "Trying to set a tensor of type Float but got Meta" errors
- Use context manager for `pdfium.PdfDocument`
## 1.1.8
- put `pdfium` call behind a thread lock
## 1.1.7
- Update OpenCV-Python to 4.13.0.90 to squash ffmpeg vulnerability CVE-2023-6605
## 1.1.6
- Use inference_config to set default rendering DPI
## 1.1.5
- Render PDF to image using PyPDFium instead of pdf2image, due to much improved performance for certain docs
## 1.1.4
- Constrain urllib3 to urllib3>=2.6.0 to address CVE-2025-66471 and CVE-2025-66418
## 1.1.3
- Constrain fonttools to >=4.60.2 to address CVE-2025-66034
## 1.1.2
* chore(deps): Bump several depedencies to resolve open high CVEs
* fix: Exclude pip and setuptools pinning based on cursor comment
* fix: With the newer version of transformers 4.57.1, the type checking became stricter, and mypy correctly flagged that DetrImageProcessor.from_pretrained() expects str | PathLike[Any], not a model object.
* fix: Update test to explicitly cast numpy array to uint8 for Pillow 12.0.0 compatibility
## 1.1.1
* Add NotImplementedError when trying to single index a TextRegions, reflecting the fact that it won't behave correctly at the moment.
## 1.1.0
* Enhancement: Add `TextSource` to track where the text of an element came from
* Enhancement: Refactor `__post_init__` of `TextRegions` and `LayoutElement` slightly to automate initialization
## 1.0.10
* Remove merging logic that's no longer used
## 1.0.9
* Make OD model loading thread safe
## 1.0.8
* Enhancement: Optimized `zoom_image` (codeflash)
* Enhancement: Optimized `cells_to_html` for an 8% speedup in some cases (codeflash)
* Enhancement: Optimized `outputs_to_objects` for an 88% speedup in some cases (codeflash)
## 1.0.7
* Fix a hardcoded file extension causing confusion in the logs
## 1.0.6
* Add slicing through indexing for vectorized elements
## 1.0.5
* feat: add thread lock to prevent racing condition when instantiating singletons
* feat: parametrize edge config for `DetrImageProcessor` with env variables
## 1.0.4
* feat: use singleton instead of `global` to store shared variables
## 1.0.3
* setting longest_edge=1333 to the table image processor
## 1.0.2
* adding parameter to table image preprocessor related to the image size
## 1.0.1
* fix: moving the table transformer model to device when loading the model instead of once the model is loaded.
## 1.0.0
* feat: support for Python 3.10+; drop support for Python 3.9
## 0.8.11
* feat: remove `donut` model
## 0.8.10
* feat: unpin `numpy` and bump minimum for `onnxruntime` to be compatible with `numpy>=2`
## 0.8.9
* chore: unpin `pdfminer-six` version
## 0.8.8
* fix: pdfminer-six dependencies
* feat: `PageLayout.elements` is now a `cached_property` to reduce unecessary memory and cpu costs
## 0.8.7
* fix: add `password` for PDF
## 0.8.6
* feat: add back `source` to `TextRegions` and `LayoutElements` for backward compatibility
## 0.8.5
* fix: remove `pdfplumber` but include `pdfminer-six==20240706` to update `pdfminer`
## 0.8.4
* feat: add `text_as_html` and `table_as_cells` to `LayoutElements` class as new attributes
* feat: replace the single valueed `source` attribute from `TextRegions` and `LayoutElements` with an array attribute `sources`
## 0.8.3
* fix: removed `layoutelement.from_lp_textblock()` and related tests as it's not used
* fix: update requirements to drop `layoutparser` lib
* fix: update `README.md` to remove layoutparser model zoo support note
## 0.8.2
* fix: fix bug when an empty list is passed into `TextRegions.from_list` triggers `IndexError`
* fix: fix bug when concatenate a list of `LayoutElements` the class id mapping is no properly
updated
## 0.8.1
* fix: fix list index out of range error caused by calling LayoutElements.from_list() with empty list
## 0.8.0
* fix: fix missing source after cleaning layout elements
* **BREAKING** Remove chipper model
## 0.7.41
* fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array
* fix: fix a bug where class id 0 becomes class type `None` when calling `LayoutElements.as_list()`
## 0.7.40
* fix: store probabilities with `float` data type instead of `int`
## 0.7.39
* fix: Correctly assign mutable default value to variable in `LayoutElements` class
## 0.7.38
* fix: Correctly assign mutable default value to variable in `TextRegions` class
## 0.7.37
* refactor: remove layout analysis related code
* enhancement: Hide warning about table transformer weights not being loaded
* fix(layout): Use TemporaryDirectory instead of NamedTemporaryFile for Windows support
* refactor: use `numpy` array to store layout elements' information in one single `LayoutElements`
object instead of using a list of `LayoutElement`
## 0.7.36
fix: add input parameter validation to `fill_cells()` when converting cells to html
## 0.7.35
Fix syntax for generated HTML tables
## 0.7.34
* Reduce excessive logging
## 0.7.33
* BREAKING CHANGE: removes legacy detectron2 model
* deps: remove layoutparser optional dependencies
## 0.7.32
* refactor: remove all code related to filling inferred elements text from embedded text (pdfminer).
* bug: set the Chipper max_length variable
## 0.7.31
* refactor: remove all `cid` related code that was originally added to filter out invalid `pdfminer` text
* enhancement: Wrapped hf_hub_download with a function that checks for local file before checking HF
## 0.7.30
* fix: table transformer doesn't return multiple cells with same coordinates
*
## 0.7.29
* fix: table transformer predictions are now removed if confidence is below threshold
## 0.7.28
* feat: allow table transformer agent to return table prediction in not parsed format
## 0.7.27
* fix: remove pin from `onnxruntime` dependency.
## 0.7.26
* feat: add a set of new `ElementType`s to extend future element types recognition
* feat: allow registering of new models for inference using `unstructured_inference.models.base.register_new_model` function
## 0.7.25
* fix: replace `Rectangle.is_in()` with `Rectangle.is_almost_subregion_of()` when filling in an inferred element with embedded text
* bug: check for None in Chipper bounding box reduction
* chore: removes `install-detectron2` from the `Makefile`
* fix: convert label_map keys read from os.environment `UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH` to int type
* feat: removes supergradients references
## 0.7.24
* fix: assign value to `text_as_html` element attribute only if `text` attribute contains HTML tags.
## 0.7.23
* fix: added handling in `UnstructuredTableTransformerModel` for if `recognize` returns an empty
list in `run_prediction`.
## 0.7.22
* fix: add logic to handle computation of intersections betwen 2 `Rectangle`s when a `Rectangle` has `None` value in its coordinates
## 0.7.21
* fix: fix a bug where chipper, or any element extraction model based `PageLayout` object, lack `image_metadata` and other attributes that are required for downstream processing; this fix also reduces the memory overhead of using chipper model
## 0.7.20
* chipper-v3: improved table prediction
## 0.7.19
* refactor: remove all OCR related code
## 0.7.18
* refactor: remove all image extraction related code
## 0.7.17
* refactor: remove all `pdfminer` related code
* enhancement: improved Chipper bounding boxes
## 0.7.16
* bug: Allow supplied ONNX models to use label_map dictionary from json file
## 0.7.15
* enhancement: Enable env variables for model definition
## 0.7.14
* enhancement: Remove Super-Gradients Dependency and Allow General Onnx Models Instead
## 0.7.13
* refactor: add a class `ElementType` for the element type constants and use the constants to replace element type strings
* enhancement: support extracting elements with types `Picture` and `Figure`
* fix: update logger in table initalization where the logger info was not showing
* chore: supress UserWarning about specified model providers
## 0.7.12
* change the default model to yolox, as table output appears to be better and speed is similar to `yolox_quantized`
## 0.7.11
* chore: remove logger info for chipper since its private
* fix: update broken slack invite link in chipper logger info
* enhancement: Improve error message when # images extracted doesn't match # page layouts.
* fix: use automatic mixed precision on GPU for Chipper
* fix: chipper Table elements now match other layout models' Table element format: html representation is stored in `text_as_html` attribute and `text` attribute stores text without html tags
## 0.7.10
* Handle kwargs explicitly when needed, suppress otherwise
* fix: Reduce Chipper memory consumption on x86_64 cpus
* fix: Skips ordering elements coming from Chipper
* fix: After refactoring to introduce Chipper, annotate() wasn't able to show text with extra info from elements, this is fixed now.
* feat: add table cell and dataframe output formats to table transformer's `run_prediction` call
* breaking change: function `unstructured_inference.models.tables.recognize` no longer takes `out_html` parameter and it now only returns table cell data format (lists of dictionaries)
## 0.7.9
* Allow table model to accept optional OCR tokens
## 0.7.8
* Fix: include onnx as base dependency.
## 0.7.7
• Fix a memory leak in DonutProcessor when using large images in numpy format
• Set the right settings for beam search size > 1
• Fix a bug that in very rare cases made the last element predicted by Chipper to have a bbox = None
## 0.7.6
* fix a bug where invalid zoom factor lead to exceptions; now invalid zoom factors results in no scaling of the image
## 0.7.5
* Improved packaging
## 0.7.4
* Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear.
* Fixed bug when PDFMiner predicts that an image text occupies the full page and removes annotations by Chipper.
* Added random seed to Chipper text generation to avoid differences between calls to Chipper.
* Allows user to use super-gradients model if they have a callback predict function, a yaml file with names field corresponding to classes and a path to the model weights
## 0.7.3
* Integration of Chipperv2 and additional Chipper functionality, which includes automatic detection of GPU,
bounding box prediction and hierarchical representation.
* Remove control characters from the text of all layout elements
## 0.7.2
* Sort elements extracted by `pdfminer` to get consistent result from `aggregate_by_block()`
## 0.7.1
* Download yolox_quantized from HF
## 0.7.0
* Remove all OCR related code expect the table OCR code
## 0.6.6
* Stop passing ocr_languages parameter into paddle to avoid invalid paddle language code error, this will be fixed until
we have the mapping from standard language code to paddle language code.
## 0.6.5
* Add functionality to keep extracted image elements while merging inferred layout with extracted layout
* Fix `source` property for elements generated by pdfminer.
* Add 'OCR-tesseract' and 'OCR-paddle' as sources for elements generated by OCR.
## 0.6.4
* add a function to automatically scale table crop images based on text height so the text height is optimum for `tesseract` OCR task
* add the new image auto scaling parameters to `config.py`
## 0.6.3
* fix a bug where padded table structure bounding boxes are not shifted back into the original image coordinates correctly
## 0.6.2
* move the confidence threshold for table transformer to config
## 0.6.1
* YoloX_quantized is now the default model. This models detects most diverse types and detect tables better than previous model.
* Since detection models tend to nest elements inside others(specifically in Tables), an algorithm has been added for reducing this
behavior. Now all the elements produced by detection models are disjoint and they don't produce overlapping regions, which helps
reduce duplicated content.
* Add `source` property to our elements, so you can know where the information was generated (OCR or detection model)
## 0.6.0
* add a config class to handle parameter configurations for inference tasks; parameters in the config class can be set via environement variables
* update behavior of `pad_image_with_background_color` so that input `pad` is applied to all sides
## 0.5.31
* Add functionality to extract and save images from the page
* Add functionality to get only "true" embedded images when extracting elements from PDF pages
* Update the layout visualization script to be able to show only image elements if need
* add an evaluation metric for table comparison based on token similarity
* fix paddle unit tests where `make test` fails since paddle doesn't work on M1/M2 chip locally
## 0.5.28
* add env variable `ENTIRE_PAGE_OCR` to specify using paddle or tesseract on entire page OCR
## 0.5.27
* table structure detection now pads the input image by 25 pixels in all 4 directions to improve its recall
## 0.5.26
* support paddle with both cpu and gpu and assumed it is pre-installed
## 0.5.25
* fix a bug where `cells_to_html` doesn't handle cells spanning multiple rows properly
## 0.5.24
* remove `cv2` preprocessing step before OCR step in table transformer
## 0.5.23
* Add functionality to bring back embedded images in PDF
## 0.5.22
* Add object-detection classification probabilities to LayoutElement for all currently implemented object detection models
## 0.5.21
* adds `safe_division` to replae 0 with machine epsilon for `float` to avoid division by 0
* apply `safe_division` to area overlap calculations in `unstructured_inference/inference/elements.py`
## 0.5.20
* Adds YoloX quantized model
## 0.5.19
* Add functionality to supplement detected layout with elements from the full page OCR
* Add functionality to annotate any layout(extracted, inferred, OCR) on a page
## 0.5.18
* Fix for incorrect type assignation at ingest test
## 0.5.17
* Use `OMP_THREAD_LIMIT` to improve tesseract performance
## 0.5.16
* Fix to no longer create a directory for storing processed images
* Hot-load images for annotation
## 0.5.15
* Handle an uncaught TesseractError
## 0.5.14
* Add TIFF test file and TIFF filetype to `test_from_image_file` in `test_layout`
## 0.5.13
* Fix extracted image elements being included in layout merge
## 0.5.12
* Add multipage TIFF extraction support
* Fix a pdfminer error when using `process_data_with_model`
## 0.5.11
* Add warning when chipper is used with < 300 DPI
* Use None default for dpi so defaults can be properly handled upstream
## 0.5.10
* Implement full-page OCR
## 0.5.9
* Handle exceptions from Tesseract
## 0.5.8
* Add alternative architecture for detectron2 (but default is unchanged)
* Updates:
| Library | From | To |
|---------------|-----------|----------|
| transformers | 4.29.2 | 4.30.2 |
| opencv-python | 4.7.0.72 | 4.8.0.74 |
| ipython | 8.12.2 | 8.14.0 |
* Cache named models that have been loaded
## 0.5.7
* hotfix to handle issue storing images in a new dir when the pdf has no file extension
## 0.5.6
* Update the `annotate` and `_get_image_array` methods of `PageLayout` to get the image from the `image_path` property if the `image` property is `None`.
* Add functionality to store pdf images for later use.
* Add `image_metadata` property to `PageLayout` & set `page.image` to None to reduce memory usage.
* Update `DocumentLayout.from_file` to open only one image.
* Update `load_pdf` to return either Image objects or Image paths.
* Warns users that Chipper is a beta model.
* Exposed control over dpi when converting PDF to an image.
* Updated detectron2 version to avoid errors related to deprecated PIL reference
## 0.5.5
* Rename large model to chipper
* Added functionality to write images to computer storage temporarily instead of keeping them in memory for `pdf2image.convert_from_path`
* Added functionality to convert a PDF in small chunks of pages at a time for `pdf2image.convert_from_path`
* Table processing check for the area of the package to fix division by zero bug
* Added CUDA and TensorRT execution providers for yolox and detectron2onnx model.
* Warning for onnx version of detectron2 for empty pages suppresed.
## 0.5.4
* Tweak to element ordering to make it more deterministic
## 0.5.3
* Refactor for large model
## 0.5.2
* Combine inferred elements with extracted elements
* Add ruff to keep code consistent with unstructured
* Configure fallback for OCR token if paddleocr doesn't work to use tesseract
## 0.5.1
* Add annotation for pages
* Store page numbers when processing PDFs
* Hotfix to handle inference of blank pages using ONNX detectron2
* Revert ordering change to investigate examples of misordering
## 0.5.0
* Preserve image format in PIL.Image.Image when loading
* Added ONNX version of Detectron2 and make default model
* Remove API code, we don't serve this as a standalone API any more
* Update ordering logic to account for multicolumn documents.
## 0.4.4
* Fixed patches not being a package.
## 0.4.3
* Patch pdfminer.six to fix parsing bug
## 0.4.2
* Output of table extraction is now stored in `text_as_html` property rather than `text` property
## 0.4.1
* Added the ability to pass `ocr_languages` to the OCR agent for users who need
non-English language packs.
## 0.4.0
* Added logic to partition granular elements (words, characters) by proximity
* Text extraction is now delegated to text regions rather than being handled centrally
* Fixed embedded image coordinates being interpreted differently than embedded text coordinates
* Update to how dependencies are being handled
* Update detectron2 version
## 0.3.2
* Allow extracting tables from higher level functions
## 0.3.1
* Pin protobuf version to avoid errors
* Make paddleocr an extra again
## 0.3.0
* Fix for text block detection
* Add paddleocr dependency to setup for x86_64 machines
## 0.2.14
* Suppressed processing progress bars
## 0.2.13
* Add table processing
* Change OCR logic to be aware of PDF image elements
## 0.2.12
* Fix for processing RGBA images
## 0.2.11
* Fixed some cases where image elements were not being OCR'd
## 0.2.10
* Removed control characters from tesseract output
## 0.2.9
* Removed multithreading from OCR (DocumentLayout.get_elements_from_layout)
## 0.2.8
* Refactored YoloX inference code to integrate better with framework
* Improved testing time
## 0.2.7
* Fixed duplicated load_pdf call
## 0.2.6
* Add donut model script for image prediction
* Add sample receipt and test for donut prediction
## 0.2.5
* Add YoloX model for images and PDFs
* Add generic model interface
## 0.2.4
* Download default model from huggingface
* Clarify error when trying to open file that doesn't exist as an image
## 0.2.3
* Pins the version of `opencv-python` for linux compatibility
## 0.2.2
* Add capability to process image files
* Add logic to use OCR when layout text is full of unknown characters
## 0.2.1
* Refactor to facilitate local inference
* Removes BasicConfig from logger configuration
* Implement auto model downloading
## 0.2.0
* Initial release of unstructured-inference
================================================
FILE: Dockerfile
================================================
# syntax=docker/dockerfile:experimental
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim AS base
# Set up environment
ENV HOME=/home/
WORKDIR ${HOME}
RUN mkdir ${HOME}/.ssh && chmod go-rwx ${HOME}/.ssh \
&& ssh-keyscan -t rsa github.com >> /home/.ssh/known_hosts
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
FROM base AS deps
# Copy project files needed for dependency resolution
COPY pyproject.toml uv.lock ./
COPY unstructured_inference/__version__.py unstructured_inference/__version__.py
RUN uv sync --locked --all-groups --no-install-project
# Ensure venv binaries are on PATH so pytest/etc. are directly accessible
ENV PATH="/home/.venv/bin:${PATH}"
FROM deps AS code
COPY unstructured_inference unstructured_inference
RUN uv sync --locked --all-groups
CMD ["/bin/bash"]
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
PACKAGE_NAME := unstructured_inference
CURRENT_DIR := $(shell pwd)
.PHONY: help
help: Makefile
@sed -n 's/^\(## \)\([a-zA-Z]\)/\2/p' $<
###########
# Install #
###########
## install: install all dependencies via uv
.PHONY: install
install:
@uv sync --locked --all-groups
## install-lint: install only lint dependencies (no project deps)
.PHONY: install-lint
install-lint:
@uv sync --locked --only-group lint
## lock: update and lock all dependencies
.PHONY: lock
lock:
@uv lock --upgrade
#################
# Test and Lint #
#################
export CI ?= false
## test: runs all unittests (excluding slow)
.PHONY: test
test:
CI=$(CI) uv run --locked --no-sync pytest -n auto -m "not slow" test_${PACKAGE_NAME} --cov=${PACKAGE_NAME} --cov-report term-missing
## test-slow: runs all unittests (including slow)
.PHONY: test-slow
test-slow:
CI=$(CI) uv run --locked --no-sync pytest -n auto test_${PACKAGE_NAME} --cov=${PACKAGE_NAME} --cov-report term-missing
## check: runs all linters and checks
.PHONY: check
check: check-ruff check-version
## check-ruff: runs ruff linter
.PHONY: check-ruff
check-ruff:
uv run --locked --no-sync ruff check .
uv run --locked --no-sync ruff format --check .
## check-scripts: run shellcheck
.PHONY: check-scripts
check-scripts:
scripts/shellcheck.sh
## check-version: run check to ensure version in CHANGELOG.md matches version in package
.PHONY: check-version
check-version:
# Fail if syncing version would produce changes
scripts/version-sync.sh -c \
-s CHANGELOG.md \
-f ${PACKAGE_NAME}/__version__.py semver
## tidy: auto-format and fix lint issues
.PHONY: tidy
tidy:
uv run --locked --no-sync ruff format .
uv run --locked --no-sync ruff check --fix-only --show-fixes .
## version-sync: update __version__.py with most recent version from CHANGELOG.md
.PHONY: version-sync
version-sync:
scripts/version-sync.sh \
-s CHANGELOG.md \
-f ${PACKAGE_NAME}/__version__.py semver
## check-coverage: check test coverage meets threshold
.PHONY: check-coverage
check-coverage:
uv run --locked --no-sync coverage report --fail-under=90
##########
# Docker #
##########
DOCKER_IMAGE ?= unstructured-inference:dev
.PHONY: docker-build
docker-build:
DOCKER_IMAGE=${DOCKER_IMAGE} ./scripts/docker-build.sh
.PHONY: docker-test
docker-test: docker-build
docker run --rm \
-v ${CURRENT_DIR}/test_unstructured_inference:/home/test_unstructured_inference \
-v ${CURRENT_DIR}/sample-docs:/home/sample-docs \
$(DOCKER_IMAGE) \
bash -c "pytest -n auto $(if $(TEST_NAME),-k $(TEST_NAME),) test_unstructured_inference"
================================================
FILE: README.md
================================================
Open-Source Pre-Processing Tools for Unstructured Data
The `unstructured-inference` repo contains hosted model inference code for layout parsing models.
These models are invoked via API as part of the partitioning bricks in the `unstructured` package.
**Requires Python >=3.11, <3.14.**
## Installation
### Package
```shell
pip install unstructured-inference
```
### Detectron2
[Detectron2](https://github.com/facebookresearch/detectron2) is required for using models from the [layoutparser model zoo](#using-models-from-the-layoutparser-model-zoo)
but is not automatically installed with this package.
For MacOS and Linux, build from source with:
```shell
pip install 'git+https://github.com/facebookresearch/detectron2.git@57bdb21249d5418c130d54e2ebdc94dda7a4c01a'
```
Other install options can be found in the
[Detectron2 installation guide](https://detectron2.readthedocs.io/en/latest/tutorials/install.html).
Windows is not officially supported by Detectron2, but some users are able to install it anyway.
See discussion [here](https://layout-parser.github.io/tutorials/installation#for-windows-users) for
tips on installing Detectron2 on Windows.
### Development Setup
This project uses [uv](https://docs.astral.sh/uv/) for dependency management.
```shell
# Clone and install all dependencies (including dev/test/lint groups)
git clone https://github.com/Unstructured-IO/unstructured-inference.git
cd unstructured-inference
make install
```
Run `make help` for a full list of available targets.
## Getting Started
To get started with the layout parsing model, use the following commands:
```python
from unstructured_inference.inference.layout import DocumentLayout
layout = DocumentLayout.from_file("sample-docs/loremipsum.pdf")
print(layout.pages[0].elements)
```
Once the model has detected the layout and OCR'd the document, the text extracted from the first
page of the sample document will be displayed.
You can convert a given element to a `dict` by running the `.to_dict()` method.
## Models
The inference pipeline operates by finding text elements in a document page using a detection model, then extracting the contents of the elements using direct extraction (if available), OCR, and optionally table inference models.
We offer several detection models including [Detectron2](https://github.com/facebookresearch/detectron2) and [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX).
### Using a non-default model
When doing inference, an alternate model can be used by passing the model object to the ingestion method via the `model` parameter. The `get_model` function can be used to construct one of our out-of-the-box models from a keyword, e.g.:
```python
from unstructured_inference.models.base import get_model
from unstructured_inference.inference.layout import DocumentLayout
model = get_model("yolox")
layout = DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf", detection_model=model)
```
### Using your own model
Any detection model can be used for in the `unstructured_inference` pipeline by wrapping the model in the `UnstructuredObjectDetectionModel` class. To integrate with the `DocumentLayout` class, a subclass of `UnstructuredObjectDetectionModel` must have a `predict` method that accepts a `PIL.Image.Image` and returns a list of `LayoutElement`s, and an `initialize` method, which loads the model and prepares it for inference.
## Security Policy
See our [security policy](https://github.com/Unstructured-IO/unstructured-inference/security/policy) for
information on how to report security vulnerabilities.
## Learn more
| Section | Description |
|-|-|
| [Unstructured Community Github](https://github.com/Unstructured-IO/community) | Information about Unstructured.io community projects |
| [Unstructured Github](https://github.com/Unstructured-IO) | Unstructured.io open source repositories |
| [Company Website](https://unstructured.io) | Unstructured.io product and company info |
================================================
FILE: benchmarks/__init__.py
================================================
================================================
FILE: benchmarks/test_benchmark_yolox.py
================================================
"""Benchmark for YoloX image_processing() memory optimization.
Uses a fake ONNX session to isolate the memory behavior of image_processing()
without requiring the real model weights. The fake session allocates a realistic
35 MiB workspace to simulate ONNX inference memory pressure.
"""
import numpy as np
from PIL import Image as PILImage
from unstructured_inference.models.yolox import UnstructuredYoloXModel
class _FakeInput:
def __init__(self) -> None:
self.name = "input"
class _FakeSession:
"""Simulates an ONNX inference session with realistic memory allocation."""
def get_inputs(self):
return [_FakeInput()]
def run(self, _names, _inputs):
workspace = np.empty((35 * 1024 * 1024,), dtype=np.uint8) # 35 MiB # noqa: F841
# input_shape (1024,768), strides [8,16,32] → 128*96 + 64*48 + 32*24 = 16128
return [np.random.randn(1, 16128, 16).astype(np.float32)]
def make_model() -> UnstructuredYoloXModel:
model = object.__new__(UnstructuredYoloXModel)
model.model = _FakeSession()
model.model_path = "yolox_fake"
model.layout_classes = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
}
return model
# Letter-size page at 200 DPI — the default render resolution
def make_letter_200dpi() -> PILImage.Image:
return PILImage.fromarray(np.random.randint(0, 255, (2200, 1700, 3), dtype=np.uint8))
def run_image_processing():
model = make_model()
img = make_letter_200dpi()
return model.image_processing(img)
def test_benchmark_yolox_image_processing(benchmark):
benchmark(run_image_processing)
================================================
FILE: examples/ocr/engine.py
================================================
import os
import re
import time
from typing import List, cast
import cv2
import numpy as np
import pytesseract
from pytesseract import Output
from unstructured_inference.inference import layout
from unstructured_inference.inference.elements import Rectangle, TextRegion
def remove_non_printable(s):
dst_str = re.sub(r"[^\x20-\x7E]", " ", s)
return " ".join(dst_str.split())
def run_ocr_with_layout_detection(
images,
detection_model=None,
element_extraction_model=None,
mode="individual_blocks",
output_dir="",
drawable=True,
printable=True,
):
total_text_extraction_infer_time = 0
total_extracted_text = {}
for i, image in enumerate(images):
page_num = i + 1
page_num_str = f"page{page_num}"
page = layout.PageLayout(
number=i + 1,
image=image,
layout=None,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
)
inferred_layout: List[TextRegion] = cast(List[TextRegion], page.detection_model(page.image))
cv_img = np.array(image)
if mode == "individual_blocks":
# OCR'ing individual blocks (current approach)
text_extraction_start_time = time.time()
elements = page.get_elements_from_layout(inferred_layout)
text_extraction_infer_time = time.time() - text_extraction_start_time
total_text_extraction_infer_time += text_extraction_infer_time
page_text = ""
for el in elements:
page_text += el.text
filtered_page_text = remove_non_printable(page_text)
total_extracted_text[page_num_str] = filtered_page_text
elif mode == "entire_page":
# OCR'ing entire page (new approach to implement)
text_extraction_start_time = time.time()
ocr_data = pytesseract.image_to_data(image, lang="eng", output_type=Output.DICT)
boxes = ocr_data["level"]
extracted_text_list = []
for k in range(len(boxes)):
(x, y, w, h) = (
ocr_data["left"][k],
ocr_data["top"][k],
ocr_data["width"][k],
ocr_data["height"][k],
)
extracted_text = ocr_data["text"][k]
if not extracted_text:
continue
extracted_region = Rectangle(x1=x, y1=y, x2=x + w, y2=y + h)
extracted_is_subregion_of_inferred = False
for inferred_region in inferred_layout:
extracted_is_subregion_of_inferred = extracted_region.is_almost_subregion_of(
inferred_region.pad(12),
subregion_threshold=0.75,
)
if extracted_is_subregion_of_inferred:
break
if extracted_is_subregion_of_inferred:
extracted_text_list.append(extracted_text)
if drawable:
if extracted_is_subregion_of_inferred:
cv2.rectangle(cv_img, (x, y), (x + w, y + h), (0, 255, 0), 2, None)
else:
cv2.rectangle(cv_img, (x, y), (x + w, y + h), (255, 0, 0), 2, None)
text_extraction_infer_time = time.time() - text_extraction_start_time
total_text_extraction_infer_time += text_extraction_infer_time
page_text = " ".join(extracted_text_list)
filtered_page_text = remove_non_printable(page_text)
total_extracted_text[page_num_str] = filtered_page_text
else:
raise ValueError("Invalid mode")
if drawable:
for el in inferred_layout:
pt1 = [int(el.x1), int(el.y1)]
pt2 = [int(el.x2), int(el.y2)]
cv2.rectangle(
img=cv_img,
pt1=pt1,
pt2=pt2,
color=(0, 0, 255),
thickness=4,
lineType=None,
)
f_path = os.path.join(output_dir, f"ocr_{mode}_{page_num_str}.jpg")
cv2.imwrite(f_path, cv_img)
if printable:
print(
f"page: {i + 1} - n_layout_elements: {len(inferred_layout)} - "
f"text_extraction_infer_time: {text_extraction_infer_time}"
)
return total_text_extraction_infer_time, total_extracted_text
def run_ocr(
images,
printable=True,
):
total_text_extraction_infer_time = 0
total_text = ""
for i, image in enumerate(images):
text_extraction_start_time = time.time()
page_text = pytesseract.image_to_string(image)
text_extraction_infer_time = time.time() - text_extraction_start_time
if printable:
print(f"page: {i + 1} - text_extraction_infer_time: {text_extraction_infer_time}")
total_text_extraction_infer_time += text_extraction_infer_time
total_text += page_text
return total_text_extraction_infer_time, total_text
================================================
FILE: examples/ocr/requirements.txt
================================================
unstructured[local-inference]
nltk
================================================
FILE: examples/ocr/validate_ocr_performance.py
================================================
import json
import os
import time
from datetime import datetime
from difflib import SequenceMatcher
import nltk
import pdf2image
from unstructured_inference.inference.layout import (
DocumentLayout,
create_image_output_dir,
process_file_with_model,
)
# Download the required resources (run this once)
nltk.download("punkt")
def validate_performance(
f_name,
validation_mode,
is_image_file=False,
):
print(
f">>> Start performance comparison - filename: {f_name}"
f" - validation_mode: {validation_mode}"
f" - is_image_file: {is_image_file}"
)
now_dt = datetime.utcnow()
now_str = now_dt.strftime("%Y_%m_%d-%H_%M_%S")
f_path = os.path.join(example_docs_dir, f_name)
image_f_paths = []
if validation_mode == "pdf":
pdf_info = pdf2image.pdfinfo_from_path(f_path)
n_pages = pdf_info["Pages"]
elif validation_mode == "image":
if is_image_file:
image_f_paths.append(f_path)
else:
image_output_dir = create_image_output_dir(f_path)
images = pdf2image.convert_from_path(f_path, output_folder=image_output_dir)
image_f_paths = [image.filename for image in images]
n_pages = len(image_f_paths)
else:
n_pages = 0
processing_result = {}
for ocr_mode in ["individual_blocks", "entire_page"]:
start_time = time.time()
if validation_mode == "pdf":
layout = process_file_with_model(
f_path,
model_name=None,
ocr_mode=ocr_mode,
)
elif validation_mode == "image":
pages = []
for image_f_path in image_f_paths:
_layout = process_file_with_model(
image_f_path,
model_name=None,
ocr_mode=ocr_mode,
is_image=True,
)
pages += _layout.pages
for i, page in enumerate(pages):
page.number = i + 1
layout = DocumentLayout.from_pages(pages)
else:
layout = None
infer_time = time.time() - start_time
if layout is None:
print("Layout is None")
return
full_text = str(layout)
page_text = {}
for page in layout.pages:
page_text[page.number] = str(page)
processing_result[ocr_mode] = {
"infer_time": infer_time,
"full_text": full_text,
"page_text": page_text,
}
individual_mode_page_text = processing_result["individual_blocks"]["page_text"]
entire_mode_page_text = processing_result["individual_blocks"]["page_text"]
individual_mode_full_text = processing_result["individual_blocks"]["full_text"]
entire_mode_full_text = processing_result["entire_page"]["full_text"]
compare_result = compare_processed_text(individual_mode_full_text, entire_mode_full_text)
report = {
"validation_mode": validation_mode,
"file_info": {
"filename": f_name,
"n_pages": n_pages,
},
"processing_time": {
"individual_blocks": processing_result["individual_blocks"]["infer_time"],
"entire_page": processing_result["entire_page"]["infer_time"],
},
"text_similarity": compare_result,
"extracted_text": {
"individual_blocks": {
"page_text": individual_mode_page_text,
"full_text": individual_mode_full_text,
},
"entire_page": {
"page_text": entire_mode_page_text,
"full_text": entire_mode_full_text,
},
},
}
write_report(report, now_str, validation_mode)
print("<<< End performance comparison", f_name)
def compare_processed_text(individual_mode_full_text, entire_mode_full_text, delimiter=" "):
# Calculate similarity ratio
similarity_ratio = SequenceMatcher(
None, individual_mode_full_text, entire_mode_full_text
).ratio()
print(f"similarity_ratio: {similarity_ratio}")
# Tokenize the text into words
word_list_individual = nltk.word_tokenize(individual_mode_full_text)
n_word_list_individual = len(word_list_individual)
print("n_word_list_in_text_individual:", n_word_list_individual)
word_sets_individual = set(word_list_individual)
n_word_sets_individual = len(word_sets_individual)
print(f"n_word_sets_in_text_individual: {n_word_sets_individual}")
# print("word_sets_merged:", word_sets_merged)
word_list_entire = nltk.word_tokenize(entire_mode_full_text)
n_word_list_entire = len(word_list_entire)
print("n_word_list_individual:", n_word_list_entire)
word_sets_entire = set(word_list_entire)
n_word_sets_entire = len(word_sets_entire)
print(f"n_word_sets_individual: {n_word_sets_entire}")
# print("word_sets_individual:", word_sets_individual)
# Find unique elements using difference
print("diff_elements:")
unique_words_individual = word_sets_individual - word_sets_entire
unique_words_entire = word_sets_entire - word_sets_individual
print(f"unique_words_in_text_individual: {unique_words_individual}\n")
print(f"unique_words_in_text_entire: {unique_words_entire}")
return {
"similarity_ratio": similarity_ratio,
"individual_blocks": {
"n_word_list": n_word_list_individual,
"n_word_sets": n_word_sets_individual,
"unique_words": delimiter.join(list(unique_words_individual)),
},
"entire_page": {
"n_word_list": n_word_list_entire,
"n_word_sets": n_word_sets_entire,
"unique_words": delimiter.join(list(unique_words_entire)),
},
}
def write_report(report, now_str, validation_mode):
report_f_name = f"validate-ocr-{validation_mode}-{now_str}.json"
report_f_path = os.path.join(output_dir, report_f_name)
with open(report_f_path, "w", encoding="utf-8-sig") as f:
json.dump(report, f, indent=4)
def run():
test_files = [
{"name": "layout-parser-paper-fast.pdf", "mode": "image", "is_image_file": False},
{"name": "loremipsum_multipage.pdf", "mode": "image", "is_image_file": False},
{"name": "2023-Jan-economic-outlook.pdf", "mode": "image", "is_image_file": False},
{"name": "recalibrating-risk-report.pdf", "mode": "image", "is_image_file": False},
{"name": "Silent-Giant.pdf", "mode": "image", "is_image_file": False},
]
for test_file in test_files:
f_name = test_file["name"]
validation_mode = test_file["mode"]
is_image_file = test_file["is_image_file"]
validate_performance(f_name, validation_mode, is_image_file)
if __name__ == "__main__":
cur_dir = os.getcwd()
base_dir = os.path.join(cur_dir, os.pardir, os.pardir)
example_docs_dir = os.path.join(base_dir, "sample-docs")
# folder path to save temporary outputs
output_dir = os.path.join(cur_dir, "output")
os.makedirs(output_dir, exist_ok=True)
run()
================================================
FILE: logger_config.yaml
================================================
version: 1
disable_existing_loggers: False
formatters:
default_format:
"()": uvicorn.logging.DefaultFormatter
format: '%(asctime)s %(name)s %(levelname)s %(message)s'
access:
"()": uvicorn.logging.AccessFormatter
format: '%(asctime)s %(client_addr)s %(request_line)s - %(status_code)s'
handlers:
access_handler:
formatter: access
class: logging.StreamHandler
stream: ext://sys.stderr
standard_handler:
formatter: default_format
class: logging.StreamHandler
stream: ext://sys.stderr
loggers:
uvicorn.error:
level: INFO
handlers:
- standard_handler
propagate: no
# disable logging for uvicorn.error by not having a handler
uvicorn.access:
level: INFO
handlers:
- access_handler
propagate: no
# disable logging for uvicorn.access by not having a handler
unstructured:
level: INFO
handlers:
- standard_handler
propagate: no
unstructured_inference:
level: DEBUG
handlers:
- standard_handler
propagate: no
================================================
FILE: pyproject.toml
================================================
[project]
name = "unstructured_inference"
description = "A library for performing inference using trained models."
requires-python = ">=3.11, <3.14"
authors = [{name = "Unstructured Technologies", email = "devops@unstructuredai.io"}]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
readme = "README.md"
license = "Apache-2.0"
keywords = ["NLP", "PDF", "HTML", "CV", "XML", "parsing", "preprocessing"]
dynamic = ["version"]
dependencies = [
"huggingface-hub>=0.22.0",
"numpy>=1.26.0",
"opencv-python>=4.13.0.90",
"onnx>=1.20.1",
"onnxruntime>=1.25.0",
"matplotlib>=3.10.8",
"torch>=2.10.0",
"timm>=1.0.24",
# NOTE(alan): Pinned because this is when the most recent module we import appeared
"transformers>=4.25.1",
# Required by transformers[torch] for model loading with torch
"accelerate>=1.12.0",
"rapidfuzz>=3.14.3",
"pandas>=1.5.0",
"scipy>=1.17.0",
"pypdfium2>=5.0.0",
]
[project.urls]
Homepage = "https://github.com/Unstructured-IO/unstructured-inference"
[tool.hatch.version]
path = "unstructured_inference/__version__.py"
[dependency-groups]
lint = [
"ruff>=0.15.0",
]
test = [
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"pytest-xdist>=3.5.0",
"coverage>=7.13.3",
"httpx>=0.28.1",
"pdf2image>=1.16.2",
]
dev = [
"jupyter>=1.1.1",
"ipython>=9.10.0",
]
release = [
"twine>=6.2.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
constraint-dependencies = [
# Security: CVE fix for fonttools
"fonttools>=4.60.2",
# Security: CVE fix for urllib3
"urllib3>=2.6.0",
# Security: CVE fix for Pillow (out-of-bounds write loading PSD images)
"pillow>=12.1.1",
]
[tool.hatch.build.targets.wheel]
packages = ["/unstructured_inference"]
[tool.hatch.build.targets.sdist]
packages = ["/unstructured_inference"]
[tool.ruff]
line-length = 100
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# flake8-comprehensions
"C4",
# flake8-commas
"COM",
# isort
"I",
# flake8-simplify
"SIM",
# pyupgrade
"UP015",
"UP018",
"UP032",
"UP034",
# pylint refactor
"PLR0402",
# flake8-pytest-style
"PT",
]
ignore = [
"COM812",
"PT011",
"PT012",
]
[tool.ruff.lint.per-file-ignores]
"test_*/**" = ["D"]
[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.codeflash]
benchmarks-root = "benchmarks"
[tool.coverage.report]
fail_under = 90
================================================
FILE: renovate.json
================================================
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": ["github>Unstructured-IO/renovate-config:python-uv"]
}
================================================
FILE: sample-docs/loremipsum.tiff
================================================
[File too large to display: 14.3 MB]
================================================
FILE: scripts/docker-build.sh
================================================
#!/usr/bin/env bash
set -euo pipefail
DOCKER_IMAGE="${DOCKER_IMAGE:-unstructured-inference:dev}"
DOCKER_BUILD_CMD=(docker buildx build --load -f Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--progress plain \
-t "$DOCKER_IMAGE" .)
DOCKER_BUILDKIT=1 "${DOCKER_BUILD_CMD[@]}"
================================================
FILE: scripts/shellcheck.sh
================================================
#!/usr/bin/env bash
find scripts -name "*.sh" -exec shellcheck {} +
================================================
FILE: scripts/test-unstructured-ingest-helper.sh
================================================
#!/usr/bin/env bash
# This is intended to be run from an unstructured checkout, not in this repo
# The goal here is to see what changes the current branch would introduce to unstructured
# fixtures
INGEST_COMMANDS=(
test_unstructured_ingest/src/azure.sh
test_unstructured_ingest/src/biomed-api.sh
test_unstructured_ingest/src/biomed-path.sh
test_unstructured_ingest/src/box.sh
test_unstructured_ingest/src/dropbox.sh
test_unstructured_ingest/src/gcs.sh
test_unstructured_ingest/src/onedrive.sh
test_unstructured_ingest/src/s3.sh
)
EXIT_STATUSES=()
# Run each command and capture its exit status
for INGEST_COMMAND in "${INGEST_COMMANDS[@]}"; do
$INGEST_COMMAND
EXIT_STATUSES+=($?)
done
# Check for failures
for STATUS in "${EXIT_STATUSES[@]}"; do
if [[ $STATUS -ne 0 ]]; then
echo "At least one ingest command failed! Scroll up to see which"
exit 1
fi
done
echo "No diff's resulted from any ingest commands"
================================================
FILE: scripts/version-sync.sh
================================================
#!/usr/bin/env bash
function usage {
echo "Usage: $(basename "$0") [-c] -f FILE_TO_CHANGE REPLACEMENT_FORMAT [-f FILE_TO_CHANGE REPLACEMENT_FORMAT ...]" 2>&1
echo 'Synchronize files to latest version in source file'
echo ' -s Specifies source file for version (default is CHANGELOG.md)'
echo ' -f Specifies a file to change and the format for searching and replacing versions'
echo ' FILE_TO_CHANGE is the file to be updated/checked for updates'
echo ' REPLACEMENT_FORMAT is one of (semver, release, api-release)'
echo ' semver indicates to look for a full semver version and replace with the latest full version'
echo ' release indicates to look for a release semver version (x.x.x) and replace with the latest release version'
echo ' api-release indicates to look for a release semver version in the context of an api route and replace with the latest release version'
echo ' -c Compare versions and output proposed changes without changing anything.'
}
function getopts-extra () {
declare -i i=1
# if the next argument is not an option, then append it to array OPTARG
while [[ ${OPTIND} -le $# && ${!OPTIND:0:1} != '-' ]]; do
OPTARG[i]=${!OPTIND}
((i += 1))
((OPTIND += 1))
done
}
# Parse input options
declare CHECK=0
declare SOURCE_FILE="CHANGELOG.md"
declare -a FILES_TO_CHECK=()
declare -a REPLACEMENT_FORMATS=()
declare args
declare OPTIND OPTARG opt
while getopts ":hcs:f:" opt; do
case $opt in
h)
usage
exit 0
;;
c)
CHECK=1
;;
s)
SOURCE_FILE="$OPTARG"
;;
f)
getopts-extra "$@"
args=( "${OPTARG[@]}" )
# validate length of args, should be 2
if [ ${#args[@]} -eq 2 ]; then
FILES_TO_CHECK+=( "${args[0]}" )
REPLACEMENT_FORMATS+=( "${args[1]}" )
else
echo "Exactly 2 arguments must follow -f option." >&2
exit 1
fi
;;
\?)
echo "Invalid option: -$OPTARG." >&2
usage
exit 1
;;
esac
done
# Parse REPLACEMENT_FORMATS
RE_SEMVER_FULL="(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)(-((0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*)(\.(0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*))*))?(\+([0-9a-zA-Z-]+(\.[0-9a-zA-Z-]+)*))?"
RE_RELEASE="(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)"
RE_API_RELEASE="v(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)"
# Pull out semver appearing earliest in SOURCE_FILE.
LAST_VERSION=$(grep -o -m 1 -E "${RE_SEMVER_FULL}" "$SOURCE_FILE")
LAST_RELEASE=$(grep -o -m 1 -E "${RE_RELEASE}($|[^-+])" "$SOURCE_FILE" | grep -o -m 1 -E "${RE_RELEASE}")
LAST_API_RELEASE="v$(grep -o -m 1 -E "${RE_RELEASE}($|[^-+])$" "$SOURCE_FILE" | grep -o -m 1 -E "${RE_RELEASE}")"
declare -a RE_SEMVERS=()
declare -a UPDATED_VERSIONS=()
for i in "${!REPLACEMENT_FORMATS[@]}"; do
REPLACEMENT_FORMAT=${REPLACEMENT_FORMATS[$i]}
case $REPLACEMENT_FORMAT in
semver)
RE_SEMVERS+=( "$RE_SEMVER_FULL" )
UPDATED_VERSIONS+=( "$LAST_VERSION" )
;;
release)
RE_SEMVERS+=( "$RE_RELEASE" )
UPDATED_VERSIONS+=( "$LAST_RELEASE" )
;;
api-release)
RE_SEMVERS+=( "$RE_API_RELEASE" )
UPDATED_VERSIONS+=( "$LAST_API_RELEASE" )
;;
*)
echo "Invalid replacement format: \"${REPLACEMENT_FORMAT}\". Use semver, release, or api-release" >&2
exit 1
;;
esac
done
if [ -z "$LAST_VERSION" ];
then
# No match to semver regex in SOURCE_FILE, so no version to go from.
printf "Error: Unable to find latest version from %s.\n" "$SOURCE_FILE"
exit 1
fi
# Search files in FILES_TO_CHECK and change (or get diffs)
declare FAILED_CHECK=0
for i in "${!FILES_TO_CHECK[@]}"; do
FILE_TO_CHANGE=${FILES_TO_CHECK[$i]}
RE_SEMVER=${RE_SEMVERS[$i]}
UPDATED_VERSION=${UPDATED_VERSIONS[$i]}
FILE_VERSION=$(grep -o -m 1 -E "${RE_SEMVER}" "$FILE_TO_CHANGE")
if [ -z "$FILE_VERSION" ];
then
# No match to semver regex in VERSIONFILE, so nothing to replace
printf "Error: No semver version found in file %s.\n" "$FILE_TO_CHANGE"
exit 1
else
# Replace semver in VERSIONFILE with semver obtained from SOURCE_FILE
TMPFILE=$(mktemp /tmp/new_version.XXXXXX)
# Check sed version, exit if version < 4.3
if ! sed --version > /dev/null 2>&1; then
CURRENT_VERSION=1.archaic
else
CURRENT_VERSION=$(sed --version | head -n1 | cut -d" " -f4)
fi
REQUIRED_VERSION="4.3"
if [ "$(printf '%s\n' "$REQUIRED_VERSION" "$CURRENT_VERSION" | sort -V | head -n1)" != "$REQUIRED_VERSION" ]; then
echo "sed version must be >= ${REQUIRED_VERSION}" && exit 1
fi
sed -E -r "s/$RE_SEMVER/$UPDATED_VERSION/" "$FILE_TO_CHANGE" > "$TMPFILE"
if [ $CHECK == 1 ];
then
DIFF=$(diff "$FILE_TO_CHANGE" "$TMPFILE" )
if [ -z "$DIFF" ];
then
printf "version sync would make no changes to %s.\n" "$FILE_TO_CHANGE"
rm "$TMPFILE"
else
FAILED_CHECK=1
printf "version sync would make the following changes to %s:\n%s\n" "$FILE_TO_CHANGE" "$DIFF"
rm "$TMPFILE"
fi
else
cp "$TMPFILE" "$FILE_TO_CHANGE"
rm "$TMPFILE"
fi
fi
done
# Exit with code determined by whether changes were needed in a check.
if [ ${FAILED_CHECK} -ne 0 ]; then
exit 1
else
exit 0
fi
================================================
FILE: test_unstructured_inference/conftest.py
================================================
import numpy as np
import pytest
from PIL import Image
from unstructured_inference.inference.elements import (
EmbeddedTextRegion,
Rectangle,
TextRegion,
)
from unstructured_inference.inference.layoutelement import LayoutElement
@pytest.fixture
def mock_pil_image():
return Image.new("RGB", (50, 50))
@pytest.fixture
def mock_numpy_image():
return np.zeros((50, 50, 3), np.uint8)
@pytest.fixture
def mock_rectangle():
return Rectangle(100, 100, 300, 300)
@pytest.fixture
def mock_text_region():
return TextRegion.from_coords(100, 100, 300, 300, text="Sample text")
@pytest.fixture
def mock_layout_element():
return LayoutElement.from_coords(
100,
100,
300,
300,
text="Sample text",
source=None,
type="Text",
)
@pytest.fixture
def mock_embedded_text_regions():
return [
EmbeddedTextRegion.from_coords(
x1=453.00277777777774,
y1=317.319341111111,
x2=711.5338541666665,
y2=358.28571222222206,
text="LayoutParser:",
),
EmbeddedTextRegion.from_coords(
x1=726.4778125,
y1=317.319341111111,
x2=760.3308594444444,
y2=357.1698966666667,
text="A",
),
EmbeddedTextRegion.from_coords(
x1=775.2748177777777,
y1=317.319341111111,
x2=917.3579885555555,
y2=357.1698966666667,
text="Unified",
),
EmbeddedTextRegion.from_coords(
x1=932.3019468888888,
y1=317.319341111111,
x2=1071.8426522222221,
y2=357.1698966666667,
text="Toolkit",
),
EmbeddedTextRegion.from_coords(
x1=1086.7866105555556,
y1=317.319341111111,
x2=1141.2105142777777,
y2=357.1698966666667,
text="for",
),
EmbeddedTextRegion.from_coords(
x1=1156.154472611111,
y1=317.319341111111,
x2=1256.334784222222,
y2=357.1698966666667,
text="Deep",
),
EmbeddedTextRegion.from_coords(
x1=437.83888888888885,
y1=367.13322999999986,
x2=610.0171992222222,
y2=406.9837855555556,
text="Learning",
),
EmbeddedTextRegion.from_coords(
x1=624.9611575555555,
y1=367.13322999999986,
x2=741.6754646666665,
y2=406.9837855555556,
text="Based",
),
EmbeddedTextRegion.from_coords(
x1=756.619423,
y1=367.13322999999986,
x2=958.3867708333332,
y2=406.9837855555556,
text="Document",
),
EmbeddedTextRegion.from_coords(
x1=973.3307291666665,
y1=367.13322999999986,
x2=1092.0535042777776,
y2=406.9837855555556,
text="Image",
),
]
# TODO(alan): Make a better test layout
@pytest.fixture
def mock_layout(mock_embedded_text_regions):
return [
LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox)
for r in mock_embedded_text_regions
]
@pytest.fixture
def example_table_cells():
cells = [
{"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]},
{"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]},
{"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]},
{"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]},
{"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]},
{"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]},
{"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]},
{"cell text": "Blind", "row_nums": [2], "column_nums": [0]},
{"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]},
{"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]},
{"cell text": "Mobility", "row_nums": [5], "column_nums": [0]},
{"cell text": "5", "row_nums": [2], "column_nums": [1]},
{"cell text": "5", "row_nums": [3], "column_nums": [1]},
{"cell text": "5", "row_nums": [4], "column_nums": [1]},
{"cell text": "3", "row_nums": [5], "column_nums": [1]},
{"cell text": "1", "row_nums": [2], "column_nums": [2]},
{"cell text": "2", "row_nums": [3], "column_nums": [2]},
{"cell text": "4", "row_nums": [4], "column_nums": [2]},
{"cell text": "3", "row_nums": [5], "column_nums": [2]},
{"cell text": "4", "row_nums": [2], "column_nums": [3]},
{"cell text": "3", "row_nums": [3], "column_nums": [3]},
{"cell text": "1", "row_nums": [4], "column_nums": [3]},
{"cell text": "0", "row_nums": [5], "column_nums": [3]},
{"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]},
{"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]},
{"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]},
{"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]},
{"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]},
{"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]},
{"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]},
{"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]},
]
for i in range(len(cells)):
cells[i]["column header"] = False
return [cells]
================================================
FILE: test_unstructured_inference/inference/test_layout.py
================================================
import os
import os.path
import tempfile
from unittest.mock import MagicMock, mock_open, patch
import numpy as np
import pytest
from PIL import Image
import unstructured_inference.models.base as models
from unstructured_inference.constants import IsExtracted
from unstructured_inference.inference import elements, layout, layoutelement, pdf_image
from unstructured_inference.inference.elements import (
EmbeddedTextRegion,
ImageTextRegion,
)
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
UnstructuredObjectDetectionModel,
)
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
@pytest.fixture
def mock_image():
return Image.new("1", (1, 1))
@pytest.fixture
def mock_initial_layout():
text_block = EmbeddedTextRegion.from_coords(
2,
4,
6,
8,
text="A very repetitive narrative. " * 10,
is_extracted=IsExtracted.TRUE,
)
title_block = EmbeddedTextRegion.from_coords(
1,
2,
3,
4,
text="A Catchy Title",
is_extracted=IsExtracted.TRUE,
)
return [text_block, title_block]
@pytest.fixture
def mock_final_layout():
text_block = layoutelement.LayoutElement.from_coords(
2,
4,
6,
8,
source="Mock",
text="A very repetitive narrative. " * 10,
type="NarrativeText",
)
title_block = layoutelement.LayoutElement.from_coords(
1,
2,
3,
4,
source="Mock",
text="A Catchy Title",
type="Title",
)
return layoutelement.LayoutElements.from_list([text_block, title_block])
def test_pdf_page_converts_images_to_array(mock_image):
def verify_image_array():
assert page.image_array is None
image_array = page._get_image_array()
assert isinstance(image_array, np.ndarray)
assert page.image_array.all() == image_array.all()
# Scenario 1: where self.image exists
page = layout.PageLayout(number=0, image=mock_image)
verify_image_array()
# Scenario 2: where self.image is None, but self.image_path exists
page.image_array = None
page.image = None
page.image_path = "mock_path_to_image"
with patch.object(Image, "open", return_value=mock_image):
verify_image_array()
class MockLayoutModel:
def __init__(self, layout):
self.layout_return = layout
def __call__(self, *args):
return self.layout_return
def initialize(self, *args, **kwargs):
pass
def deduplicate_detected_elements(self, elements, *args, **kwargs):
return elements
def test_get_page_elements(monkeypatch, mock_final_layout):
image = Image.fromarray(
np.random.randint(12, 14, size=(40, 10, 3)).astype(np.uint8), mode="RGB"
)
page = layout.PageLayout(
number=0,
image=image,
detection_model=MockLayoutModel(mock_final_layout),
)
elements = page.get_elements_with_detection_model(inplace=False)
page.get_elements_with_detection_model(inplace=True)
assert elements == page.elements_array
class MockPool:
def map(self, f, xs):
return [f(x) for x in xs]
def close(self):
pass
def join(self):
pass
@pytest.mark.parametrize("model_name", [None, "checkbox", "fake"])
def test_process_data_with_model(monkeypatch, mock_final_layout, model_name):
monkeypatch.setattr(layout, "get_model", lambda x: MockLayoutModel(mock_final_layout))
monkeypatch.setattr(
layout.DocumentLayout,
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
def new_isinstance(obj, cls):
if type(obj) is MockLayoutModel:
return True
else:
return isinstance(obj, cls)
with (
patch("builtins.open", mock_open(read_data=b"000000")),
patch(
"unstructured_inference.inference.layout.UnstructuredObjectDetectionModel",
MockLayoutModel,
),
open("") as fp,
):
assert layout.process_data_with_model(fp, model_name=model_name)
def test_process_data_with_model_raises_on_invalid_model_name():
with (
patch("builtins.open", mock_open(read_data=b"000000")),
pytest.raises(
models.UnknownModelException,
),
open("") as fp,
):
layout.process_data_with_model(fp, model_name="fake")
@pytest.mark.parametrize("model_name", [None, "yolox"])
def test_process_file_with_model(monkeypatch, mock_final_layout, model_name):
def mock_initialize(self, *args, **kwargs):
self.model = MockLayoutModel(mock_final_layout)
monkeypatch.setattr(
layout.DocumentLayout,
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize)
filename = ""
assert layout.process_file_with_model(filename, model_name=model_name)
def test_process_file_no_warnings(monkeypatch, mock_final_layout, recwarn):
def mock_initialize(self, *args, **kwargs):
self.model = MockLayoutModel(mock_final_layout)
monkeypatch.setattr(
layout.DocumentLayout,
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize)
filename = ""
layout.process_file_with_model(filename, model_name=None)
# There should be no UserWarning, but if there is one it should not have the following message
with pytest.raises(AssertionError, match="not found in warning list"):
user_warning = recwarn.pop(UserWarning)
assert "not in available provider names" not in str(user_warning.message)
def test_process_file_with_model_raises_on_invalid_model_name():
with pytest.raises(models.UnknownModelException):
layout.process_file_with_model("", model_name="fake")
class MockPoints:
def tolist(self):
return [1, 2, 3, 4]
class MockEmbeddedTextRegion(EmbeddedTextRegion):
def __init__(self, type=None, text=None):
self.type = type
self.text = text
@property
def points(self):
return MockPoints()
class MockPageLayout(layout.PageLayout):
def __init__(
self,
number=1,
image=None,
model=None,
detection_model=None,
):
self.image = image
self.layout = layout
self.model = model
self.number = number
self.detection_model = detection_model
class MockLayout:
def __init__(self, *elements):
self.elements = elements
def __len__(self):
return len(self.elements)
def sort(self, key, inplace):
return self.elements
def __iter__(self):
return iter(self.elements)
def get_texts(self):
return [el.text for el in self.elements]
def filter_by(self, *args, **kwargs):
return MockLayout()
@pytest.mark.parametrize("element_extraction_model", [None, "foo"])
@pytest.mark.parametrize("filetype", ["png", "jpg", "tiff"])
def test_from_image_file(monkeypatch, mock_final_layout, filetype, element_extraction_model):
def mock_get_elements(self, *args, **kwargs):
self.elements = [mock_final_layout]
monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements)
monkeypatch.setattr(layout.PageLayout, "get_elements_using_image_extraction", mock_get_elements)
filename = f"sample-docs/loremipsum.{filetype}"
image = Image.open(filename)
image_metadata = {
"format": image.format,
"width": image.width,
"height": image.height,
"pdf_rotation": 0,
}
doc = layout.DocumentLayout.from_image_file(
filename,
element_extraction_model=element_extraction_model,
)
page = doc.pages[0]
assert page.elements[0] == mock_final_layout
assert page.image is None
assert page.image_path == os.path.abspath(filename)
assert page.image_metadata == image_metadata
def test_from_file(monkeypatch, mock_final_layout):
def mock_get_elements(self, *args, **kwargs):
self.elements = [mock_final_layout]
monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements)
with tempfile.TemporaryDirectory() as tmpdir:
image_path = os.path.join(tmpdir, "loremipsum.ppm")
image = Image.open("sample-docs/loremipsum.jpg")
image.save(image_path)
image_metadata = {
"format": "PPM",
"width": image.width,
"height": image.height,
"pdf_rotation": 0,
}
with patch.object(
layout,
"convert_pdf_to_image",
lambda *args, **kwargs: ([image_path]),
):
doc = layout.DocumentLayout.from_file("fake-file.pdf")
page = doc.pages[0]
assert page.elements[0] == mock_final_layout
assert page.image_metadata == image_metadata
assert page.image is None
def test_from_file_rotated_pdf_stores_rotation_in_metadata(monkeypatch, mock_final_layout):
"""image_metadata includes pdf_rotation for rotated PDF pages."""
def mock_get_elements(self, *args, **kwargs):
self.elements = [mock_final_layout]
monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements)
doc = layout.DocumentLayout.from_file("sample-docs/rotated-page-90.pdf")
page = doc.pages[0]
assert page.image_metadata["pdf_rotation"] == 90
assert page.image is None
@pytest.mark.slow
def test_from_file_with_password(monkeypatch, mock_final_layout):
doc = layout.DocumentLayout.from_file("sample-docs/password.pdf", password="password")
assert doc
monkeypatch.setattr(layout, "get_model", lambda x: MockLayoutModel(mock_final_layout))
with (
patch(
"unstructured_inference.inference.layout.UnstructuredObjectDetectionModel",
MockLayoutModel,
),
open("sample-docs/password.pdf", mode="rb") as fp,
):
doc = layout.process_data_with_model(fp, model_name="fake", password="password")
assert doc
def test_from_image_file_raises_with_empty_fn():
with pytest.raises(FileNotFoundError):
layout.DocumentLayout.from_image_file("")
def test_from_image_file_raises_isadirectoryerror_with_dir():
with tempfile.TemporaryDirectory() as tempdir, pytest.raises(IsADirectoryError):
layout.DocumentLayout.from_image_file(tempdir)
def test_page_numbers_in_page_objects():
with patch(
"unstructured_inference.inference.layout.PageLayout.get_elements_with_detection_model",
) as mock_get_elements:
doc = layout.DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf")
mock_get_elements.assert_called()
assert [page.number for page in doc.pages] == list(range(1, len(doc.pages) + 1))
no_text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100)
text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100, text="test")
overlapping_rect = ImageTextRegion.from_coords(50, 50, 150, 150)
nonoverlapping_rect = ImageTextRegion.from_coords(150, 150, 200, 200)
populated_text_region = EmbeddedTextRegion.from_coords(50, 50, 60, 60, text="test")
unpopulated_text_region = EmbeddedTextRegion.from_coords(50, 50, 60, 60, text=None)
@pytest.mark.parametrize(
("colors", "add_details", "threshold"),
[("red", False, 0.992), (None, False, 0.992), ("red", True, 0.8)],
)
def test_annotate(colors, add_details, threshold):
def check_annotated_image():
annotated_array = np.array(annotated_image)
for coords in [coords1, coords2]:
x1, y1, x2, y2 = coords
# Make sure the pixels on the edge of the box are red
for i, expected in zip(range(3), [255, 0, 0]):
assert all(annotated_array[y1, x1:x2, i] == expected)
assert all(annotated_array[y2, x1:x2, i] == expected)
assert all(annotated_array[y1:y2, x1, i] == expected)
assert all(annotated_array[y1:y2, x2, i] == expected)
# Make sure almost all the pixels are not changed
assert ((annotated_array[:, :, 0] == 1).mean()) > threshold
assert ((annotated_array[:, :, 1] == 1).mean()) > threshold
assert ((annotated_array[:, :, 2] == 1).mean()) > threshold
test_image_arr = np.ones((100, 100, 3), dtype="uint8")
image = Image.fromarray(test_image_arr)
page = layout.PageLayout(number=1, image=image)
coords1 = (21, 30, 37, 41)
rect1 = elements.TextRegion.from_coords(*coords1)
coords2 = (1, 10, 7, 11)
rect2 = elements.TextRegion.from_coords(*coords2)
page.elements = [rect1, rect2]
annotated_image = page.annotate(colors=colors, add_details=add_details, sources=None)
check_annotated_image()
# Scenario 1: where self.image exists
annotated_image = page.annotate(colors=colors, add_details=add_details)
check_annotated_image()
# Scenario 2: where self.image is None, but self.image_path exists
with patch.object(Image, "open", return_value=image):
page.image = None
page.image_path = "mock_path_to_image"
annotated_image = page.annotate(colors=colors, add_details=add_details)
check_annotated_image()
class MockDetectionModel(layout.UnstructuredObjectDetectionModel):
def initialize(self, *args, **kwargs):
pass
def predict(self, x):
return layoutelement.LayoutElements.from_list(
[
layout.LayoutElement.from_coords(x1=447.0, y1=315.0, x2=1275.7, y2=413.0, text="0"),
layout.LayoutElement.from_coords(x1=380.6, y1=473.4, x2=1334.8, y2=533.9, text="1"),
layout.LayoutElement.from_coords(x1=578.6, y1=556.8, x2=1109.0, y2=874.4, text="2"),
layout.LayoutElement.from_coords(
x1=444.5,
y1=942.3,
x2=1261.1,
y2=1584.1,
text="3",
),
layout.LayoutElement.from_coords(
x1=444.8,
y1=1609.4,
x2=1257.2,
y2=1665.2,
text="4",
),
layout.LayoutElement.from_coords(
x1=414.0,
y1=1718.8,
x2=635.0,
y2=1755.2,
text="5",
),
layout.LayoutElement.from_coords(
x1=372.6,
y1=1786.9,
x2=1333.6,
y2=1848.7,
text="6",
),
],
)
def test_layout_order(mock_image):
with tempfile.TemporaryDirectory() as tmpdir:
mock_image_path = os.path.join(tmpdir, "mock.jpg")
mock_image.save(mock_image_path)
with (
patch.object(layout, "get_model", lambda: MockDetectionModel()),
patch.object(
layout,
"convert_pdf_to_image",
lambda *args, **kwargs: ([mock_image_path]),
),
):
doc = layout.DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf")
page = doc.pages[0]
for n, element in enumerate(page.elements):
assert element.text == str(n)
def test_page_layout_raises_when_multiple_models_passed(mock_image, mock_initial_layout):
with pytest.raises(ValueError):
layout.PageLayout(
0,
mock_image,
mock_initial_layout,
detection_model="something",
element_extraction_model="something else",
)
class MockElementExtractionModel:
def __call__(self, x):
return [1, 2, 3]
@pytest.mark.parametrize(("inplace", "expected"), [(True, None), (False, [1, 2, 3])])
def test_get_elements_using_image_extraction(mock_image, inplace, expected):
page = layout.PageLayout(
1,
mock_image,
None,
element_extraction_model=MockElementExtractionModel(),
)
assert page.get_elements_using_image_extraction(inplace=inplace) == expected
def test_get_elements_using_image_extraction_raises_with_no_extraction_model(
mock_image,
):
page = layout.PageLayout(1, mock_image, None, element_extraction_model=None)
with pytest.raises(ValueError):
page.get_elements_using_image_extraction()
def test_get_elements_with_detection_model_raises_with_wrong_default_model(monkeypatch):
monkeypatch.setattr(layout, "get_model", lambda *x: MockLayoutModel(mock_final_layout))
page = layout.PageLayout(1, mock_image, None)
with pytest.raises(NotImplementedError):
page.get_elements_with_detection_model()
@pytest.mark.parametrize(
(
"detection_model",
"element_extraction_model",
"detection_model_called",
"element_extraction_model_called",
),
[(None, "asdf", False, True), ("asdf", None, True, False)],
)
def test_from_image(
mock_image,
detection_model,
element_extraction_model,
detection_model_called,
element_extraction_model_called,
):
with (
patch.object(
layout.PageLayout,
"get_elements_using_image_extraction",
) as mock_image_extraction,
patch.object(
layout.PageLayout,
"get_elements_with_detection_model",
) as mock_detection,
):
layout.PageLayout.from_image(
mock_image,
image_path=None,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
)
assert mock_image_extraction.called == element_extraction_model_called
assert mock_detection.called == detection_model_called
class MockUnstructuredElementExtractionModel(UnstructuredElementExtractionModel):
def initialize(self, *args, **kwargs):
return super().initialize(*args, **kwargs)
def predict(self, x: Image):
return super().predict(x)
class MockUnstructuredDetectionModel(UnstructuredObjectDetectionModel):
def initialize(self, *args, **kwargs):
return super().initialize(*args, **kwargs)
def predict(self, x: Image):
return super().predict(x)
@pytest.mark.parametrize(
("model_type", "is_detection_model"),
[
(MockUnstructuredElementExtractionModel, False),
(MockUnstructuredDetectionModel, True),
],
)
def test_process_file_with_model_routing(monkeypatch, model_type, is_detection_model):
model = model_type()
monkeypatch.setattr(layout, "get_model", lambda *x: model)
with patch.object(layout.DocumentLayout, "from_file") as mock_from_file:
layout.process_file_with_model("asdf", model_name="fake", is_image=False)
if is_detection_model:
detection_model = model
element_extraction_model = None
else:
detection_model = None
element_extraction_model = model
mock_from_file.assert_called_once_with(
"asdf",
detection_model=detection_model,
element_extraction_model=element_extraction_model,
fixed_layouts=None,
password=None,
pdf_image_dpi=200,
pdf_render_max_pixels_per_page=None,
)
@pytest.mark.parametrize(("pdf_image_dpi", "expected"), [(200, 2200), (100, 1100)])
def test_exposed_pdf_image_dpi(pdf_image_dpi, expected, monkeypatch):
with patch.object(layout.PageLayout, "from_image") as mock_from_image:
layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", pdf_image_dpi=pdf_image_dpi)
assert mock_from_image.call_args[0][0].height == expected
def test_convert_pdf_to_image_no_output_folder():
result = layout.convert_pdf_to_image(filename="sample-docs/loremipsum.pdf", dpi=72)
assert len(result) == 1
assert isinstance(result[0], Image.Image)
def _install_mock_pdfium(monkeypatch, *, width=720, height=720):
page = MagicMock()
page.get_width.return_value = width
page.get_height.return_value = height
page.get_rotation.return_value = 0
page.render.return_value.to_pil.return_value = Image.new("RGB", (1, 1))
pdf = MagicMock()
pdf.__len__.return_value = 1
pdf.__getitem__.return_value = page
pdfium = MagicMock()
pdfium.PdfDocument.return_value = pdf
monkeypatch.setattr(pdf_image, "_get_pdfium_module", lambda: pdfium)
return page
def test_convert_pdf_to_image_rejects_oversized_page_before_render(monkeypatch):
page = _install_mock_pdfium(monkeypatch)
with pytest.raises(pdf_image.PdfRenderTooLargeError, match="too many pixels"):
pdf_image.convert_pdf_to_image(
filename="mock.pdf",
dpi=100,
pdf_render_max_pixels_per_page=999_999,
)
page.render.assert_not_called()
def test_convert_pdf_to_image_allows_render_guard_to_be_disabled(monkeypatch):
page = _install_mock_pdfium(monkeypatch)
result = pdf_image.convert_pdf_to_image(
filename="mock.pdf",
dpi=100,
pdf_render_max_pixels_per_page=0,
)
page.render.assert_called_once()
assert len(result) == 1
assert isinstance(result[0], Image.Image)
def test_page_hotload_preserves_render_max_pixels_per_page(monkeypatch, tmp_path):
image_path = tmp_path / "page_1.png"
Image.new("RGB", (1, 1)).save(image_path)
calls = []
def fake_convert_pdf_to_image(**kwargs):
calls.append(kwargs)
return [str(image_path)]
monkeypatch.setattr(layout, "convert_pdf_to_image", fake_convert_pdf_to_image)
page = layout.PageLayout(
number=1,
image=Image.new("RGB", (1, 1)),
document_filename="mock.pdf",
pdf_render_max_pixels_per_page=None,
)
image = page._get_image("mock.pdf", 1, pdf_image_dpi=123)
assert image.size == (1, 1)
assert calls[0]["dpi"] == 123
assert calls[0]["pdf_render_max_pixels_per_page"] is None
def test_convert_pdf_to_image_output_folder_returns_images(tmp_path):
result = layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=tmp_path,
path_only=False,
)
assert len(result) == 1
assert isinstance(result[0], Image.Image)
saved = list(tmp_path.glob("*.png"))
assert len(saved) == 1
def test_convert_pdf_to_image_path_only(tmp_path):
result = layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=tmp_path,
path_only=True,
)
assert len(result) == 1
assert all(isinstance(p, str) for p in result)
for p in result:
assert os.path.exists(p)
assert p.endswith(".png")
saved = sorted(tmp_path.glob("*.png"))
assert [str(s) for s in saved] == sorted(result)
def test_convert_pdf_to_image_applies_rotation_path_only(tmp_path):
"""Rotation is also applied when saving to disk (path_only mode)."""
result = layout.convert_pdf_to_image(
filename="sample-docs/rotated-page-90.pdf",
dpi=72,
output_folder=tmp_path,
path_only=True,
)
assert len(result) == 1
saved = Image.open(result[0])
assert saved.height > saved.width, f"Expected portrait after rotation, got {saved.size}"
def test_convert_pdf_to_image_no_rotation_on_normal_pdf():
"""Non-rotated PDFs are unchanged."""
result = layout.convert_pdf_to_image(filename="sample-docs/loremipsum.pdf", dpi=72)
assert len(result) == 1
img = result[0]
# loremipsum.pdf is a standard portrait page - should stay portrait
assert img.height > img.width, f"Expected portrait, got {img.size}"
def test_convert_pdf_to_image_save_not_under_pdfium_lock(tmp_path):
"""Verify that PIL save (disk I/O) is NOT performed while holding _pdfium_lock."""
original_save = Image.Image.save
lock_held_during_save = []
def spy_save(self, *args, **kwargs):
lock_held_during_save.append(layout._pdfium_lock.locked())
return original_save(self, *args, **kwargs)
with patch.object(Image.Image, "save", spy_save):
layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=tmp_path,
path_only=True,
)
assert lock_held_during_save, "save was never called"
assert not any(lock_held_during_save), "pil_image.save() was called while _pdfium_lock was held"
def test_convert_pdf_to_image_concurrent_saves_not_serialized(tmp_path):
"""Two concurrent callers must be able to overlap their disk writes.
Uses a threading.Barrier to verify both threads are inside save()
simultaneously. If saves are serialized under _pdfium_lock, the second
thread can never reach save() while the first is there, so the barrier
times out and the test fails.
"""
import threading
original_save = Image.Image.save
barrier = threading.Barrier(2, timeout=5)
overlap_detected = threading.Event()
def barrier_save(self, *args, **kwargs):
try:
barrier.wait()
overlap_detected.set()
except threading.BrokenBarrierError:
pass
return original_save(self, *args, **kwargs)
errors: list[str] = []
def run(folder):
try:
layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=folder,
path_only=True,
)
except Exception as exc:
errors.append(str(exc))
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
with patch.object(Image.Image, "save", barrier_save):
t1 = threading.Thread(target=run, args=(dir_a,))
t2 = threading.Thread(target=run, args=(dir_b,))
t1.start()
t2.start()
t1.join(timeout=10)
t2.join(timeout=10)
assert not errors, f"threads raised: {errors}"
assert overlap_detected.is_set(), (
"saves were serialized under _pdfium_lock — threads could not overlap"
)
assert list(dir_a.glob("*.png")), "thread A produced no output"
assert list(dir_b.glob("*.png")), "thread B produced no output"
def test_render_can_proceed_while_other_thread_saves(tmp_path):
"""Thread B can acquire _pdfium_lock and render while thread A is in save().
Blocks thread A inside save() (outside the lock), then starts thread B.
If B completes entirely while A is still blocked, the lock was not held
during save — rendering and saving can overlap across callers.
"""
import threading
original_save = Image.Image.save
a_in_save = threading.Event()
b_done = threading.Event()
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
def gated_save(self, *args, **kwargs):
fp = str(args[0]) if args else ""
if str(dir_a) in fp:
a_in_save.set()
b_done.wait(timeout=5)
return original_save(self, *args, **kwargs)
errors: list[str] = []
def run(folder, done_event=None):
try:
layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=folder,
path_only=True,
)
except Exception as exc:
errors.append(str(exc))
finally:
if done_event:
done_event.set()
with patch.object(Image.Image, "save", gated_save):
t_a = threading.Thread(target=run, args=(dir_a,))
t_b = threading.Thread(target=run, args=(dir_b, b_done))
t_a.start()
a_in_save.wait(timeout=5)
# A is now blocked in save (outside lock). B should render + save freely.
t_b.start()
t_b.join(timeout=10)
t_a.join(timeout=10)
assert not errors, f"threads raised: {errors}"
assert b_done.is_set(), "Thread B could not complete while A was saving"
assert list(dir_a.glob("*.png")), "thread A produced no output"
assert list(dir_b.glob("*.png")), "thread B produced no output"
def test_multi_page_concurrent_output_complete(tmp_path):
"""Two threads processing a multi-page PDF both produce correct, complete output."""
import threading
errors: list[str] = []
def run(folder):
try:
layout.convert_pdf_to_image(
filename="sample-docs/loremipsum_multipage.pdf",
dpi=72,
output_folder=folder,
path_only=True,
)
except Exception as exc:
errors.append(str(exc))
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
t1 = threading.Thread(target=run, args=(dir_a,))
t2 = threading.Thread(target=run, args=(dir_b,))
t1.start()
t2.start()
t1.join(timeout=60)
t2.join(timeout=60)
assert not errors, f"threads raised: {errors}"
a_files = sorted(dir_a.glob("*.png"))
b_files = sorted(dir_b.glob("*.png"))
assert len(a_files) == 10, f"thread A produced {len(a_files)} files, expected 10"
assert len(b_files) == 10, f"thread B produced {len(b_files)} files, expected 10"
for i in range(1, 11):
assert (dir_a / f"page_{i}.png").exists(), f"thread A missing page_{i}.png"
assert (dir_b / f"page_{i}.png").exists(), f"thread B missing page_{i}.png"
def test_error_in_one_thread_does_not_block_other(tmp_path):
"""If one thread fails mid-processing, the other still completes."""
import threading
original_save = Image.Image.save
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
def failing_save(self, *args, **kwargs):
fp = str(args[0]) if args else ""
if str(dir_a) in fp:
raise OSError("simulated disk failure")
return original_save(self, *args, **kwargs)
a_error: list[Exception] = []
b_result: list[str] = []
b_error: list[Exception] = []
def run_a():
try:
layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=dir_a,
path_only=True,
)
except Exception as exc:
a_error.append(exc)
def run_b():
try:
result = layout.convert_pdf_to_image(
filename="sample-docs/loremipsum.pdf",
dpi=72,
output_folder=dir_b,
path_only=True,
)
b_result.extend(result)
except Exception as exc:
b_error.append(exc)
with patch.object(Image.Image, "save", failing_save):
t_a = threading.Thread(target=run_a)
t_b = threading.Thread(target=run_b)
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert a_error, "Thread A should have failed"
assert not b_error, f"Thread B should have succeeded: {b_error}"
assert b_result, "Thread B produced no result"
assert list(dir_b.glob("*.png")), "Thread B produced no output files"
@pytest.mark.parametrize(
("filename", "img_num", "should_complete"),
[
("sample-docs/empty-document.pdf", 0, True),
("sample-docs/empty-document.pdf", 10, False),
],
)
def test_get_image(filename, img_num, should_complete):
doc = layout.DocumentLayout.from_file(filename)
page = doc.pages[0]
try:
img = page._get_image(filename, img_num)
# transform img to numpy array
img = np.array(img)
# is a blank image with all pixels white
assert img.mean() == 255.0
except ValueError:
assert not should_complete
================================================
FILE: test_unstructured_inference/inference/test_layout_element.py
================================================
from unstructured_inference.constants import IsExtracted, Source
from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion
def test_layout_element_to_dict(mock_layout_element):
expected = {
"coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)),
"text": "Sample text",
"is_extracted": None,
"type": "Text",
"prob": None,
"source": None,
}
assert mock_layout_element.to_dict() == expected
def test_layout_element_from_region(mock_rectangle):
expected = LayoutElement.from_coords(100, 100, 300, 300)
region = TextRegion(bbox=mock_rectangle)
assert LayoutElement.from_region(region) == expected
def test_layoutelement_inheritance_works_correctly():
"""Test that LayoutElement properly inherits from TextRegion without conflicts"""
from unstructured_inference.inference.elements import TextRegion
# Create a TextRegion with both source and text_source
region = TextRegion.from_coords(
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
)
# Convert to LayoutElement
element = LayoutElement.from_region(region)
# Check that both properties are preserved
assert element.source == Source.YOLOX, "LayoutElement should inherit source from TextRegion"
assert element.is_extracted == IsExtracted.TRUE, (
"LayoutElement should inherit is_extracted from TextRegion"
)
# Check that to_dict() works correctly
d = element.to_dict()
assert d["source"] == Source.YOLOX
assert d["is_extracted"] == IsExtracted.TRUE
# Check that we can set source directly on LayoutElement
element.source = Source.DETECTRON2_ONNX
assert element.source == Source.DETECTRON2_ONNX
================================================
FILE: test_unstructured_inference/inference/test_layout_rotation.py
================================================
from __future__ import annotations
import numpy as np
from unstructured_inference.inference import pdf_image
def test_convert_pdf_to_image_applies_rotation():
"""Pages with /Rotate metadata are rendered upright."""
result = pdf_image.convert_pdf_to_image(filename="sample-docs/rotated-page-90.pdf", dpi=72)
assert len(result) == 1
img = result[0]
# The PDF has /Rotate=90 on a landscape page (width > height in PDF units).
# Without rotation fix the rendered image would be landscape; with the fix it's portrait.
assert img.height > img.width, f"Expected portrait after rotation, got {img.size}"
# Fixture contract: rotated-page-90.pdf has visible dark text in the upper half when upright.
# Use relative dark-pixel counts to reduce sensitivity to minor renderer differences.
gray = np.array(img.convert("L"))
split = gray.shape[0] // 2
top_dark_pixels = int(np.count_nonzero(gray[:split] < 245))
bottom_dark_pixels = int(np.count_nonzero(gray[split:] < 245))
assert top_dark_pixels > 0, "Expected text pixels in upper half of upright page"
assert top_dark_pixels > max(bottom_dark_pixels * 10, 50), (
"Expected substantially more dark pixels in upper half for upright orientation; "
f"got top={top_dark_pixels}, bottom={bottom_dark_pixels}"
)
================================================
FILE: test_unstructured_inference/models/test_detectron2onnx.py
================================================
import os
from unittest.mock import patch
import pytest
from PIL import Image
import unstructured_inference.models.base as models
import unstructured_inference.models.detectron2onnx as detectron2
class MockDetectron2ONNXLayoutModel:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def run(self, *args):
return ([(1, 2, 3, 4)], [0], [(4, 5)], [0.818])
def get_inputs(self):
class input_thing:
name = "Bernard"
return [input_thing()]
def test_load_default_model(monkeypatch):
monkeypatch.setattr(models, "models", {})
with patch.object(
detectron2.onnxruntime,
"InferenceSession",
new=MockDetectron2ONNXLayoutModel,
):
model = models.get_model("detectron2_mask_rcnn")
assert isinstance(model.model, MockDetectron2ONNXLayoutModel)
@pytest.mark.parametrize(("model_path", "label_map"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")])
def test_load_model(model_path, label_map):
with patch.object(detectron2.onnxruntime, "InferenceSession", return_value=True):
model = detectron2.UnstructuredDetectronONNXModel()
model.initialize(model_path=model_path, label_map=label_map)
args, _ = detectron2.onnxruntime.InferenceSession.call_args
assert args == (model_path,)
assert label_map == model.label_map
def test_unstructured_detectron_model():
model = detectron2.UnstructuredDetectronONNXModel()
model.model = 1
with patch.object(detectron2.UnstructuredDetectronONNXModel, "predict", return_value=[]):
result = model(None)
assert isinstance(result, list)
assert len(result) == 0
def test_inference():
with patch.object(
detectron2.onnxruntime,
"InferenceSession",
return_value=MockDetectron2ONNXLayoutModel(),
):
model = detectron2.UnstructuredDetectronONNXModel()
model.initialize(model_path="test_path", label_map={0: "test_class"})
assert isinstance(model.model, MockDetectron2ONNXLayoutModel)
with open(os.path.join("sample-docs", "receipt-sample.jpg"), mode="rb") as fp:
image = Image.open(fp)
image.load()
elements = model(image)
assert len(elements) == 1
element = elements[0]
(x1, y1), _, (x2, y2), _ = element.bbox.coordinates
assert hasattr(
element,
"prob",
) # NOTE(pravin) New Assertion to Make Sure element has probabilities
assert isinstance(
element.prob,
float,
) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float
# NOTE(alan): The bbox coordinates get resized, so check their relative proportions
assert x2 / x1 == pytest.approx(3.0) # x1 == 1, x2 == 3 before scaling
assert y2 / y1 == pytest.approx(2.0) # y1 == 2, y2 == 4 before scaling
assert element.type == "test_class"
================================================
FILE: test_unstructured_inference/models/test_eval.py
================================================
import pytest
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
from unstructured_inference.models.eval import compare_contents_as_df, default_tokenizer
@pytest.fixture
def actual_cells():
return [
{
"column_nums": [0],
"row_nums": [0, 1],
"column header": True,
"cell text": "Disability Category",
},
{
"column_nums": [1],
"row_nums": [0, 1],
"column header": True,
"cell text": "Participants",
},
{
"column_nums": [2],
"row_nums": [0, 1],
"column header": True,
"cell text": "Ballots Completed",
},
{
"column_nums": [3],
"row_nums": [0, 1],
"column header": True,
"cell text": "Ballots Incomplete/Terminated",
},
{"column_nums": [4, 5], "row_nums": [0], "column header": True, "cell text": "Results"},
{"column_nums": [4], "row_nums": [1], "column header": False, "cell text": "Accuracy"},
{
"column_nums": [5],
"row_nums": [1],
"column header": False,
"cell text": "Time to complete",
},
{"column_nums": [0], "row_nums": [2], "column header": False, "cell text": "Blind"},
{"column_nums": [0], "row_nums": [3], "column header": False, "cell text": "Low Vision"},
{"column_nums": [0], "row_nums": [4], "column header": False, "cell text": "Dexterity"},
{"column_nums": [0], "row_nums": [5], "column header": False, "cell text": "Mobility"},
{"column_nums": [1], "row_nums": [2], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [3], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [4], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [5], "column header": False, "cell text": "3"},
{"column_nums": [2], "row_nums": [2], "column header": False, "cell text": "1"},
{"column_nums": [2], "row_nums": [3], "column header": False, "cell text": "2"},
{"column_nums": [2], "row_nums": [4], "column header": False, "cell text": "4"},
{"column_nums": [2], "row_nums": [5], "column header": False, "cell text": "3"},
{"column_nums": [3], "row_nums": [2], "column header": False, "cell text": "4"},
{"column_nums": [3], "row_nums": [3], "column header": False, "cell text": "3"},
{"column_nums": [3], "row_nums": [4], "column header": False, "cell text": "1"},
{"column_nums": [3], "row_nums": [5], "column header": False, "cell text": "0"},
{"column_nums": [4], "row_nums": [2], "column header": False, "cell text": "34.5%, n=1"},
{
"column_nums": [4],
"row_nums": [3],
"column header": False,
"cell text": "98.3% n=2 (97.7%, n=3)",
},
{"column_nums": [4], "row_nums": [4], "column header": False, "cell text": "98.3%, n=4"},
{"column_nums": [4], "row_nums": [5], "column header": False, "cell text": "95.4%, n=3"},
{"column_nums": [5], "row_nums": [2], "column header": False, "cell text": "1199 sec, n=1"},
{
"column_nums": [5],
"row_nums": [3],
"column header": False,
"cell text": "1716 sec, n=3 (1934 sec, n=2)",
},
{
"column_nums": [5],
"row_nums": [4],
"column header": False,
"cell text": "1672.1 sec, n=4",
},
{"column_nums": [5], "row_nums": [5], "column header": False, "cell text": "1416 sec, n=3"},
]
@pytest.fixture
def pred_cells():
return [
{"column_nums": [0], "row_nums": [2], "column header": False, "cell text": "Blind"},
{"column_nums": [0], "row_nums": [3], "column header": False, "cell text": "Low Vision"},
{"column_nums": [0], "row_nums": [4], "column header": False, "cell text": "Dexterity"},
{"column_nums": [0], "row_nums": [5], "column header": False, "cell text": "Mobility"},
{"column_nums": [1], "row_nums": [2], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [3], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [4], "column header": False, "cell text": "5"},
{"column_nums": [1], "row_nums": [5], "column header": False, "cell text": "3"},
{"column_nums": [2], "row_nums": [2], "column header": False, "cell text": "1"},
{"column_nums": [2], "row_nums": [3], "column header": False, "cell text": "2"},
{"column_nums": [2], "row_nums": [4], "column header": False, "cell text": "4"},
{"column_nums": [2], "row_nums": [5], "column header": False, "cell text": "3"},
{"column_nums": [3], "row_nums": [2], "column header": False, "cell text": "4"},
{"column_nums": [3], "row_nums": [3], "column header": False, "cell text": "3"},
{"column_nums": [3], "row_nums": [4], "column header": False, "cell text": "1"},
{"column_nums": [3], "row_nums": [5], "column header": False, "cell text": "0"},
{"column_nums": [4], "row_nums": [1], "column header": False, "cell text": "Accuracy"},
{"column_nums": [4], "row_nums": [2], "column header": False, "cell text": "34.5%, n=1"},
{
"column_nums": [4],
"row_nums": [3],
"column header": False,
"cell text": "98.3% n=2 (97.7%, n=3)",
},
{"column_nums": [4], "row_nums": [4], "column header": False, "cell text": "98.3%, n=4"},
{"column_nums": [4], "row_nums": [5], "column header": False, "cell text": "95.4%, n=3"},
{
"column_nums": [5],
"row_nums": [1],
"column header": False,
"cell text": "Time to complete",
},
{"column_nums": [5], "row_nums": [2], "column header": False, "cell text": "1199 sec, n=1"},
{
"column_nums": [5],
"row_nums": [3],
"column header": False,
"cell text": "1716 sec, n=3 | (1934 sec, n=2)",
},
{
"column_nums": [5],
"row_nums": [4],
"column header": False,
"cell text": "1672.1 sec, n=4",
},
{"column_nums": [5], "row_nums": [5], "column header": False, "cell text": "1416 sec, n=3"},
{
"column_nums": [0],
"row_nums": [0, 1],
"column header": True,
"cell text": "soa etealeiliay Category",
},
{"column_nums": [4, 5], "row_nums": [0], "column header": True, "cell text": "Results"},
{
"column_nums": [1],
"row_nums": [0, 1],
"column header": True,
"cell text": "Participants P",
},
{
"column_nums": [2],
"row_nums": [0, 1],
"column header": True,
"cell text": "pallets Completed",
},
{
"column_nums": [3],
"row_nums": [0, 1],
"column header": True,
"cell text": "Ballot: incom lete/ Ne Terminated",
},
]
@pytest.fixture
def actual_df(actual_cells):
return table_cells_to_dataframe(actual_cells).fillna("")
@pytest.fixture
def pred_df(pred_cells):
return table_cells_to_dataframe(pred_cells).fillna("")
@pytest.mark.parametrize(
("eval_func", "processor"),
[
("token_ratio", default_tokenizer),
("token_ratio", None),
("partial_token_ratio", default_tokenizer),
("ratio", None),
("ratio", default_tokenizer),
("partial_ratio", default_tokenizer),
],
)
def test_compare_content_as_df(actual_df, pred_df, eval_func, processor):
results = compare_contents_as_df(actual_df, pred_df, eval_func=eval_func, processor=processor)
assert 0 < results.get(f"by_col_{eval_func}") < 100
def test_compare_content_as_df_with_invalid_input(actual_df, pred_df):
with pytest.raises(ValueError, match="eval_func must be one of"):
compare_contents_as_df(actual_df, pred_df, eval_func="foo")
================================================
FILE: test_unstructured_inference/models/test_model.py
================================================
import json
import threading
import time
from typing import Any
from unittest import mock
import numpy as np
import pytest
import unstructured_inference.models.base as models
from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements
from unstructured_inference.models.unstructuredmodel import (
ModelNotInitializedError,
UnstructuredObjectDetectionModel,
)
class MockModel(UnstructuredObjectDetectionModel):
call_count = 0
def __init__(self):
self.initializer = mock.MagicMock()
super().__init__()
def initialize(self, *args, **kwargs):
return self.initializer(self, *args, **kwargs)
def predict(self, x: Any) -> Any:
return LayoutElements(element_coords=np.array([]))
MOCK_MODEL_TYPES = {
"foo": {
"input_shape": (640, 640),
},
}
def test_get_model(monkeypatch):
monkeypatch.setattr(models, "models", {})
with mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
assert isinstance(models.get_model("yolox"), MockModel)
def test_get_model_threaded(monkeypatch):
"""Test that get_model works correctly when called from multiple threads simultaneously."""
monkeypatch.setattr(models, "models", {})
# Results and exceptions from threads will be stored here
results = []
exceptions = []
def get_model_worker(thread_id):
"""Worker function for each thread."""
try:
model = models.get_model("yolox")
results.append((thread_id, model))
except Exception as e:
exceptions.append((thread_id, e))
# Create and start multiple threads
num_threads = 10
threads = []
with mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
for i in range(num_threads):
thread = threading.Thread(target=get_model_worker, args=(i,))
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify no exceptions occurred
assert len(exceptions) == 0, f"Exceptions occurred in threads: {exceptions}"
# Verify all threads got results
assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}"
# Verify all results are MockModel instances
for thread_id, model in results:
assert isinstance(model, MockModel), (
f"Thread {thread_id} got unexpected model type: {type(model)}"
)
def test_get_model_concurrent_different_models(monkeypatch):
"""Test that different models can load in parallel without serialization."""
monkeypatch.setattr(models, "models", {})
# Track initialization timing
init_events = []
init_lock = threading.Lock()
class SlowMockModel(MockModel):
def __init__(self):
super().__init__()
self.model_name = None
def initialize(self, *args, **kwargs):
with init_lock:
init_events.append((self.model_name, "start"))
time.sleep(0.1) # Simulate slow loading
with init_lock:
init_events.append((self.model_name, "end"))
return super().initialize(*args, **kwargs)
# Store model names in instances
def create_model_with_name(name):
def factory():
model = SlowMockModel()
model.model_name = name
return model
return factory
results = []
def worker(model_name):
models.get_model(model_name) # Load the model
results.append(model_name)
# Load 2 different models concurrently
threads = []
mock_config = {"input_shape": (640, 640)}
with (
mock.patch.dict(
models.model_class_map,
{
"yolox": create_model_with_name("yolox"),
"detectron2": create_model_with_name("detectron2"),
},
),
mock.patch.dict(models.model_config_map, {"yolox": mock_config, "detectron2": mock_config}),
):
for model_name in ["yolox", "detectron2"]:
thread = threading.Thread(target=worker, args=(model_name,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Both models should load successfully
assert len(results) == 2
# Verify parallel execution (both start before either ends)
assert len(init_events) == 4, f"Expected 4 events (2 starts + 2 ends), got {len(init_events)}"
# True parallelism means both models start before either finishes
# Find when the first model finishes
first_end_idx = next(
(i for i, (_, event_type) in enumerate(init_events) if event_type == "end"), None
)
assert first_end_idx is not None, "No 'end' event found"
# Count how many models started before the first one finished
starts_before_first_end = sum(
1 for _, event_type in init_events[:first_end_idx] if event_type == "start"
)
assert starts_before_first_end == 2, (
f"Expected both models to start before either finishes (parallel execution), "
f"but only {starts_before_first_end} started before first completion. "
f"Events: {init_events}"
)
def test_register_new_model():
assert "foo" not in models.model_class_map
assert "foo" not in models.model_config_map
models.register_new_model(MOCK_MODEL_TYPES, MockModel)
assert "foo" in models.model_class_map
assert "foo" in models.model_config_map
model = models.get_model("foo")
assert len(model.initializer.mock_calls) == 1
assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"]
assert isinstance(model, MockModel)
# unregister the new model by reset to default
models.model_class_map, models.model_config_map = models.get_default_model_mappings()
assert "foo" not in models.model_class_map
assert "foo" not in models.model_config_map
def test_get_model_with_lazydict_config(monkeypatch):
"""get_model must unpack a LazyDict config into initialize() without
depending on Mapping.keys() — prevents regression of
'argument after ** must be a mapping, not LazyDict' in prod.
"""
from unstructured_inference.utils import LazyDict, LazyEvaluateInfo
monkeypatch.setattr(models, "models", {})
evaluated = []
def _fake_download(path):
evaluated.append(path)
return path
lazy_config = LazyDict(
model_path=LazyEvaluateInfo(_fake_download, "/tmp/weights.onnx"),
input_shape=(640, 640),
)
with (
mock.patch.dict(models.model_class_map, {"lazy_mock": MockModel}),
mock.patch.dict(models.model_config_map, {"lazy_mock": lazy_config}),
):
model = models.get_model("lazy_mock")
assert isinstance(model, MockModel)
assert evaluated == ["/tmp/weights.onnx"]
model.initializer.assert_called_once_with(
model,
model_path="/tmp/weights.onnx",
input_shape=(640, 640),
)
def test_raises_invalid_model():
with pytest.raises(models.UnknownModelException):
models.get_model("fake_model")
def test_raises_uninitialized():
with pytest.raises(ModelNotInitializedError):
models.UnstructuredDetectronONNXModel().predict(None)
def test_model_initializes_once():
from unstructured_inference.inference import layout
with (
mock.patch.dict(models.model_class_map, {"yolox": MockModel}),
mock.patch.object(
models,
"models",
{},
),
):
doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf")
doc.pages[0].detection_model.initializer.assert_called_once()
def test_deduplicate_detected_elements():
import numpy as np
from unstructured_inference.inference.elements import intersections
from unstructured_inference.inference.layout import DocumentLayout
from unstructured_inference.models.base import get_model
model = get_model("yolox_quantized")
# model.confidence_threshold=0.5
file = "sample-docs/example_table.jpg"
doc = DocumentLayout.from_image_file(
file,
model,
)
known_elements = [e.bbox for e in doc.pages[0].elements if e.type != "UncategorizedText"]
# Compute intersection matrix
intersections_mtx = intersections(*known_elements)
# Get rid off diagonal (cause an element will always intersect itself)
np.fill_diagonal(intersections_mtx, False)
# Now all the elements should be False, because any intersection remains
assert not intersections_mtx.any()
def test_enhance_regions():
from unstructured_inference.inference.elements import Rectangle
from unstructured_inference.models.base import get_model
elements = [
LayoutElement(bbox=Rectangle(0, 0, 1, 1)),
LayoutElement(bbox=Rectangle(0.01, 0.01, 1.01, 1.01)),
LayoutElement(bbox=Rectangle(0.02, 0.02, 1.02, 1.02)),
LayoutElement(bbox=Rectangle(0.03, 0.03, 1.03, 1.03)),
LayoutElement(bbox=Rectangle(0.04, 0.04, 1.04, 1.04)),
LayoutElement(bbox=Rectangle(0.05, 0.05, 1.05, 1.05)),
LayoutElement(bbox=Rectangle(0.06, 0.06, 1.06, 1.06)),
LayoutElement(bbox=Rectangle(0.07, 0.07, 1.07, 1.07)),
LayoutElement(bbox=Rectangle(0.08, 0.08, 1.08, 1.08)),
LayoutElement(bbox=Rectangle(0.09, 0.09, 1.09, 1.09)),
LayoutElement(bbox=Rectangle(0.10, 0.10, 1.10, 1.10)),
]
model = get_model("yolox_tiny")
elements = model.enhance_regions(elements, 0.5)
assert len(elements) == 1
assert (
elements[0].bbox.x1,
elements[0].bbox.y1,
elements[0].bbox.x2,
elements[0].bbox.x2,
) == (
0,
0,
1.10,
1.10,
)
def test_clean_type():
from unstructured_inference.inference.layout import LayoutElement
from unstructured_inference.models.base import get_model
elements = [
LayoutElement.from_coords(
0.6,
0.6,
0.65,
0.65,
type="Table",
), # One little table nested inside all the others
LayoutElement.from_coords(0.5, 0.5, 0.7, 0.7, type="Table"), # One nested table
LayoutElement.from_coords(0, 0, 1, 1, type="Table"), # Big table
LayoutElement.from_coords(0.01, 0.01, 1.01, 1.01),
LayoutElement.from_coords(0.02, 0.02, 1.02, 1.02),
LayoutElement.from_coords(0.03, 0.03, 1.03, 1.03),
LayoutElement.from_coords(0.04, 0.04, 1.04, 1.04),
LayoutElement.from_coords(0.05, 0.05, 1.05, 1.05),
]
model = get_model("yolox_tiny")
elements = model.clean_type(elements, type_to_clean="Table")
assert len(elements) == 1
assert (
elements[0].bbox.x1,
elements[0].bbox.y1,
elements[0].bbox.x2,
elements[0].bbox.x2,
) == (0, 0, 1, 1)
def test_env_variables_override_default_model(monkeypatch):
# When an environment variable specifies a different default model and we call get_model with no
# args, we should get back the model the env var calls for
monkeypatch.setattr(models, "models", {})
with (
mock.patch.dict(
models.os.environ,
{"UNSTRUCTURED_DEFAULT_MODEL_NAME": "yolox"},
),
mock.patch.dict(models.model_class_map, {"yolox": MockModel}),
):
model = models.get_model()
assert isinstance(model, MockModel)
def test_env_variables_override_initialization_params(monkeypatch):
# When initialization params are specified in an environment variable, and we call get_model, we
# should see that the model was initialized with those params
monkeypatch.setattr(models, "models", {})
fake_label_map = {"1": "label1", "2": "label2"}
with (
mock.patch.dict(
models.os.environ,
{"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"},
),
mock.patch.object(models, "DEFAULT_MODEL", "fake"),
mock.patch.dict(
models.model_class_map,
{"fake": mock.MagicMock()},
),
mock.patch(
"builtins.open",
mock.mock_open(
read_data='{"model_path": "fakepath", "label_map": '
+ json.dumps(fake_label_map)
+ "}",
),
),
):
model = models.get_model()
model.initialize.assert_called_once_with(
model_path="fakepath",
label_map={1: "label1", 2: "label2"},
)
================================================
FILE: test_unstructured_inference/models/test_tables.py
================================================
import os
import threading
from copy import deepcopy
import numpy as np
import pytest
import torch
from PIL import Image
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerDecoder,
)
import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
from unstructured_inference.models.tables import (
apply_thresholds_on_objects,
structure_to_cells,
)
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
@pytest.fixture
def table_transformer():
tables.load_agent()
return tables.tables_agent
def test_load_agent(table_transformer):
assert hasattr(table_transformer, "model")
@pytest.fixture
def example_image():
return Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB")
@pytest.fixture
def mocked_ocr_tokens():
return [
{
"bbox": [51.0, 37.0, 1333.0, 38.0],
"block_num": 0,
"line_num": 0,
"span_num": 0,
"text": " ",
},
{
"bbox": [1064.0, 47.0, 1161.0, 71.0],
"block_num": 0,
"line_num": 0,
"span_num": 1,
"text": "Results",
},
{
"bbox": [891.0, 113.0, 1333.0, 114.0],
"block_num": 0,
"line_num": 0,
"span_num": 2,
"text": " ",
},
{
"bbox": [51.0, 236.0, 1333.0, 237.0],
"block_num": 0,
"line_num": 0,
"span_num": 3,
"text": " ",
},
{
"bbox": [51.0, 308.0, 1333.0, 309.0],
"block_num": 0,
"line_num": 0,
"span_num": 4,
"text": " ",
},
{
"bbox": [51.0, 450.0, 1333.0, 452.0],
"block_num": 0,
"line_num": 0,
"span_num": 5,
"text": " ",
},
{
"bbox": [51.0, 522.0, 1333.0, 524.0],
"block_num": 0,
"line_num": 0,
"span_num": 6,
"text": " ",
},
{
"bbox": [51.0, 37.0, 53.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 7,
"text": " ",
},
{
"bbox": [90.0, 89.0, 167.0, 93.0],
"block_num": 0,
"line_num": 0,
"span_num": 8,
"text": "soa",
},
{
"bbox": [684.0, 68.0, 762.0, 91.0],
"block_num": 0,
"line_num": 0,
"span_num": 9,
"text": "Ballot:",
},
{
"bbox": [69.0, 84.0, 196.0, 140.0],
"block_num": 0,
"line_num": 0,
"span_num": 10,
"text": "etealeiliay",
},
{
"bbox": [283.0, 109.0, 446.0, 132.0],
"block_num": 0,
"line_num": 0,
"span_num": 11,
"text": "Participants",
},
{
"bbox": [484.0, 84.0, 576.0, 140.0],
"block_num": 0,
"line_num": 0,
"span_num": 12,
"text": "pallets",
},
{
"bbox": [684.0, 75.0, 776.0, 132.0],
"block_num": 0,
"line_num": 0,
"span_num": 13,
"text": "incom",
},
{
"bbox": [788.0, 107.0, 853.0, 136.0],
"block_num": 0,
"line_num": 0,
"span_num": 14,
"text": "lete/",
},
{
"bbox": [68.0, 121.0, 191.0, 162.0],
"block_num": 0,
"line_num": 0,
"span_num": 15,
"text": "Category",
},
{
"bbox": [371.0, 115.0, 386.0, 137.0],
"block_num": 0,
"line_num": 0,
"span_num": 16,
"text": "P",
},
{
"bbox": [483.0, 121.0, 632.0, 162.0],
"block_num": 0,
"line_num": 0,
"span_num": 17,
"text": "Completed",
},
{
"bbox": [756.0, 115.0, 785.0, 154.0],
"block_num": 0,
"line_num": 0,
"span_num": 18,
"text": "Ne",
},
{
"bbox": [930.0, 125.0, 1054.0, 152.0],
"block_num": 0,
"line_num": 0,
"span_num": 19,
"text": "Accuracy",
},
{
"bbox": [1159.0, 124.0, 1227.0, 147.0],
"block_num": 0,
"line_num": 0,
"span_num": 20,
"text": "Time",
},
{
"bbox": [1235.0, 126.0, 1264.0, 147.0],
"block_num": 0,
"line_num": 0,
"span_num": 21,
"text": "to",
},
{
"bbox": [682.0, 149.0, 841.0, 173.0],
"block_num": 0,
"line_num": 0,
"span_num": 22,
"text": "Terminated",
},
{
"bbox": [1147.0, 169.0, 1276.0, 198.0],
"block_num": 0,
"line_num": 0,
"span_num": 23,
"text": "complete",
},
{
"bbox": [70.0, 245.0, 127.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 24,
"text": "Blind",
},
{
"bbox": [361.0, 247.0, 373.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 25,
"text": "5",
},
{
"bbox": [562.0, 247.0, 573.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 26,
"text": "1",
},
{
"bbox": [772.0, 247.0, 786.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 27,
"text": "4",
},
{
"bbox": [925.0, 246.0, 1005.0, 270.0],
"block_num": 0,
"line_num": 0,
"span_num": 28,
"text": "34.5%,",
},
{
"bbox": [1017.0, 247.0, 1059.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 29,
"text": "n=1",
},
{
"bbox": [1129.0, 246.0, 1187.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 30,
"text": "1199",
},
{
"bbox": [1197.0, 251.0, 1241.0, 270.0],
"block_num": 0,
"line_num": 0,
"span_num": 31,
"text": "sec,",
},
{
"bbox": [1253.0, 247.0, 1295.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 32,
"text": "n=1",
},
{
"bbox": [70.0, 319.0, 117.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 33,
"text": "Low",
},
{
"bbox": [125.0, 318.0, 198.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 34,
"text": "Vision",
},
{
"bbox": [361.0, 319.0, 373.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 35,
"text": "5",
},
{
"bbox": [561.0, 318.0, 573.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 36,
"text": "2",
},
{
"bbox": [773.0, 318.0, 785.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 37,
"text": "3",
},
{
"bbox": [928.0, 318.0, 1002.0, 339.0],
"block_num": 0,
"line_num": 0,
"span_num": 38,
"text": "98.3%",
},
{
"bbox": [1013.0, 318.0, 1055.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 39,
"text": "n=2",
},
{
"bbox": [1129.0, 318.0, 1188.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 40,
"text": "1716",
},
{
"bbox": [1197.0, 323.0, 1242.0, 342.0],
"block_num": 0,
"line_num": 0,
"span_num": 41,
"text": "sec,",
},
{
"bbox": [1253.0, 318.0, 1295.0, 338.0],
"block_num": 0,
"line_num": 0,
"span_num": 42,
"text": "n=3",
},
{
"bbox": [916.0, 387.0, 1005.0, 413.0],
"block_num": 0,
"line_num": 0,
"span_num": 43,
"text": "(97.7%,",
},
{
"bbox": [1016.0, 387.0, 1068.0, 413.0],
"block_num": 0,
"line_num": 0,
"span_num": 44,
"text": "n=3)",
},
{
"bbox": [1086.0, 383.0, 1099.0, 418.0],
"block_num": 0,
"line_num": 0,
"span_num": 45,
"text": "|",
},
{
"bbox": [1120.0, 387.0, 1188.0, 413.0],
"block_num": 0,
"line_num": 0,
"span_num": 46,
"text": "(1934",
},
{
"bbox": [1197.0, 393.0, 1241.0, 412.0],
"block_num": 0,
"line_num": 0,
"span_num": 47,
"text": "sec,",
},
{
"bbox": [1253.0, 387.0, 1305.0, 413.0],
"block_num": 0,
"line_num": 0,
"span_num": 48,
"text": "n=2)",
},
{
"bbox": [70.0, 456.0, 181.0, 489.0],
"block_num": 0,
"line_num": 0,
"span_num": 49,
"text": "Dexterity",
},
{
"bbox": [360.0, 461.0, 372.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 50,
"text": "5",
},
{
"bbox": [560.0, 461.0, 574.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 51,
"text": "4",
},
{
"bbox": [774.0, 461.0, 785.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 52,
"text": "1",
},
{
"bbox": [924.0, 460.0, 1005.0, 484.0],
"block_num": 0,
"line_num": 0,
"span_num": 53,
"text": "98.3%,",
},
{
"bbox": [1017.0, 461.0, 1060.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 54,
"text": "n=4",
},
{
"bbox": [1118.0, 460.0, 1199.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 55,
"text": "1672.1",
},
{
"bbox": [1209.0, 465.0, 1253.0, 484.0],
"block_num": 0,
"line_num": 0,
"span_num": 56,
"text": "sec,",
},
{
"bbox": [1265.0, 461.0, 1308.0, 480.0],
"block_num": 0,
"line_num": 0,
"span_num": 57,
"text": "n=4",
},
{
"bbox": [70.0, 527.0, 170.0, 561.0],
"block_num": 0,
"line_num": 0,
"span_num": 58,
"text": "Mobility",
},
{
"bbox": [361.0, 532.0, 373.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 59,
"text": "3",
},
{
"bbox": [561.0, 532.0, 573.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 60,
"text": "3",
},
{
"bbox": [773.0, 532.0, 786.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 61,
"text": "0",
},
{
"bbox": [924.0, 532.0, 1005.0, 556.0],
"block_num": 0,
"line_num": 0,
"span_num": 62,
"text": "95.4%,",
},
{
"bbox": [1017.0, 532.0, 1059.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 63,
"text": "n=3",
},
{
"bbox": [1129.0, 532.0, 1188.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 64,
"text": "1416",
},
{
"bbox": [1197.0, 537.0, 1242.0, 556.0],
"block_num": 0,
"line_num": 0,
"span_num": 65,
"text": "sec,",
},
{
"bbox": [1253.0, 532.0, 1295.0, 552.0],
"block_num": 0,
"line_num": 0,
"span_num": 66,
"text": "n=3",
},
{
"bbox": [266.0, 37.0, 267.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 67,
"text": " ",
},
{
"bbox": [466.0, 37.0, 468.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 68,
"text": " ",
},
{
"bbox": [666.0, 37.0, 668.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 69,
"text": " ",
},
{
"bbox": [891.0, 37.0, 893.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 70,
"text": " ",
},
{
"bbox": [1091.0, 113.0, 1093.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 71,
"text": " ",
},
{
"bbox": [51.0, 595.0, 1333.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 72,
"text": " ",
},
{
"bbox": [1331.0, 37.0, 1333.0, 596.0],
"block_num": 0,
"line_num": 0,
"span_num": 73,
"text": " ",
},
]
@pytest.mark.parametrize(
"model_path",
[
("invalid_table_path"),
("incorrect_table_path"),
],
)
def test_load_table_model_raises_when_not_available(model_path):
with pytest.raises(OSError):
table_model = tables.UnstructuredTableTransformerModel()
table_model.initialize(model=model_path)
@pytest.mark.parametrize(
("bbox1", "bbox2", "expected_result"),
[
((0, 0, 5, 5), (2, 2, 7, 7), 0.36),
((0, 0, 0, 0), (6, 6, 10, 10), 0),
],
)
def test_iob(bbox1, bbox2, expected_result):
result = tables.iob(bbox1, bbox2)
assert result == expected_result
@pytest.mark.parametrize(
"model_path",
[
"microsoft/table-transformer-structure-recognition",
],
)
def test_load_donut_model(model_path):
table_model = tables.UnstructuredTableTransformerModel()
table_model.initialize(model=model_path)
assert type(table_model.model.model.decoder) is TableTransformerDecoder
@pytest.mark.parametrize(
("input_test", "output_test"),
[
(
[
{
"label": "table column header",
"score": 0.9349299073219299,
"bbox": [
47.83147430419922,
116.8877944946289,
2557.79296875,
216.98883056640625,
],
},
{
"label": "table column header",
"score": 0.934,
"bbox": [
47.83147430419922,
116.8877944946289,
2557.79296875,
216.98883056640625,
],
},
],
[
{
"label": "table column header",
"score": 0.9349299073219299,
"bbox": [
47.83147430419922,
116.8877944946289,
2557.79296875,
216.98883056640625,
],
},
],
),
([], []),
],
)
def test_nms(input_test, output_test):
output = postprocess.nms(input_test)
assert output == output_test
@pytest.mark.parametrize(
("supercell1", "supercell2"),
[
(
{
"label": "table spanning cell",
"score": 0.526617169380188,
"bbox": [
1446.2801513671875,
1023.817138671875,
2114.3525390625,
1099.20166015625,
],
"projected row header": False,
"header": False,
"row_numbers": [3, 4],
"column_numbers": [0, 4],
},
{
"label": "table spanning cell",
"score": 0.5199193954467773,
"bbox": [
98.92312622070312,
676.1566772460938,
751.0982666015625,
938.5986938476562,
],
"projected row header": False,
"header": False,
"row_numbers": [3, 4, 6],
"column_numbers": [0, 4],
},
),
(
{
"label": "table spanning cell",
"score": 0.526617169380188,
"bbox": [
1446.2801513671875,
1023.817138671875,
2114.3525390625,
1099.20166015625,
],
"projected row header": False,
"header": False,
"row_numbers": [3, 4],
"column_numbers": [0, 4],
},
{
"label": "table spanning cell",
"score": 0.5199193954467773,
"bbox": [
98.92312622070312,
676.1566772460938,
751.0982666015625,
938.5986938476562,
],
"projected row header": False,
"header": False,
"row_numbers": [4],
"column_numbers": [0, 4, 6],
},
),
(
{
"label": "table spanning cell",
"score": 0.526617169380188,
"bbox": [
1446.2801513671875,
1023.817138671875,
2114.3525390625,
1099.20166015625,
],
"projected row header": False,
"header": False,
"row_numbers": [3, 4],
"column_numbers": [1, 4],
},
{
"label": "table spanning cell",
"score": 0.5199193954467773,
"bbox": [
98.92312622070312,
676.1566772460938,
751.0982666015625,
938.5986938476562,
],
"projected row header": False,
"header": False,
"row_numbers": [4],
"column_numbers": [0, 4, 6],
},
),
(
{
"label": "table spanning cell",
"score": 0.526617169380188,
"bbox": [
1446.2801513671875,
1023.817138671875,
2114.3525390625,
1099.20166015625,
],
"projected row header": False,
"header": False,
"row_numbers": [3, 4],
"column_numbers": [1, 4],
},
{
"label": "table spanning cell",
"score": 0.5199193954467773,
"bbox": [
98.92312622070312,
676.1566772460938,
751.0982666015625,
938.5986938476562,
],
"projected row header": False,
"header": False,
"row_numbers": [2, 4, 5, 6, 7, 8],
"column_numbers": [0, 4, 6],
},
),
],
)
def test_remove_supercell_overlap(supercell1, supercell2):
assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None
@pytest.mark.parametrize(
("supercells", "rows", "columns", "output_test"),
[
(
[
{
"label": "table spanning cell",
"score": 0.9,
"bbox": [
98.92312622070312,
143.11549377441406,
2115.197265625,
1238.27587890625,
],
"projected row header": True,
"header": True,
"span": True,
},
],
[
{
"label": "table row",
"score": 0.9299452900886536,
"bbox": [0, 0, 10, 10],
"column header": True,
"header": True,
},
{
"label": "table row",
"score": 0.9299452900886536,
"bbox": [
98.92312622070312,
143.11549377441406,
2114.3525390625,
193.67681884765625,
],
"column header": True,
"header": True,
},
{
"label": "table row",
"score": 0.9299452900886536,
"bbox": [
98.92312622070312,
143.11549377441406,
2114.3525390625,
193.67681884765625,
],
"column header": True,
"header": True,
},
],
[
{
"label": "table column",
"score": 0.9996132254600525,
"bbox": [
98.92312622070312,
143.11549377441406,
517.6508178710938,
1616.48779296875,
],
},
{
"label": "table column",
"score": 0.9935646653175354,
"bbox": [
520.0474853515625,
143.11549377441406,
751.0982666015625,
1616.48779296875,
],
},
],
[
{
"label": "table spanning cell",
"score": 0.9,
"bbox": [
98.92312622070312,
143.11549377441406,
751.0982666015625,
193.67681884765625,
],
"projected row header": True,
"header": True,
"span": True,
"row_numbers": [1, 2],
"column_numbers": [0, 1],
},
{
"row_numbers": [0],
"column_numbers": [0, 1],
"score": 0.9,
"propagated": True,
"bbox": [
98.92312622070312,
143.11549377441406,
751.0982666015625,
193.67681884765625,
],
},
],
),
],
)
def test_align_supercells(supercells, rows, columns, output_test):
assert postprocess.align_supercells(supercells, rows, columns) == output_test
@pytest.mark.parametrize(("rows", "bbox", "output"), [([1.0], [0.0], [1.0])])
def test_align_rows(rows, bbox, output):
assert postprocess.align_rows(rows, bbox) == output
@pytest.mark.parametrize(
("output_format", "expectation"),
[
("html", "| Blind | 5 | 1 | 4 | 34.5%, n=1 | "),
(
"cells",
{
"column_nums": [0],
"row_nums": [2],
"column header": False,
"cell text": "Blind",
},
),
("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]),
(None, "
| Blind | 5 | 1 | 4 | 34.5%, n=1 | "),
],
)
def test_table_prediction_output_format(
output_format,
expectation,
table_transformer,
example_image,
mocker,
example_table_cells,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
if output_format:
result = table_transformer.run_prediction(
example_image,
result_format=output_format,
ocr_tokens=mocked_ocr_tokens,
)
else:
result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens)
if output_format == "dataframe":
assert expectation in result.values
elif output_format == "cells":
# other output like bbox are flakey to test since they depend on OCR and it may change
# slightly when OCR pacakge changes or even on different machines
validation_fields = ("column_nums", "row_nums", "column header", "cell text")
assert expectation in [{key: cell[key] for key in validation_fields} for cell in result]
else:
assert expectation in result
def test_table_prediction_output_format_when_wrong_type_then_value_error(
table_transformer,
example_image,
mocker,
example_table_cells,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
with pytest.raises(ValueError):
table_transformer.run_prediction(
example_image,
result_format="Wrong format",
ocr_tokens=mocked_ocr_tokens,
)
def test_table_prediction_runs_with_empty_recognize(
table_transformer,
example_image,
mocker,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=[])
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
assert table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) == ""
def test_table_prediction_with_ocr_tokens(table_transformer, example_image, mocked_ocr_tokens):
prediction = table_transformer.predict(example_image, ocr_tokens=mocked_ocr_tokens)
assert '| ' in prediction
assert " |
|---|
| Blind | 5 | 1 | 4 | 34.5%, n=1 | " in prediction
def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
with pytest.raises(ValueError):
table_transformer.predict(example_image)
@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5}, 1),
({"0": 0.1}, 3),
({"0": 0.9}, 0),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5, "1": 0.1}, 4),
({"0": 0.1, "1": 0.9}, 3),
({"0": 0.9, "1": 0.5}, 1),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
{"label": "1", "score": 0.2},
{"label": "1", "score": 0.4},
{"label": "1", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
def test_objects_filtering_when_missing_threshold():
class_name = "class_name"
objects = [{"label": class_name, "score": 0.2}]
thresholds = {"1": 0.5}
with pytest.raises(KeyError, match=class_name):
apply_thresholds_on_objects(objects, thresholds)
def test_intersect():
a = postprocess.Rect()
b = postprocess.Rect([1, 2, 3, 4])
assert a.intersect(b).get_area() == 4.0
def test_include_rect():
a = postprocess.Rect()
assert a.include_rect([1, 2, 3, 4]).get_area() == 4.0
@pytest.mark.parametrize(
("spans", "join_with_space", "expected"),
[
(
[
{
"flags": 2**0,
"text": "5",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
],
True,
"",
),
(
[
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
],
True,
"p",
),
(
[
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
],
True,
"p p",
),
(
[
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 1,
},
],
True,
"p p",
),
(
[
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 0,
},
{
"flags": 2**0,
"text": "p",
"superscript": False,
"span_num": 0,
"line_num": 0,
"block_num": 1,
},
],
False,
"p p",
),
],
)
def test_extract_text_from_spans(spans, join_with_space, expected):
res = postprocess.extract_text_from_spans(
spans,
join_with_space=join_with_space,
remove_integer_superscripts=True,
)
assert res == expected
@pytest.mark.parametrize(
("supercells", "expected_len"),
[
([{"header": "hi", "row_numbers": [0, 1, 2], "score": 0.9}], 1),
(
[
{
"header": "hi",
"row_numbers": [0],
"column_numbers": [1, 2, 3],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [1],
"column_numbers": [1],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [1],
"column_numbers": [2],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [1],
"column_numbers": [3],
"score": 0.9,
},
],
4,
),
(
[
{
"header": "hi",
"row_numbers": [0],
"column_numbers": [0],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [1],
"column_numbers": [0],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [1, 2],
"column_numbers": [0],
"score": 0.9,
},
{
"header": "hi",
"row_numbers": [3],
"column_numbers": [0],
"score": 0.9,
},
],
3,
),
],
)
def test_header_supercell_tree(supercells, expected_len):
postprocess.header_supercell_tree(supercells)
assert len(supercells) == expected_len
@pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0])
def test_zoom_image(example_image, zoom):
width, height = example_image.size
new_image = tables.zoom_image(example_image, zoom)
new_w, new_h = new_image.size
if zoom <= 0:
zoom = 1
assert new_w == np.round(width * zoom, 0)
assert new_h == np.round(height * zoom, 0)
@pytest.mark.parametrize(
("input_cells", "expected_html"),
[
# +----------+---------------------+
# | row1col1 | row1col2 | row1col3 |
# |----------|----------+----------|
# | row2col1 | row2col2 | row2col3 |
# +----------+----------+----------+
pytest.param(
[
{
"row_nums": [0],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [0],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [0],
"column_nums": [2],
"cell text": "row1col3",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
(
"| row1col1 | row1col2 | row1col3 |
"
"| row2col1 | row2col2 | row2col3 |
"
),
id="simple table without header",
),
# +----------+---------------------+
# | h1col1 | h1col2 | h1col3 |
# |----------|----------+----------|
# | row1col1 | row1col2 | row1col3 |
# |----------|----------+----------|
# | row2col1 | row2col2 | row2col3 |
# +----------+----------+----------+
pytest.param(
[
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "row1col3",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
(
"| h1col1 | h1col2 | h1col2 |
"
"| row1col1 | row1col2 | row1col3 |
"
"| row2col1 | row2col2 | row2col3 |
"
),
id="simple table with header",
),
# +----------+---------------------+
# | h1col1 | h1col2 | h1col3 |
# |----------|----------+----------|
# | row1col1 | row1col2 | row1col3 |
# |----------|----------+----------|
# | row2col1 | row2col2 | row2col3 |
# +----------+----------+----------+
pytest.param(
[
{"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "row1col3",
"column header": False,
},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
],
(
"| h1col1 | h1col2 | h1col2 |
"
"| row1col1 | row1col2 | row1col3 |
"
"| row2col1 | row2col2 | row2col3 |
"
),
id="simple table with header, mixed elements",
),
# +----------+---------------------+
# | two | two columns |
# | |----------+----------|
# | rows |sub cell 1|sub cell 2|
# +----------+----------+----------+
pytest.param(
[
{
"row_nums": [0, 1],
"column_nums": [0],
"cell text": "two row",
"column header": False,
},
{
"row_nums": [0],
"column_nums": [1, 2],
"cell text": "two cols",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "sub cell 1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "sub cell 2",
"column header": False,
},
],
(
'| two row | two '
"cols |
| sub cell 1 | sub cell 2 |
"
"
"
),
id="various spans, no headers",
),
# +----------+---------------------+----------+
# | | h1col23 | h1col4 |
# | h12col1 |----------+----------+----------|
# | | h2col2 | h2col34 |
# |----------|----------+----------+----------+
# | r3col1 | r3col2 | |
# |----------+----------| r34col34 |
# | r4col12 | |
# +----------+----------+----------+----------+
pytest.param(
[
{
"row_nums": [0, 1],
"column_nums": [0],
"cell text": "h12col1",
"column header": True,
},
{
"row_nums": [0],
"column_nums": [1, 2],
"cell text": "h1col23",
"column header": True,
},
{"row_nums": [0], "column_nums": [3], "cell text": "h1col4", "column header": True},
{"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
{
"row_nums": [1],
"column_nums": [2, 3],
"cell text": "h2col34",
"column header": True,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "r3col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "r3col2",
"column header": False,
},
{
"row_nums": [2, 3],
"column_nums": [2, 3],
"cell text": "r34col34",
"column header": False,
},
{
"row_nums": [3],
"column_nums": [0, 1],
"cell text": "r4col12",
"column header": False,
},
],
(
'| h12col1 | '
'h1col23 | h1col4 |
'
'| h2col2 | h2col34 |
'
'| r3col1 | r3col2 | r34col34 |
'
'| r4col12 |
'
),
id="various spans, with 2 row header",
),
],
)
def test_cells_to_html(input_cells, expected_html):
assert tables.cells_to_html(input_cells) == expected_html
@pytest.mark.parametrize(
("input_cells", "expected_cells"),
[
pytest.param(
[
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "row1col3",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
[
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [2],
"cell text": "row1col3",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
id="identical tables, no changes expected",
),
pytest.param(
[
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
[
{"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
{"row_nums": [0], "column_nums": [1], "cell text": "", "column header": True},
{"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
{
"row_nums": [1],
"column_nums": [0],
"cell text": "row1col1",
"column header": False,
},
{
"row_nums": [1],
"column_nums": [1],
"cell text": "row1col2",
"column header": False,
},
{"row_nums": [1], "column_nums": [2], "cell text": "", "column header": False},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "row2col1",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [1],
"cell text": "row2col2",
"column header": False,
},
{
"row_nums": [2],
"column_nums": [2],
"cell text": "row2col3",
"column header": False,
},
],
id="missing column in header and in the middle",
),
pytest.param(
[
{
"row_nums": [0, 1],
"column_nums": [0],
"cell text": "h12col1",
"column header": True,
},
{
"row_nums": [0],
"column_nums": [1, 2],
"cell text": "h1col23",
"column header": True,
},
{"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
{
"row_nums": [1],
"column_nums": [2, 3],
"cell text": "h2col34",
"column header": True,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "r3col1",
"column header": False,
},
{
"row_nums": [2, 3],
"column_nums": [2, 3],
"cell text": "r34col34",
"column header": False,
},
{
"row_nums": [3],
"column_nums": [0, 1],
"cell text": "r4col12",
"column header": False,
},
],
[
{
"row_nums": [0, 1],
"column_nums": [0],
"cell text": "h12col1",
"column header": True,
},
{
"row_nums": [0],
"column_nums": [1, 2],
"cell text": "h1col23",
"column header": True,
},
{"row_nums": [0], "column_nums": [3], "cell text": "", "column header": True},
{"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
{
"row_nums": [1],
"column_nums": [2, 3],
"cell text": "h2col34",
"column header": True,
},
{
"row_nums": [2],
"column_nums": [0],
"cell text": "r3col1",
"column header": False,
},
{"row_nums": [2], "column_nums": [1], "cell text": "", "column header": False},
{
"row_nums": [2, 3],
"column_nums": [2, 3],
"cell text": "r34col34",
"column header": False,
},
{
"row_nums": [3],
"column_nums": [0, 1],
"cell text": "r4col12",
"column header": False,
},
],
id="missing column in header and in the middle in table with spans",
),
],
)
def test_fill_cells(input_cells, expected_cells):
def sort_cells(cells):
return sorted(cells, key=lambda x: (x["row_nums"], x["column_nums"]))
assert sort_cells(tables.fill_cells(input_cells)) == sort_cells(expected_cells)
def test_padded_results_has_right_dimensions(table_transformer, example_image):
str_class_name2idx = tables.get_class_map("structure")
# a simpler mapping so we keep all structure in the returned objs below for test
str_class_idx2name = dict.fromkeys(str_class_name2idx.values(), "table cell")
# pad size is no more than 10% of the original image so we can setup test below easier
pad = int(min(example_image.size) / 10)
structure = table_transformer.get_structure(example_image, pad_for_structure_detection=pad)
# boxes deteced OUTSIDE of the original image; this shouldn't happen but we want to make sure
# the code handles it as expected
structure["pred_boxes"][0][0, :2] = 0.5
structure["pred_boxes"][0][0, 2:] = 1.0
# mock a box we know are safly inside the original image with known positions
width, height = example_image.size
padded_width = width + pad * 2
padded_height = height + pad * 2
original = [1, 3, 101, 53]
structure["pred_boxes"][0][1, :] = torch.tensor(
[
(51 + pad) / padded_width,
(28 + pad) / padded_height,
100 / padded_width,
50 / padded_height,
],
)
objs = tables.outputs_to_objects(structure, example_image.size, str_class_idx2name)
np.testing.assert_almost_equal(objs[0]["bbox"], [-pad, -pad, width + pad, height + pad], 4)
np.testing.assert_almost_equal(objs[1]["bbox"], original, 4)
# a more strict test would be to constrain the actual detected boxes to be within the original
# image but that requires the table transformer to behave in certain ways and do not
# actually test the padding math; so here we use the relaxed condition
for obj in objs[2:]:
x1, y1, x2, y2 = obj["bbox"]
assert max(x1, x2) < width + pad
assert max(y1, y2) < height + pad
def test_compute_confidence_score_zero_division_error_handling():
assert tables.compute_confidence_score([]) == 0
@pytest.mark.parametrize(
("column_span_score", "row_span_score", "expected_text_to_indexes"),
[
(
0.9,
0.8,
(
{
"one three": {"row_nums": [0, 1], "column_nums": [0]},
"two": {"row_nums": [0], "column_nums": [1]},
"four": {"row_nums": [1], "column_nums": [1]},
}
),
),
(
0.8,
0.9,
(
{
"one two": {"row_nums": [0], "column_nums": [0, 1]},
"three": {"row_nums": [1], "column_nums": [0]},
"four": {"row_nums": [1], "column_nums": [1]},
}
),
),
],
)
def test_subcells_filtering_when_overlapping_spanning_cells(
column_span_score,
row_span_score,
expected_text_to_indexes,
):
"""
# table
# +-----------+----------+
# | one | two |
# |-----------+----------|
# | three | four |
# +-----------+----------+
spanning cells over first row and over first column
"""
table_structure = {
"rows": [
{"bbox": [0, 0, 10, 20]},
{"bbox": [10, 0, 20, 20]},
],
"columns": [
{"bbox": [0, 0, 20, 10]},
{"bbox": [0, 10, 20, 20]},
],
"spanning cells": [
{"bbox": [0, 0, 20, 10], "score": column_span_score},
{"bbox": [0, 0, 10, 20], "score": row_span_score},
],
}
tokens = [
{
"text": "one",
"bbox": [0, 0, 10, 10],
},
{
"text": "two",
"bbox": [0, 10, 10, 20],
},
{
"text": "three",
"bbox": [10, 0, 20, 10],
},
{"text": "four", "bbox": [10, 10, 20, 20]},
]
token_args = {"span_num": 1, "line_num": 1, "block_num": 1}
for token in tokens:
token.update(token_args)
for spanning_cell in table_structure["spanning cells"]:
spanning_cell["projected row header"] = False
# table structure is edited inside structure_to_cells, save copy for future runs
saved_table_structure = deepcopy(table_structure)
predicted_cells, _ = structure_to_cells(table_structure, tokens=tokens)
predicted_text_to_indexes = {
cell["cell text"]: {
"row_nums": cell["row_nums"],
"column_nums": cell["column_nums"],
}
for cell in predicted_cells
}
assert predicted_text_to_indexes == expected_text_to_indexes
# swap spanning cells to ensure the highest prob spanning cell is used
spans = saved_table_structure["spanning cells"]
spans[0], spans[1] = spans[1], spans[0]
saved_table_structure["spanning cells"] = spans
predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens)
assert predicted_cells_after_reorder == predicted_cells
def test_model_init_is_thread_safe():
threads = []
tables.tables_agent.model = None
for i in range(5):
thread = threading.Thread(target=tables.load_agent)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert tables.tables_agent.model is not None
================================================
FILE: test_unstructured_inference/models/test_yolox.py
================================================
import os
import pytest
from unstructured_inference.inference.layout import process_file_with_model
@pytest.mark.slow
def test_layout_yolox_local_parsing_image():
filename = os.path.join("sample-docs", "test-image.jpg")
# NOTE(benjamin) keep_output = True create a file for each image in
# localstorage for visualization of the result
document_layout = process_file_with_model(filename, model_name="yolox", is_image=True)
# NOTE(benjamin) The example image should result in one page result
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 13 detections
types_known = ["Text", "Section-header", "Page-header"]
elements = document_layout.pages[0].elements_array
known_regions = [
e for e in elements.element_class_ids if elements.element_class_id_map[e] in types_known
]
assert len(known_regions) == 13
# NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
assert hasattr(elements, "element_probs")
assert isinstance(
elements.element_probs[0],
float,
) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float
@pytest.mark.slow
def test_layout_yolox_local_parsing_pdf():
filename = os.path.join("sample-docs", "loremipsum.pdf")
document_layout = process_file_with_model(filename, model_name="yolox")
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 5 text detections
text_elements = [e for e in document_layout.pages[0].elements if e.type == "Text"]
assert len(text_elements) == 5
assert hasattr(
document_layout.pages[0].elements[0],
"prob",
) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
assert isinstance(
document_layout.pages[0].elements[0].prob,
float,
) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float
@pytest.mark.slow
def test_layout_yolox_local_parsing_empty_pdf():
filename = os.path.join("sample-docs", "empty-document.pdf")
document_layout = process_file_with_model(filename, model_name="yolox")
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 0 detections
assert len(document_layout.pages[0].elements) == 0
########################
# ONLY SHORT TESTS BELOW
########################
def test_layout_yolox_local_parsing_image_soft():
filename = os.path.join("sample-docs", "example_table.jpg")
# NOTE(benjamin) keep_output = True create a file for each image in
# localstorage for visualization of the result
document_layout = process_file_with_model(filename, model_name="yolox_quantized", is_image=True)
# NOTE(benjamin) The example image should result in one page result
assert len(document_layout.pages) == 1
# NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model
assert len(document_layout.pages[0].elements) > 0
assert hasattr(
document_layout.pages[0].elements[0],
"prob",
) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
assert isinstance(
document_layout.pages[0].elements[0].prob,
float,
) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float
def test_layout_yolox_local_parsing_pdf_soft():
filename = os.path.join("sample-docs", "loremipsum.pdf")
document_layout = process_file_with_model(filename, model_name="yolox_tiny")
assert len(document_layout.pages) == 1
# NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model
assert len(document_layout.pages[0].elements) > 0
assert hasattr(
document_layout.pages[0].elements[0],
"prob",
) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
def test_layout_yolox_local_parsing_empty_pdf_soft():
filename = os.path.join("sample-docs", "empty-document.pdf")
document_layout = process_file_with_model(filename, model_name="yolox_tiny")
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 0 detections
text_elements_page_1 = [el for el in document_layout.pages[0].elements if el.type != "Image"]
assert len(text_elements_page_1) == 0
================================================
FILE: test_unstructured_inference/test_config.py
================================================
def test_default_config():
from unstructured_inference.config import inference_config
assert inference_config.TT_TABLE_CONF == 0.5
def test_env_override(monkeypatch):
monkeypatch.setenv("TT_TABLE_CONF", 1)
from unstructured_inference.config import inference_config
assert inference_config.TT_TABLE_CONF == 1
================================================
FILE: test_unstructured_inference/test_elements.py
================================================
import os
from random import randint
from unittest.mock import PropertyMock, patch
import numpy as np
import pytest
from unstructured_inference.constants import IsExtracted, Source
from unstructured_inference.inference import elements
from unstructured_inference.inference.elements import (
Rectangle,
TextRegion,
TextRegions,
)
from unstructured_inference.inference.layoutelement import (
LayoutElements,
clean_layoutelements,
clean_layoutelements_for_class,
partition_groups_from_regions,
separate,
)
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
def intersect_brute(rect1, rect2):
return any(
(rect2.x1 <= x <= rect2.x2) and (rect2.y1 <= y <= rect2.y2)
for x in range(rect1.x1, rect1.x2 + 1)
for y in range(rect1.y1, rect1.y2 + 1)
)
def rand_rect(size=10):
x1 = randint(0, 30 - size)
y1 = randint(0, 30 - size)
return elements.Rectangle(x1, y1, x1 + size, y1 + size)
@pytest.fixture
def test_layoutelements():
coords = np.array(
[
[0.6, 0.6, 0.65, 0.65], # One little table nested inside all the others
[0.5, 0.5, 0.7, 0.7], # One nested table
[0, 0, 1, 1], # Big table
[0.01, 0.01, 0.09, 0.09],
[0.02, 0.02, 1.02, 1.02],
[0.03, 0.03, 1.03, 1.03],
[0.04, 0.04, 1.04, 1.04],
[0.05, 0.05, 1.05, 1.05],
[2, 2, 3, 3], # Big table
],
)
element_class_ids = np.array([1, 1, 1, 0, 0, 0, 0, 0, 2])
class_map = {0: "type0", 1: "type1", 2: "type2"}
return LayoutElements(
element_coords=coords,
element_class_ids=element_class_ids,
element_class_id_map=class_map,
source=Source.YOLOX,
)
@pytest.mark.parametrize(
("rect1", "rect2", "expected"),
[
(Rectangle(0, 0, 1, 1), Rectangle(0, 0, None, None), None),
(Rectangle(0, 0, None, None), Rectangle(0, 0, 1, 1), None),
],
)
def test_unhappy_intersection(rect1, rect2, expected):
assert rect1.intersection(rect2) == expected
assert not rect1.intersects(rect2)
@pytest.mark.parametrize("second_size", [10, 20])
def test_intersects(second_size):
for _ in range(1000):
rect1 = rand_rect()
rect2 = rand_rect(second_size)
assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1)
if rect1.intersects(rect2):
if rect1.is_in(rect2):
assert rect1.intersection(rect2) == rect1 == rect2.intersection(rect1)
elif rect2.is_in(rect1):
assert rect2.intersection(rect1) == rect2
else:
x1 = max(rect1.x1, rect2.x1)
x2 = min(rect1.x2, rect2.x2)
y1 = max(rect1.y1, rect2.y1)
y2 = min(rect1.y2, rect2.y2)
intersection = elements.Rectangle(x1, y1, x2, y2)
assert rect1.intersection(rect2) == intersection == rect2.intersection(rect1)
else:
assert rect1.intersection(rect2) is None
assert rect2.intersection(rect1) is None
def test_intersection_of_lots_of_rects():
for _ in range(1000):
n_rects = 10
rects = [rand_rect(6) for _ in range(n_rects)]
intersection_mtx = elements.intersections(*rects)
for i in range(n_rects):
for j in range(n_rects):
assert (
intersect_brute(rects[i], rects[j])
== intersection_mtx[i, j]
== intersection_mtx[j, i]
)
def test_rectangle_width_height():
for _ in range(1000):
x1 = randint(0, 50)
x2 = randint(x1 + 1, 100)
y1 = randint(0, 50)
y2 = randint(y1 + 1, 100)
rect = elements.Rectangle(x1, y1, x2, y2)
assert rect.width == x2 - x1
assert rect.height == y2 - y1
def test_minimal_containing_rect():
for _ in range(1000):
rect1 = rand_rect()
rect2 = rand_rect()
big_rect = elements.minimal_containing_region(rect1, rect2)
for decrease_attr in ["x1", "y1", "x2", "y2"]:
almost_as_big_rect = rand_rect()
mod = 1 if decrease_attr.endswith("1") else -1
for attr in ["x1", "y1", "x2", "y2"]:
if attr == decrease_attr:
setattr(almost_as_big_rect, attr, getattr(big_rect, attr) + mod)
else:
setattr(almost_as_big_rect, attr, getattr(big_rect, attr))
assert not rect1.is_in(almost_as_big_rect) or not rect2.is_in(almost_as_big_rect)
assert rect1.is_in(big_rect)
assert rect2.is_in(big_rect)
@pytest.mark.parametrize("coord_type", [int, float])
def test_partition_groups_from_regions(mock_embedded_text_regions, coord_type):
words = TextRegions.from_list(mock_embedded_text_regions)
words.element_coords = words.element_coords.astype(coord_type)
groups = partition_groups_from_regions(words)
assert len(groups) == 1
text = "".join(groups[-1].texts)
assert text.startswith("Layout")
# test backward compatibility
text = "".join([str(region) for region in groups[-1].as_list()])
assert text.startswith("Layout")
def test_rectangle_padding():
rect = Rectangle(x1=0, y1=1, x2=3, y2=4)
padded = rect.pad(1)
assert (padded.x1, padded.y1, padded.x2, padded.y2) == (-1, 0, 4, 5)
assert (rect.x1, rect.y1, rect.x2, rect.y2) == (0, 1, 3, 4)
def test_rectangle_area(monkeypatch):
for _ in range(1000):
width = randint(0, 20)
height = randint(0, 20)
with (
patch(
"unstructured_inference.inference.elements.Rectangle.height",
new_callable=PropertyMock,
) as mockheight,
patch(
"unstructured_inference.inference.elements.Rectangle.width",
new_callable=PropertyMock,
) as mockwidth,
):
rect = elements.Rectangle(0, 0, 0, 0)
mockheight.return_value = height
mockwidth.return_value = width
assert rect.area == width * height
def test_rectangle_iou():
for _ in range(1000):
rect1 = rand_rect()
assert rect1.intersection_over_union(rect1) == 1.0
rect2 = rand_rect(20)
assert rect1.intersection_over_union(rect2) == rect2.intersection_over_union(rect1)
if rect1.is_in(rect2):
assert rect1.intersection_over_union(rect2) == rect1.area / rect2.area
elif rect2.is_in(rect1):
assert rect1.intersection_over_union(rect2) == rect2.area / rect1.area
else:
if rect1.intersection(rect2) is None:
assert rect1.intersection_over_union(rect2) == 0.0
else:
intersection = rect1.intersection(rect2).area
assert rect1.intersection_over_union(rect2) == intersection / (
rect1.area + rect2.area - intersection
)
def test_midpoints():
for _ in range(1000):
x2 = randint(0, 100)
y2 = randint(0, 100)
rect1 = elements.Rectangle(0, 0, x2, y2)
assert rect1.x_midpoint == x2 / 2.0
assert rect1.y_midpoint == y2 / 2.0
x_offset = randint(0, 50)
y_offset = randint(0, 50)
rect2 = elements.Rectangle(x_offset, y_offset, x2 + x_offset, y2 + y_offset)
assert rect2.x_midpoint == (x2 / 2.0) + x_offset
assert rect2.y_midpoint == (y2 / 2.0) + y_offset
def test_is_disjoint():
for _ in range(1000):
a = randint(0, 100)
b = randint(a + 1, 200)
c = randint(b + 1, 300)
d = randint(c + 1, 400)
e = randint(0, 100)
f = randint(e, 200)
g = randint(0, 100)
h = randint(g, 200)
rect1 = elements.Rectangle(a, e, b, f)
rect2 = elements.Rectangle(c, g, d, h)
assert rect1.is_disjoint(rect2)
assert rect2.is_disjoint(rect1)
rect3 = elements.Rectangle(e, a, f, b)
rect4 = elements.Rectangle(g, c, h, d)
assert rect3.is_disjoint(rect4)
assert rect4.is_disjoint(rect3)
@pytest.mark.parametrize(
("rect1", "rect2", "expected"),
[
(elements.Rectangle(0, 0, 100, 200), elements.Rectangle(0, 0, 60, 150), 1.0),
(elements.Rectangle(0, 0, 100, 100), elements.Rectangle(150, 150, 200, 200), 0.0),
(elements.Rectangle(0, 0, 100, 100), elements.Rectangle(50, 50, 150, 150), 0.25),
(elements.Rectangle(0, 0, 100, 100), elements.Rectangle(20, 20, 120, 40), 0.8),
],
)
def test_intersection_over_min(
rect1: elements.Rectangle,
rect2: elements.Rectangle,
expected: float,
):
assert (
rect1.intersection_over_minimum(rect2) == rect2.intersection_over_minimum(rect1) == expected
)
def test_grow_region_to_match_region():
from unstructured_inference.inference.elements import (
Rectangle,
grow_region_to_match_region,
)
a = Rectangle(1, 1, 2, 2)
b = Rectangle(1, 1, 5, 5)
grow_region_to_match_region(a, b)
assert a == Rectangle(1, 1, 5, 5)
@pytest.mark.parametrize(
("rect1", "rect2", "expected"),
[
(elements.Rectangle(0, 0, 5, 5), elements.Rectangle(3, 3, 5.1, 5.1), True),
(elements.Rectangle(0, 0, 5, 5), elements.Rectangle(3, 3, 5.2, 5.2), True),
(elements.Rectangle(0, 0, 5, 5), elements.Rectangle(7, 7, 10, 10), False),
],
)
def test_is_almost_subregion_of(rect1, rect2, expected):
assert expected == rect2.is_almost_subregion_of(rect1)
@pytest.mark.parametrize(
("rect1", "rect2"),
[
(elements.Rectangle(0, 0, 5, 5), elements.Rectangle(3, 3, 6, 6)),
(elements.Rectangle(0, 0, 5, 5), elements.Rectangle(6, 6, 8, 8)),
(elements.Rectangle(3, 3, 7, 7), elements.Rectangle(2, 2, 4, 4)),
(elements.Rectangle(2, 2, 4, 11), elements.Rectangle(3, 3, 7, 10)),
(elements.Rectangle(2, 2, 4, 4), elements.Rectangle(3, 3, 7, 10)),
(elements.Rectangle(2, 2, 4, 4), elements.Rectangle(2.5, 2.5, 3.5, 4.5)),
(elements.Rectangle(2, 2, 4, 4), elements.Rectangle(3, 1, 4, 3.5)),
(elements.Rectangle(2, 2, 4, 4), elements.Rectangle(3, 1, 4.5, 3.5)),
],
)
def test_separate(rect1, rect2):
separate(rect1, rect2)
# assert not rect1.intersects(rect2) #TODO: fix this test
def test_clean_layoutelements(test_layoutelements):
elements = clean_layoutelements(test_layoutelements).as_list()
assert len(elements) == 2
assert (
elements[0].bbox.x1,
elements[0].bbox.y1,
elements[0].bbox.x2,
elements[0].bbox.x2,
) == (0, 0, 1, 1)
assert (
elements[1].bbox.x1,
elements[1].bbox.y1,
elements[1].bbox.x2,
elements[1].bbox.x2,
) == (2, 2, 3, 3)
assert elements[0].source == elements[1].source == Source.YOLOX
@pytest.mark.parametrize(
("coords", "class_ids", "expected_coords", "expected_ids"),
[
([[0, 0, 1, 1], [0, 0, 1, 1]], [0, 1], [[0, 0, 1, 1]], [0]), # one box
(
[[0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 2, 2]],
[0, 1, 0],
[[0, 0, 1, 1], [1, 1, 2, 2]],
[0, 0],
),
(
[[0, 0, 1.4, 1.4], [0, 0, 1, 1], [0.4, 0, 1.4, 1], [1.2, 0, 1.4, 1]],
[0, 1, 1, 1],
[[0, 0, 1.4, 1.4]],
[0],
),
],
)
def test_clean_layoutelements_cases(
coords,
class_ids,
expected_coords,
expected_ids,
):
coords = np.array(coords)
element_class_ids = np.array(class_ids)
elements = LayoutElements(element_coords=coords, element_class_ids=element_class_ids)
elements = clean_layoutelements(elements)
np.testing.assert_array_equal(elements.element_coords, expected_coords)
np.testing.assert_array_equal(elements.element_class_ids, expected_ids)
@pytest.mark.parametrize(
("coords", "class_ids", "class_to_filter", "expected_coords", "expected_ids"),
[
([[0, 0, 1, 1], [0, 0, 1, 1]], [0, 1], 1, [[0, 0, 1, 1]], [1]), # one box
(
[[0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 2, 2]], # one box
[0, 1, 0],
1,
[[0, 0, 1, 1], [1, 1, 2, 2]],
[1, 0],
),
(
# a -> b, b -> c, but a not -> c
[[0, 0, 1.4, 1.4], [0, 0, 1, 1], [0.4, 0, 1.4, 1], [1.2, 0, 1.4, 1]],
[0, 1, 1, 1],
1,
[[0, 0, 1, 1], [1.2, 0, 1.4, 1], [0, 0, 1.4, 1.4]],
[1, 1, 0],
),
(
# like the case above but a different filtering element type changes the results
[[0, 0, 1.4, 1.4], [0, 0, 1, 1], [0.4, 0, 1.4, 1], [1.2, 0, 1.4, 1]],
[0, 1, 1, 1],
0,
[[0, 0, 1.4, 1.4]],
[0],
),
],
)
def test_clean_layoutelements_for_class(
coords,
class_ids,
class_to_filter,
expected_coords,
expected_ids,
):
coords = np.array(coords)
element_class_ids = np.array(class_ids)
elements = LayoutElements(element_coords=coords, element_class_ids=element_class_ids)
elements = clean_layoutelements_for_class(elements, element_class=class_to_filter)
np.testing.assert_array_equal(elements.element_coords, expected_coords)
np.testing.assert_array_equal(elements.element_class_ids, expected_ids)
def test_layoutelements_to_list_and_back(test_layoutelements):
back = LayoutElements.from_list(test_layoutelements.as_list())
np.testing.assert_array_equal(test_layoutelements.element_coords, back.element_coords)
np.testing.assert_array_equal(test_layoutelements.texts, back.texts)
assert all(np.isnan(back.element_probs))
assert [
test_layoutelements.element_class_id_map[idx]
for idx in test_layoutelements.element_class_ids
] == [back.element_class_id_map[idx] for idx in back.element_class_ids]
def test_layoutelements_from_list_no_elements():
back = LayoutElements.from_list(elements=[])
assert back.sources.size == 0
assert back.source is None
assert back.element_coords.size == 0
def test_textregions_from_list_no_elements():
back = TextRegions.from_list(regions=[])
assert back.is_extracted_array.size == 0
assert back.is_extracted is None
assert back.element_coords.size == 0
def test_layoutelements_concatenate():
layout1 = LayoutElements(
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
texts=np.array(["a", "two"]),
source=Source.YOLOX,
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type0", 1: "type1"},
)
layout2 = LayoutElements(
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
texts=np.array(["three", "4"]),
sources=np.array([Source.DETECTRON2_ONNX, Source.DETECTRON2_ONNX]),
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type1", 1: "type2"},
)
joint = LayoutElements.concatenate([layout1, layout2])
assert joint.texts.tolist() == ["a", "two", "three", "4"]
assert [s.value for s in joint.sources.tolist()] == [
"yolox",
"yolox",
"detectron2_onnx",
"detectron2_onnx",
]
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}
@pytest.mark.parametrize(
"test_elements",
[
TextRegions(
element_coords=np.array(
[
[0.0, 0.0, 1.0, 1.0],
[1.0, 0.0, 1.5, 1.0],
[2.0, 0.0, 2.5, 1.0],
[3.0, 0.0, 4.0, 1.0],
[4.0, 0.0, 5.0, 1.0],
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
is_extracted_array=np.array([IsExtracted.TRUE] * 5),
is_extracted=IsExtracted.TRUE,
),
LayoutElements(
element_coords=np.array(
[
[0.0, 0.0, 1.0, 1.0],
[1.0, 0.0, 1.5, 1.0],
[2.0, 0.0, 2.5, 1.0],
[3.0, 0.0, 4.0, 1.0],
[4.0, 0.0, 5.0, 1.0],
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
sources=np.array([Source.YOLOX] * 5),
source=Source.YOLOX,
is_extracted_array=np.array([] * 5),
is_extracted=IsExtracted.TRUE,
element_probs=np.array([0.0, 0.1, 0.2, 0.3, 0.4]),
),
],
)
def test_textregions_support_numpy_slicing(test_elements):
np.testing.assert_equal(test_elements[1:4].texts, np.array(["1", "2", "3"]))
np.testing.assert_equal(test_elements[0::2].texts, np.array(["0", "2", "4"]))
np.testing.assert_equal(test_elements[[1, 2, 4]].texts, np.array(["1", "2", "4"]))
np.testing.assert_equal(test_elements[np.array([1, 2, 4])].texts, np.array(["1", "2", "4"]))
np.testing.assert_equal(
test_elements[np.array([True, False, False, True, False])].texts, np.array(["0", "3"])
)
if isinstance(test_elements, LayoutElements):
np.testing.assert_almost_equal(test_elements[1:4].element_probs, np.array([0.1, 0.2, 0.3]))
def test_textregions_from_list_collects_sources():
"""Test that TextRegions.from_list() collects both source and text_source from regions"""
from unstructured_inference.inference.elements import TextRegion
regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
TextRegion.from_coords(
10,
10,
20,
20,
text="second",
source=Source.DETECTRON2_ONNX,
is_extracted=IsExtracted.TRUE,
),
]
text_regions = TextRegions.from_list(regions)
# This should fail because from_list() doesn't collect sources
assert text_regions.sources.size > 0, "sources array should not be empty"
assert text_regions.sources[0] == Source.YOLOX
assert text_regions.sources[1] == Source.DETECTRON2_ONNX
def test_textregions_has_sources_field():
"""Test that TextRegions has a sources field"""
text_regions = TextRegions(element_coords=np.array([[0, 0, 10, 10]]))
# This should fail because TextRegions doesn't have a sources field
assert hasattr(text_regions, "sources"), "TextRegions should have a sources field"
assert hasattr(text_regions, "source"), "TextRegions should have a source field"
def test_textregions_iter_elements_preserves_source():
"""Test that TextRegions.iter_elements() preserves source property"""
from unstructured_inference.inference.elements import TextRegion
regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
]
text_regions = TextRegions.from_list(regions)
elements = list(text_regions.iter_elements())
# This should fail because iter_elements() doesn't pass source to TextRegion.from_coords()
assert elements[0].source == Source.YOLOX, "iter_elements() should preserve source"
def test_textregions_slice_preserves_sources():
"""Test that TextRegions slicing preserves sources array"""
from unstructured_inference.inference.elements import TextRegion
regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
TextRegion.from_coords(
10,
10,
20,
20,
text="second",
source=Source.DETECTRON2_ONNX,
is_extracted=IsExtracted.TRUE,
),
]
text_regions = TextRegions.from_list(regions)
sliced = text_regions[0:1]
# This should fail because slice() doesn't handle sources
assert sliced.sources.size > 0, "Sliced TextRegions should have sources"
assert sliced.sources[0] == Source.YOLOX
assert sliced.is_extracted_array[0] is IsExtracted.TRUE
def test_textregions_post_init_handles_sources():
"""Test that TextRegions.__post_init__() handles sources array initialization"""
# Create with source but no sources array
text_regions = TextRegions(
element_coords=np.array([[0, 0, 10, 10], [10, 10, 20, 20]]), source=Source.YOLOX
)
# This should fail because __post_init__() doesn't handle sources
assert text_regions.sources.size > 0, "sources should be initialized from source"
assert text_regions.sources[0] == Source.YOLOX
assert text_regions.sources[1] == Source.YOLOX
def test_textregions_from_coords_accepts_source():
"""Test that TextRegion.from_coords() accepts source parameter"""
# This should fail because from_coords() doesn't accept source parameter
region = TextRegion.from_coords(
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
)
assert region.source == Source.YOLOX
assert region.is_extracted
@pytest.mark.skip(reason="Not implemented")
def test_textregions_allows_for_single_element_access_and_returns_textregion_with_correct_values():
"""Test that TextRegions allows for single element access and returns a TextRegion with the
correct values"""
regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
TextRegion.from_coords(
0,
0,
20,
20,
text="second",
source=Source.DETECTRON2_ONNX,
is_extracted=IsExtracted.PARTIAL,
),
]
text_regions = TextRegions.from_list(regions)
for i, region in enumerate(regions):
sliced = text_regions[i]
assert isinstance(sliced, TextRegion)
assert sliced.text == region.text
assert sliced.source == region.source
assert sliced.is_extracted is region.is_extracted
================================================
FILE: test_unstructured_inference/test_logger.py
================================================
import logging
import pytest
from unstructured_inference import logger
@pytest.mark.parametrize("level", range(50))
def test_translate_log_level(level):
level_name = logging.getLevelName(level)
if level_name in ["WARNING", "INFO", "DEBUG", "NOTSET", "WARN"]:
expected = 4
elif level_name in ["ERROR", "CRITICAL"]:
expected = 3
else:
expected = 0
assert logger.translate_log_level(level) == expected
================================================
FILE: test_unstructured_inference/test_math.py
================================================
import numpy as np
import pytest
from unstructured_inference.math import FLOAT_EPSILON, safe_division
@pytest.mark.parametrize(
("a", "b", "expected"),
[(0, 0, 0), (0, 1, 0), (1, 0, np.round(1 / FLOAT_EPSILON, 1)), (2, 3, 0.7)],
)
def test_safe_division(a, b, expected):
assert np.round(safe_division(a, b), 1) == expected
================================================
FILE: test_unstructured_inference/test_utils.py
================================================
import numpy as np
import pytest
from unstructured_inference.inference.layout import DocumentLayout
from unstructured_inference.utils import (
LazyDict,
LazyEvaluateInfo,
pad_image_with_background_color,
strip_tags,
)
# Mocking the DocumentLayout and Page classes
class MockPageLayout:
def annotate(self, annotation_data):
return "mock_image"
class MockDocumentLayout(DocumentLayout):
@property
def pages(self):
return [MockPageLayout(), MockPageLayout()]
def test_dict_same():
d = {"a": 1, "b": 2, "c": 3}
ld = LazyDict(**d)
assert all(kd == kld for kd, kld in zip(d, ld))
assert all(d[k] == ld[k] for k in d)
assert len(ld) == len(d)
def test_lazy_evaluate():
called = 0
def func(x):
nonlocal called
called += 1
return x
lei = LazyEvaluateInfo(func, 3)
assert called == 0
ld = LazyDict(a=lei)
assert called == 0
assert ld["a"] == 3
assert called == 1
@pytest.mark.parametrize(("cache", "expected"), [(True, 1), (False, 2)])
def test_caches(cache, expected):
called = 0
def func(x):
nonlocal called
called += 1
return x
lei = LazyEvaluateInfo(func, 3)
assert called == 0
ld = LazyDict(cache=cache, a=lei)
assert called == 0
assert ld["a"] == 3
assert ld["a"] == 3
assert called == expected
def test_pad_image_with_background_color(mock_pil_image):
pad = 10
height, width = mock_pil_image.size
padded = pad_image_with_background_color(mock_pil_image, pad, "black")
assert padded.size == (height + 2 * pad, width + 2 * pad)
np.testing.assert_array_almost_equal(
np.array(padded.crop((pad, pad, width + pad, height + pad))),
np.array(mock_pil_image),
)
assert padded.getpixel((1, 1)) == (0, 0, 0)
def test_pad_image_with_invalid_input(mock_pil_image):
with pytest.raises(ValueError, match="Can not pad an image with negative space!"):
pad_image_with_background_color(mock_pil_image, -1)
@pytest.mark.parametrize(
("html", "text"),
[
("", "Table"),
# test escaped character
("", "yz"),
# test tag with parameters
("Table", "Table"),
],
)
def test_strip_tags(html, text):
assert strip_tags(html) == text
================================================
FILE: test_unstructured_inference/test_visualization.py
================================================
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.visualize import draw_bbox, show_plot
def test_draw_bbox():
test_image_arr = np.ones((100, 100, 3), dtype="uint8")
image = Image.fromarray(test_image_arr)
x1, y1, x2, y2 = (1, 10, 7, 11)
rect = TextRegion.from_coords(x1, y1, x2, y2)
annotated_image = draw_bbox(image=image, element=rect, details=False)
annotated_array = np.array(annotated_image)
# Make sure the pixels on the edge of the box are red
for i, expected in zip(range(3), [255, 0, 0]):
assert all(annotated_array[y1, x1:x2, i] == expected)
assert all(annotated_array[y2, x1:x2, i] == expected)
assert all(annotated_array[y1:y2, x1, i] == expected)
assert all(annotated_array[y1:y2, x2, i] == expected)
# Make sure almost all the pixels are not changed
assert ((annotated_array[:, :, 0] == 1).mean()) > 0.995
assert ((annotated_array[:, :, 1] == 1).mean()) > 0.995
assert ((annotated_array[:, :, 2] == 1).mean()) > 0.995
def test_show_plot_with_pil_image(mock_pil_image):
mock_fig = MagicMock()
mock_ax = MagicMock()
with (
patch(
"matplotlib.pyplot.subplots",
return_value=(mock_fig, mock_ax),
) as mock_subplots,
patch("matplotlib.pyplot.show") as mock_show,
patch.object(
mock_ax,
"imshow",
) as mock_imshow,
):
show_plot(mock_pil_image, desired_width=100)
mock_subplots.assert_called()
mock_imshow.assert_called_with(mock_pil_image)
mock_show.assert_called()
def test_show_plot_with_numpy_image(mock_numpy_image):
mock_fig = MagicMock()
mock_ax = MagicMock()
with (
patch(
"matplotlib.pyplot.subplots",
return_value=(mock_fig, mock_ax),
) as mock_subplots,
patch("matplotlib.pyplot.show") as mock_show,
patch.object(
mock_ax,
"imshow",
) as mock_imshow,
):
show_plot(mock_numpy_image)
mock_subplots.assert_called()
mock_imshow.assert_called_with(mock_numpy_image)
mock_show.assert_called()
def test_show_plot_with_unsupported_image_type():
with pytest.raises(ValueError) as exec_info:
show_plot("unsupported_image_type")
assert "Unsupported Image Type" in str(exec_info.value)
================================================
FILE: unstructured_inference/__init__.py
================================================
================================================
FILE: unstructured_inference/__version__.py
================================================
__version__ = "1.6.11" # pragma: no cover
================================================
FILE: unstructured_inference/config.py
================================================
"""
This module contains variables that can permitted to be tweaked by the system environment. For
example, model parameters that changes the output of an inference call. Constants do NOT belong in
this module. Constants are values that are usually names for common options (e.g., color names) or
settings that should not be altered without making a code change (e.g., definition of 1Gb of memory
in bytes). Constants should go into `./constants.py`
"""
import os
from dataclasses import dataclass
@dataclass
class InferenceConfig:
"""class for configuring inference parameters"""
def _get_string(self, var: str, default_value: str = "") -> str:
"""attempt to get the value of var from the os environment; if not present return the
default_value"""
return os.environ.get(var, default_value)
def _get_int(self, var: str, default_value: int) -> int:
if value := self._get_string(var):
return int(value)
return default_value
def _get_float(self, var: str, default_value: float) -> float:
if value := self._get_string(var):
return float(value)
return default_value
@property
def TABLE_IMAGE_BACKGROUND_PAD(self) -> int:
"""number of pixels to pad around an table image with a white background color
The padding adds NO image data around an identified table bounding box; it simply adds white
background around the image
"""
return self._get_int("TABLE_IMAGE_BACKGROUND_PAD", 20)
@property
def TT_TABLE_CONF(self) -> float:
"""confidence threshold for table identified by table transformer"""
return self._get_float("TT_TABLE_CONF", 0.5)
@property
def TABLE_COLUMN_CONF(self) -> float:
"""confidence threshold for column identified by table transformer"""
return self._get_float("TABLE_COLUMN_CONF", 0.5)
@property
def TABLE_ROW_CONF(self) -> float:
"""confidence threshold for column identified by table transformer"""
return self._get_float("TABLE_ROW_CONF", 0.5)
@property
def TABLE_COLUMN_HEADER_CONF(self) -> float:
"""confidence threshold for column header identified by table transformer"""
return self._get_float("TABLE_COLUMN_HEADER_CONF", 0.5)
@property
def TABLE_PROJECTED_ROW_HEADER_CONF(self) -> float:
"""confidence threshold for projected row header identified by table transformer"""
return self._get_float("TABLE_PROJECTED_ROW_HEADER_CONF", 0.5)
@property
def TABLE_SPANNING_CELL_CONF(self) -> float:
"""confidence threshold for table spanning cells identified by table transformer"""
return self._get_float("TABLE_SPANNING_CELL_CONF", 0.5)
@property
def TABLE_IOB_THRESHOLD(self) -> float:
"""minimum intersection over box area ratio for a box to be considered part of a larger box
it intersects"""
return self._get_float("TABLE_IOB_THRESHOLD", 0.5)
@property
def LAYOUT_SAME_REGION_THRESHOLD(self) -> float:
"""threshold for two layouts' bounding boxes to be considered as the same region
When the intersection area over union area of the two is larger than this threshold the two
boxes are considered the same region
"""
return self._get_float("LAYOUT_SAME_REGION_THRESHOLD", 0.75)
@property
def LAYOUT_SUBREGION_THRESHOLD(self) -> float:
"""threshold for one bounding box to be considered as a sub-region of another bounding box
When the intersection region area divided by self area is larger than this threshold self is
considered a subregion of the other
"""
return self._get_float("LAYOUT_SUBREGION_THRESHOLD", 0.75)
@property
def ELEMENTS_H_PADDING_COEF(self) -> float:
"""When extending the boundaries of a PDF object for the purpose of determining which other
elements should be considered in the same text region, we use a relative distance based on
some fraction of the block height (typically character height). This is the fraction used
for the horizontal extension applied to the left and right sides.
"""
return self._get_float("ELEMENTS_H_PADDING_COEF", 0.4)
@property
def ELEMENTS_V_PADDING_COEF(self) -> float:
"""Same as ELEMENTS_H_PADDING_COEF but the vertical extension."""
return self._get_float("ELEMENTS_V_PADDING_COEF", 0.3)
@property
def IMG_PROCESSOR_LONGEST_EDGE(self) -> int:
"""configuration for DetrImageProcessor to scale images"""
return self._get_int("IMG_PROCESSOR_LONGEST_EDGE", 1333)
@property
def IMG_PROCESSOR_SHORTEST_EDGE(self) -> int:
"""configuration for DetrImageProcessor to scale images"""
return self._get_int("IMG_PROCESSOR_SHORTEST_EDGE", 800)
@property
def PDF_RENDER_MAX_PIXELS_PER_PAGE(self) -> int:
"""maximum number of pixels (width * height) a single PDF page may render to
Pages whose rendered bitmap would exceed this value are rejected before allocation.
Set to 0 to disable the guard.
"""
return self._get_int("PDF_RENDER_MAX_PIXELS_PER_PAGE", 1_000_000_000)
inference_config = InferenceConfig()
================================================
FILE: unstructured_inference/constants.py
================================================
from enum import Enum
class Source(Enum):
YOLOX = "yolox"
DETECTRON2_ONNX = "detectron2_onnx"
DETECTRON2_LP = "detectron2_lp"
class IsExtracted(Enum):
TRUE = "true"
FALSE = "false"
PARTIAL = "partial"
class ElementType:
PARAGRAPH = "Paragraph"
IMAGE = "Image"
PARAGRAPH_IN_IMAGE = "ParagraphInImage"
FIGURE = "Figure"
PICTURE = "Picture"
TABLE = "Table"
PARAGRAPH_IN_TABLE = "ParagraphInTable"
LIST = "List"
FORM = "Form"
PARAGRAPH_IN_FORM = "ParagraphInForm"
CHECK_BOX_CHECKED = "CheckBoxChecked"
CHECK_BOX_UNCHECKED = "CheckBoxUnchecked"
RADIO_BUTTON_CHECKED = "RadioButtonChecked"
RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked"
LIST_ITEM = "List-item"
FORMULA = "Formula"
CAPTION = "Caption"
PAGE_HEADER = "Page-header"
SECTION_HEADER = "Section-header"
PAGE_FOOTER = "Page-footer"
FOOTNOTE = "Footnote"
TITLE = "Title"
TEXT = "Text"
UNCATEGORIZED_TEXT = "UncategorizedText"
PAGE_BREAK = "PageBreak"
CODE_SNIPPET = "CodeSnippet"
PAGE_NUMBER = "PageNumber"
OTHER = "Other"
FULL_PAGE_REGION_THRESHOLD = 0.99
# this field is defined by pytesseract/unstructured.pytesseract
TESSERACT_TEXT_HEIGHT = "height"
PDF_POINTS_PER_INCH = 72
================================================
FILE: unstructured_inference/inference/__init__.py
================================================
================================================
FILE: unstructured_inference/inference/elements.py
================================================
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
from typing import Optional, Union
import numpy as np
from unstructured_inference.constants import IsExtracted, Source
from unstructured_inference.math import safe_division
@dataclass
class Rectangle:
x1: Union[int, float]
y1: Union[int, float]
x2: Union[int, float]
y2: Union[int, float]
def pad(self, padding: Union[int, float]):
"""Increases (or decreases, if padding is negative) the size of the rectangle by extending
the boundary outward (resp. inward)."""
out_object = self.hpad(padding).vpad(padding)
return out_object
def hpad(self, padding: Union[int, float]):
"""Increases (or decreases, if padding is negative) the size of the rectangle by extending
the left and right sides of the boundary outward (resp. inward)."""
out_object = deepcopy(self)
out_object.x1 -= padding
out_object.x2 += padding
return out_object
def vpad(self, padding: Union[int, float]):
"""Increases (or decreases, if padding is negative) the size of the rectangle by extending
the top and bottom of the boundary outward (resp. inward)."""
out_object = deepcopy(self)
out_object.y1 -= padding
out_object.y2 += padding
return out_object
@property
def width(self) -> Union[int, float]:
"""Width of rectangle"""
return self.x2 - self.x1
@property
def height(self) -> Union[int, float]:
"""Height of rectangle"""
return self.y2 - self.y1
@property
def x_midpoint(self) -> Union[int, float]:
"""Finds the horizontal midpoint of the object."""
return (self.x2 + self.x1) / 2
@property
def y_midpoint(self) -> Union[int, float]:
"""Finds the vertical midpoint of the object."""
return (self.y2 + self.y1) / 2
def is_disjoint(self, other: Rectangle) -> bool:
"""Checks whether this rectangle is disjoint from another rectangle."""
return not self.intersects(other)
def intersects(self, other: Rectangle) -> bool:
"""Checks whether this rectangle intersects another rectangle."""
if self._has_none() or other._has_none():
return False
return intersections(self, other)[0, 1]
def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool:
"""Checks whether this rectangle is contained within another rectangle."""
padded_other = other.pad(error_margin) if error_margin is not None else other
return all(
[
(self.x1 >= padded_other.x1),
(self.x2 <= padded_other.x2),
(self.y1 >= padded_other.y1),
(self.y2 <= padded_other.y2),
],
)
def _has_none(self) -> bool:
"""return false when one of the coord is nan"""
return any((self.x1 is None, self.x2 is None, self.y1 is None, self.y2 is None))
@property
def coordinates(self):
"""Gets coordinates of the rectangle"""
return ((self.x1, self.y1), (self.x1, self.y2), (self.x2, self.y2), (self.x2, self.y1))
def intersection(self, other: Rectangle) -> Optional[Rectangle]:
"""Gives the rectangle that is the intersection of two rectangles, or None if the
rectangles are disjoint."""
if self._has_none() or other._has_none():
return None
x1 = max(self.x1, other.x1)
x2 = min(self.x2, other.x2)
y1 = max(self.y1, other.y1)
y2 = min(self.y2, other.y2)
if x1 > x2 or y1 > y2:
return None
return Rectangle(x1, y1, x2, y2)
@property
def area(self) -> float:
"""Gives the area of the rectangle."""
return self.width * self.height
def intersection_over_union(self, other: Rectangle) -> float:
"""Gives the intersection-over-union of two rectangles. This tends to be a good metric of
how similar the regions are. Returns 0 for disjoint rectangles, 1 for two identical
rectangles -- area of intersection / area of union."""
intersection = self.intersection(other)
intersection_area = 0.0 if intersection is None else intersection.area
union_area = self.area + other.area - intersection_area
return safe_division(intersection_area, union_area)
def intersection_over_minimum(self, other: Rectangle) -> float:
"""Gives the area-of-intersection over the minimum of the areas of the rectangles. Useful
for identifying when one rectangle is almost-a-subset of the other. Returns 0 for disjoint
rectangles, 1 when either is a subset of the other."""
intersection = self.intersection(other)
intersection_area = 0.0 if intersection is None else intersection.area
min_area = min(self.area, other.area)
return safe_division(intersection_area, min_area)
def is_almost_subregion_of(self, other: Rectangle, subregion_threshold: float = 0.75) -> bool:
"""Returns whether this region is almost a subregion of other. This is determined by
comparing the intersection area over self area to some threshold, and checking whether self
is the smaller rectangle."""
intersection = self.intersection(other)
intersection_area = 0.0 if intersection is None else intersection.area
return (subregion_threshold < safe_division(intersection_area, self.area)) and (
self.area <= other.area
)
def minimal_containing_region(*regions: Rectangle) -> Rectangle:
"""Returns the smallest rectangular region that contains all regions passed"""
x1 = min(region.x1 for region in regions)
y1 = min(region.y1 for region in regions)
x2 = max(region.x2 for region in regions)
y2 = max(region.y2 for region in regions)
return Rectangle(x1, y1, x2, y2)
def intersections(*rects: Rectangle):
"""Returns a square boolean matrix of intersections of an arbitrary number of rectangles, i.e.
the ijth entry of the matrix is True if and only if the ith Rectangle and jth Rectangle
intersect."""
# NOTE(alan): Rewrite using line scan
coords = np.array([[r.x1, r.y1, r.x2, r.y2] for r in rects])
return coords_intersections(coords)
def coords_intersections(coords: np.ndarray) -> np.ndarray:
"""Returns a square boolean matrix of intersections of given stack of coords, i.e.
the ijth entry of the matrix is True if and only if the ith coords and jth coords
intersect."""
x1s, y1s, x2s, y2s = coords[:, 0], coords[:, 1], coords[:, 2], coords[:, 3]
# Use broadcasting to get comparison matrices.
# For Rectangles r1 and r2, any of the following conditions makes the rectangles disjoint:
# r1.x1 > r2.x2
# r1.y1 > r2.y2
# r2.x1 > r1.x2
# r2.y1 > r1.y2
# Then we take the complement (~) of the disjointness matrix to get the intersection matrix.
intersections = ~(
(x1s[None] > x2s[..., None])
| (y1s[None] > y2s[..., None])
| (x1s[None] > x2s[..., None]).T
| (y1s[None] > y2s[..., None]).T
)
return intersections
@dataclass
class TextRegion:
bbox: Rectangle
text: Optional[str] = None
source: Optional[Source] = None
is_extracted: Optional[IsExtracted] = None
def __str__(self) -> str:
return str(self.text)
@classmethod
def from_coords(
cls,
x1: Union[int, float],
y1: Union[int, float],
x2: Union[int, float],
y2: Union[int, float],
text: Optional[str] = None,
source: Optional[Source] = None,
is_extracted: Optional[IsExtracted] = None,
**kwargs,
) -> TextRegion:
"""Constructs a region from coordinates."""
bbox = Rectangle(x1, y1, x2, y2)
return cls(text=text, source=source, is_extracted=is_extracted, bbox=bbox, **kwargs)
@dataclass
class TextRegions:
element_coords: np.ndarray
texts: np.ndarray = field(default_factory=lambda: np.array([]))
sources: np.ndarray = field(default_factory=lambda: np.array([]))
source: Source | None = None
is_extracted_array: np.ndarray = field(default_factory=lambda: np.array([]))
is_extracted: IsExtracted | None = None
_optional_array_attributes: list[str] = field(
init=False, default_factory=lambda: ["texts", "sources", "is_extracted_array"]
)
_scalar_to_array_mappings: dict[str, str] = field(
init=False,
default_factory=lambda: {
"source": "sources",
"is_extracted": "is_extracted_array",
},
)
def __post_init__(self):
element_size = self.element_coords.shape[0]
for scalar, array in self._scalar_to_array_mappings.items():
if (
getattr(self, scalar) is not None
and getattr(self, array).size == 0
and element_size
):
setattr(self, array, np.array([getattr(self, scalar)] * element_size))
elif getattr(self, scalar) is None and getattr(self, array).size > 0:
setattr(self, scalar, getattr(self, array)[0])
for attr in self._optional_array_attributes:
if getattr(self, attr).size == 0 and element_size:
setattr(self, attr, np.array([None] * element_size))
# we convert to float so data type is more consistent (e.g., None will be np.nan)
self.element_coords = self.element_coords.astype(float)
def __getitem__(self, indices) -> TextRegions:
return self.slice(indices)
def slice(self, indices) -> TextRegions:
"""slice text regions based on indices"""
# NOTE(alan): I would expect if I try to access a single element, it should return a
# TextRegion, not a TextRegions. Currently, you get an error when trying to access a single
# element.
if self.element_coords[indices].ndim == 1:
# We've indexed a single element. For now this isn't implemented.
raise NotImplementedError("Slicing a single element is not implemented")
return TextRegions(
element_coords=self.element_coords[indices],
texts=self.texts[indices],
sources=self.sources[indices],
is_extracted_array=self.is_extracted_array[indices],
)
def iter_elements(self):
"""iter text regions as one TextRegion per iteration; this returns a generator and has less
memory impact than the as_list method"""
for (x1, y1, x2, y2), text, source, is_extracted in zip(
self.element_coords,
self.texts,
self.sources,
self.is_extracted_array,
):
yield TextRegion.from_coords(x1, y1, x2, y2, text, source, is_extracted)
def as_list(self):
"""return a list of LayoutElement for backward compatibility"""
return list(self.iter_elements())
@classmethod
def from_list(cls, regions: list):
"""create TextRegions from a list of TextRegion objects; the objects must have the same
is_extracted"""
coords, texts, sources, is_extracted_array = [], [], [], []
for region in regions:
coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2))
texts.append(region.text)
sources.append(region.source)
is_extracted_array.append(region.is_extracted)
return cls(
element_coords=np.array(coords),
texts=np.array(texts),
sources=np.array(sources),
is_extracted_array=np.array(is_extracted_array),
)
def __len__(self):
return self.element_coords.shape[0]
@property
def x1(self):
"""left coordinate"""
return self.element_coords[:, 0]
@property
def y1(self):
"""top coordinate"""
return self.element_coords[:, 1]
@property
def x2(self):
"""right coordinate"""
return self.element_coords[:, 2]
@property
def y2(self):
"""bottom coordinate"""
return self.element_coords[:, 3]
@cached_property
def areas(self) -> np.ndarray:
"""areas of each region; only compute it when it is needed"""
return (self.x2 - self.x1) * (self.y2 - self.y1)
class EmbeddedTextRegion(TextRegion):
pass
class ImageTextRegion(TextRegion):
pass
def region_bounding_boxes_are_almost_the_same(
region1: Rectangle,
region2: Rectangle,
same_region_threshold: float = 0.75,
) -> bool:
"""Returns whether bounding boxes are almost the same. This is determined by checking if the
intersection over union is above some threshold."""
return region1.intersection_over_union(region2) > same_region_threshold
def grow_region_to_match_region(region_to_grow: Rectangle, region_to_match: Rectangle):
"""Grows a region to the minimum size necessary to contain both regions."""
(new_x1, new_y1), _, (new_x2, new_y2), _ = minimal_containing_region(
region_to_grow,
region_to_match,
).coordinates
region_to_grow.x1, region_to_grow.y1, region_to_grow.x2, region_to_grow.y2 = (
new_x1,
new_y1,
new_x2,
new_y2,
)
================================================
FILE: unstructured_inference/inference/layout.py
================================================
from __future__ import annotations
import os
import tempfile
from functools import cached_property
from pathlib import PurePath
from typing import Any, BinaryIO, Collection, List, Optional, Union, cast
import numpy as np
from PIL import Image, ImageSequence
from unstructured_inference.inference import pdf_image as pdf_image_utils
from unstructured_inference.inference.elements import (
TextRegion,
)
from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements
from unstructured_inference.logger import logger
from unstructured_inference.models.base import get_model
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
UnstructuredObjectDetectionModel,
)
from unstructured_inference.visualize import draw_bbox
convert_pdf_to_image = pdf_image_utils.convert_pdf_to_image
_pdfium_lock = pdf_image_utils._pdfium_lock
class DocumentLayout:
"""Class for handling documents that are saved as .pdf files. For .pdf files, a
document image analysis (DIA) model detects the layout of the page prior to extracting
element."""
def __init__(self, pages=None):
self._pages = pages
def __str__(self) -> str:
return "\n\n".join([str(page) for page in self.pages])
@property
def pages(self) -> List[PageLayout]:
"""Gets all elements from pages in sequential order."""
return self._pages
@classmethod
def from_pages(cls, pages: List[PageLayout]) -> DocumentLayout:
"""Generates a new instance of the class from a list of `PageLayouts`s"""
doc_layout = cls()
doc_layout._pages = pages
return doc_layout
@classmethod
def from_file(
cls,
filename: str,
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
pdf_image_dpi: int = 200,
pdf_render_max_pixels_per_page: Optional[int] = None,
password: Optional[str] = None,
**kwargs,
) -> DocumentLayout:
"""Creates a DocumentLayout from a pdf file."""
logger.info(f"Reading PDF for file: {filename} ...")
with tempfile.TemporaryDirectory() as temp_dir:
_image_paths = convert_pdf_to_image(
filename=filename,
dpi=pdf_image_dpi,
output_folder=temp_dir,
path_only=True,
password=password,
pdf_render_max_pixels_per_page=pdf_render_max_pixels_per_page,
)
image_paths = cast(List[str], _image_paths)
number_of_pages = len(image_paths)
pages: List[PageLayout] = []
if fixed_layouts is None:
fixed_layouts = [None for _ in range(0, number_of_pages)]
for i, (image_path, fixed_layout) in enumerate(zip(image_paths, fixed_layouts)):
# NOTE(robinson) - In the future, maybe we detect the page number and default
# to the index if it is not detected
with Image.open(image_path) as image:
page = PageLayout.from_image(
image,
number=i + 1,
document_filename=filename,
fixed_layout=fixed_layout,
pdf_render_max_pixels_per_page=pdf_render_max_pixels_per_page,
**kwargs,
)
pages.append(page)
return cls.from_pages(pages)
@classmethod
def from_image_file(
cls,
filename: str,
detection_model: Optional[UnstructuredObjectDetectionModel] = None,
element_extraction_model: Optional[UnstructuredElementExtractionModel] = None,
fixed_layout: Optional[List[TextRegion]] = None,
**kwargs,
) -> DocumentLayout:
"""Creates a DocumentLayout from an image file."""
logger.info(f"Reading image file: {filename} ...")
try:
image = Image.open(filename)
format = image.format
images: list[Image.Image] = []
for i, im in enumerate(ImageSequence.Iterator(image)):
im = im.convert("RGB")
im.format = format
images.append(im)
except Exception as e:
if os.path.isdir(filename) or os.path.isfile(filename):
raise e
else:
raise FileNotFoundError(f'File "{filename}" not found!') from e
pages = []
for i, image in enumerate(images): # type: ignore
page = PageLayout.from_image(
image,
image_path=filename,
number=i,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
fixed_layout=fixed_layout,
**kwargs,
)
pages.append(page)
return cls.from_pages(pages)
class PageLayout:
"""Class for an individual PDF page."""
def __init__(
self,
number: int,
image: Image.Image,
image_metadata: Optional[dict] = None,
image_path: Optional[Union[str, PurePath]] = None, # TODO: Deprecate
document_filename: Optional[Union[str, PurePath]] = None,
detection_model: Optional[UnstructuredObjectDetectionModel] = None,
element_extraction_model: Optional[UnstructuredElementExtractionModel] = None,
pdf_render_max_pixels_per_page: Optional[int] = None,
password: Optional[str] = None,
):
if detection_model is not None and element_extraction_model is not None:
raise ValueError("Only one of detection_model and extraction_model should be passed.")
self.image: Optional[Image.Image] = image
if image_metadata is None:
image_metadata = {}
self.image_metadata = image_metadata
self.image_path = image_path
self.image_array: Union[np.ndarray[Any, Any], None] = None
self.document_filename = document_filename
self.number = number
self.detection_model = detection_model
self.element_extraction_model = element_extraction_model
self.pdf_render_max_pixels_per_page = pdf_render_max_pixels_per_page
self.elements_array: LayoutElements | None = None
self.password = password
# NOTE(alan): Dropped LocationlessLayoutElement that was created for chipper - chipper has
# locations now and if we need to support LayoutElements without bounding boxes we can make
# the bbox property optional
def __str__(self) -> str:
return "\n\n".join([str(element) for element in self.elements])
@cached_property
def elements(self) -> Collection[LayoutElement]:
"""return a list of layout elements from the array data structure; intended for backward
compatibility"""
if self.elements_array is None:
return []
return self.elements_array.as_list()
def get_elements_using_image_extraction(
self,
inplace=True,
) -> Optional[list[LayoutElement]]:
"""Uses end-to-end text element extraction model to extract the elements on the page."""
if self.element_extraction_model is None:
raise ValueError(
"Cannot get elements using image extraction, no image extraction model defined",
)
assert self.image is not None
elements = self.element_extraction_model(self.image)
if inplace:
self.elements = elements
return None
return elements
def get_elements_with_detection_model(
self,
inplace: bool = True,
) -> Optional[LayoutElements]:
"""Uses specified model to detect the elements on the page."""
if self.detection_model is None:
model = get_model()
if isinstance(model, UnstructuredObjectDetectionModel):
self.detection_model = model
else:
raise NotImplementedError("Default model should be a detection model")
# NOTE(mrobinson) - We'll want make this model inference step some kind of
# remote call in the future.
assert self.image is not None
inferred_layout: LayoutElements = self.detection_model(self.image)
routing = inferred_layout.routing
routing_score = inferred_layout.routing_score
inferred_layout = self.detection_model.deduplicate_detected_elements(
inferred_layout,
)
inferred_layout.routing = routing
inferred_layout.routing_score = routing_score
if inplace:
self.elements_array = inferred_layout
return None
return inferred_layout
def _get_image_array(self) -> Union[np.ndarray[Any, Any], None]:
"""Converts the raw image into a numpy array."""
if self.image_array is None:
if self.image:
self.image_array = np.array(self.image)
else:
image = Image.open(self.image_path) # type: ignore
self.image_array = np.array(image)
return self.image_array
def annotate(
self,
colors: Optional[Union[List[str], str]] = None,
image_dpi: int = 200,
annotation_data: Optional[dict[str, dict]] = None,
add_details: bool = False,
sources: Optional[List[str]] = None,
) -> Image.Image:
"""Annotates the elements on the page image.
if add_details is True, and the elements contain type and source attributes, then
the type and source will be added to the image.
sources is a list of sources to annotate. If sources is ["all"], then all sources will be
annotated. Current sources allowed are "yolox","detectron2_onnx" and "detectron2_lp" """
if colors is None:
colors = ["red" for _ in self.elements]
if isinstance(colors, str):
colors = [colors]
# If there aren't enough colors, just cycle through the colors a few times
if len(colors) < len(self.elements):
n_copies = (len(self.elements) // len(colors)) + 1
colors = colors * n_copies
# Hotload image if it hasn't been loaded yet
if self.image:
img = self.image.copy()
elif self.image_path:
img = Image.open(self.image_path)
else:
img = self._get_image(self.document_filename, self.number, image_dpi)
if annotation_data is None:
for el, color in zip(self.elements, colors):
if sources is None or el.source in sources:
img = draw_bbox(img, el, color=color, details=add_details)
else:
for attribute, style in annotation_data.items():
if hasattr(self, attribute) and getattr(self, attribute):
color = style["color"]
width = style["width"]
for region in getattr(self, attribute):
required_source = getattr(region, "source", None)
if (sources is None) or (required_source in sources):
img = draw_bbox(
img,
region,
color=color,
width=width,
details=add_details,
)
return img
def _get_image(self, filename, page_number, pdf_image_dpi: int = 200) -> Image.Image:
"""Hotloads a page image from a pdf file."""
with tempfile.TemporaryDirectory() as temp_dir:
_image_paths = convert_pdf_to_image(
filename=filename,
dpi=pdf_image_dpi,
output_folder=temp_dir,
path_only=True,
pdf_render_max_pixels_per_page=self.pdf_render_max_pixels_per_page,
)
image_paths = cast(List[str], _image_paths)
if page_number > len(image_paths):
raise ValueError(
f"Page number {page_number} is greater than the number of pages in the PDF.",
)
with Image.open(image_paths[page_number - 1]) as image:
return image.copy()
@classmethod
def from_image(
cls,
image: Image.Image,
image_path: Optional[Union[str, PurePath]] = None,
document_filename: Optional[Union[str, PurePath]] = None,
number: int = 1,
detection_model: Optional[UnstructuredObjectDetectionModel] = None,
element_extraction_model: Optional[UnstructuredElementExtractionModel] = None,
fixed_layout: Optional[List[TextRegion]] = None,
pdf_render_max_pixels_per_page: Optional[int] = None,
):
"""Creates a PageLayout from an already-loaded PIL Image."""
page = cls(
number=number,
image=image,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
pdf_render_max_pixels_per_page=pdf_render_max_pixels_per_page,
)
# FIXME (yao): refactor the other methods so they all return elements like the third route
if page.element_extraction_model is not None:
page.get_elements_using_image_extraction()
elif fixed_layout is None:
page.get_elements_with_detection_model()
else:
page.elements = []
page.image_metadata = {
"format": page.image.format if page.image else None,
"width": page.image.width if page.image else None,
"height": page.image.height if page.image else None,
"pdf_rotation": int(page.image.info.get("pdf_rotation", 0)) if page.image else 0,
}
page.image_path = os.path.abspath(image_path) if image_path else None
page.document_filename = os.path.abspath(document_filename) if document_filename else None
# Clear the image to save memory
page.image = None
return page
def process_data_with_model(
data: BinaryIO,
model_name: Optional[str],
password: Optional[str] = None,
**kwargs: Any,
) -> DocumentLayout:
"""Process PDF or image as file-like object `data` into a `DocumentLayout`.
Uses the model identified by `model_name`.
"""
# Note: We use a temp dir, not a temp file,
# because the latter fails on Windows
# https://github.com/Unstructured-IO/unstructured-inference/pull/376
with tempfile.TemporaryDirectory() as tmp_dir_path:
file_path = os.path.join(tmp_dir_path, "document")
with open(file_path, "wb") as f:
f.write(data.read())
f.flush()
layout = process_file_with_model(
file_path,
model_name,
password=password,
**kwargs,
)
return layout
def process_file_with_model(
filename: str,
model_name: Optional[str],
is_image: bool = False,
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
pdf_image_dpi: int = 200,
pdf_render_max_pixels_per_page: Optional[int] = None,
password: Optional[str] = None,
**kwargs: Any,
) -> DocumentLayout:
"""Processes pdf or image file with name filename into a DocumentLayout by using
a model identified by model_name."""
model = get_model(model_name, **kwargs)
if isinstance(model, UnstructuredObjectDetectionModel):
detection_model = model
element_extraction_model = None
elif isinstance(model, UnstructuredElementExtractionModel):
detection_model = None
element_extraction_model = model
else:
raise ValueError(f"Unsupported model type: {type(model)}")
layout = (
DocumentLayout.from_image_file(
filename,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
**kwargs,
)
if is_image
else DocumentLayout.from_file(
filename,
detection_model=detection_model,
element_extraction_model=element_extraction_model,
fixed_layouts=fixed_layouts,
pdf_image_dpi=pdf_image_dpi,
pdf_render_max_pixels_per_page=pdf_render_max_pixels_per_page,
password=password,
**kwargs,
)
)
return layout
================================================
FILE: unstructured_inference/inference/layoutelement.py
================================================
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional, Union
import numpy as np
from pandas import DataFrame
from scipy.sparse.csgraph import connected_components
from unstructured_inference.config import inference_config
from unstructured_inference.constants import IsExtracted, Source
from unstructured_inference.inference.elements import (
Rectangle,
TextRegion,
TextRegions,
coords_intersections,
)
EPSILON_AREA = 1e-7
@dataclass
class LayoutElements(TextRegions):
element_probs: np.ndarray = field(default_factory=lambda: np.array([]))
element_class_ids: np.ndarray = field(default_factory=lambda: np.array([]))
element_class_id_map: dict[int, str] = field(default_factory=dict)
text_as_html: np.ndarray = field(default_factory=lambda: np.array([]))
table_as_cells: np.ndarray = field(default_factory=lambda: np.array([]))
table_extraction_method: np.ndarray = field(default_factory=lambda: np.array([]))
routing: str | None = None
routing_score: float | None = None
_optional_array_attributes: list[str] = field(
init=False,
default_factory=lambda: [
"texts",
"sources",
"is_extracted_array",
"element_probs",
"element_class_ids",
"text_as_html",
"table_as_cells",
"table_extraction_method",
],
)
_scalar_to_array_mappings: dict[str, str] = field(
init=False,
default_factory=lambda: {
"source": "sources",
"is_extracted": "is_extracted_array",
},
)
def __post_init__(self):
super().__post_init__()
self.element_probs = self.element_probs.astype(float)
def __eq__(self, other: object) -> bool:
if not isinstance(other, LayoutElements):
return NotImplemented
mask = ~np.isnan(self.element_probs)
other_mask = ~np.isnan(other.element_probs)
return (
np.array_equal(self.element_coords, other.element_coords)
and np.array_equal(self.texts, other.texts)
and np.array_equal(mask, other_mask)
and np.array_equal(self.element_probs[mask], other.element_probs[mask])
and (
[self.element_class_id_map[idx] for idx in self.element_class_ids]
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
)
and np.array_equal(self.sources[mask], other.sources[mask])
and np.array_equal(self.is_extracted_array[mask], other.is_extracted_array[mask])
and np.array_equal(self.text_as_html[mask], other.text_as_html[mask])
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
and np.array_equal(
self.table_extraction_method[mask], other.table_extraction_method[mask]
)
)
def __getitem__(self, indices):
return self.slice(indices)
def slice(self, indices) -> LayoutElements:
"""slice and return only selected indices"""
return LayoutElements(
element_coords=self.element_coords[indices],
texts=self.texts[indices],
is_extracted_array=self.is_extracted_array[indices],
sources=self.sources[indices],
element_probs=self.element_probs[indices],
element_class_ids=self.element_class_ids[indices],
element_class_id_map=self.element_class_id_map,
text_as_html=self.text_as_html[indices],
table_as_cells=self.table_as_cells[indices],
table_extraction_method=self.table_extraction_method[indices],
)
@classmethod
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
coords, texts, probs, class_ids, sources, is_extracted_array = [], [], [], [], [], []
text_as_html, table_as_cells, table_extraction_method = [], [], []
class_id_reverse_map: dict[str, int] = {}
for group in groups:
coords.append(group.element_coords)
texts.append(group.texts)
probs.append(group.element_probs)
sources.append(group.sources)
is_extracted_array.append(group.is_extracted_array)
text_as_html.append(group.text_as_html)
table_as_cells.append(group.table_as_cells)
table_extraction_method.append(group.table_extraction_method)
idx = group.element_class_ids.copy()
if group.element_class_id_map:
for class_id, class_name in group.element_class_id_map.items():
if class_name in class_id_reverse_map:
idx[group.element_class_ids == class_id] = class_id_reverse_map[class_name]
continue
new_id = len(class_id_reverse_map)
class_id_reverse_map[class_name] = new_id
idx[group.element_class_ids == class_id] = new_id
class_ids.append(idx)
return cls(
element_coords=np.concatenate(coords),
texts=np.concatenate(texts),
element_probs=np.concatenate(probs),
element_class_ids=np.concatenate(class_ids),
element_class_id_map={v: k for k, v in class_id_reverse_map.items()},
sources=np.concatenate(sources),
is_extracted_array=np.concatenate(is_extracted_array),
text_as_html=np.concatenate(text_as_html),
table_as_cells=np.concatenate(table_as_cells),
table_extraction_method=np.concatenate(table_extraction_method),
)
def iter_elements(self):
"""iter elements as one LayoutElement per iteration; this returns a generator and has less
memory impact than the as_list method"""
for (
(x1, y1, x2, y2),
text,
prob,
class_id,
source,
is_extracted,
text_as_html,
table_as_cells,
table_extraction_method,
) in zip(
self.element_coords,
self.texts,
self.element_probs,
self.element_class_ids,
self.sources,
self.is_extracted_array,
self.text_as_html,
self.table_as_cells,
self.table_extraction_method,
):
yield LayoutElement.from_coords(
x1,
y1,
x2,
y2,
text=text,
type=(
self.element_class_id_map[class_id]
if class_id is not None and self.element_class_id_map
else None
),
prob=None if np.isnan(prob) else prob,
source=source,
is_extracted=is_extracted,
text_as_html=text_as_html,
table_as_cells=table_as_cells,
table_extraction_method=table_extraction_method,
)
@classmethod
def from_list(cls, elements: list):
"""create LayoutElements from a list of LayoutElement objects; the objects must have the
same source"""
len_ele = len(elements)
coords = np.empty((len_ele, 4), dtype=float)
# text and probs can be Nones so use lists first then convert into array to avoid them being
# filled as nan
(
texts,
text_as_html,
table_as_cells,
table_extraction_method,
sources,
is_extracted_array,
class_probs,
) = (
[],
[],
[],
[],
[],
[],
[],
)
class_types = np.empty((len_ele,), dtype="object")
for i, element in enumerate(elements):
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
texts.append(element.text)
sources.append(element.source)
is_extracted_array.append(element.is_extracted)
text_as_html.append(element.text_as_html)
table_as_cells.append(element.table_as_cells)
table_extraction_method.append(getattr(element, "table_extraction_method", None))
class_probs.append(element.prob)
class_types[i] = element.type or "None"
unique_ids, class_ids = np.unique(class_types, return_inverse=True)
unique_ids[unique_ids == "None"] = None
return cls(
element_coords=coords,
texts=np.array(texts),
element_probs=np.array(class_probs),
element_class_ids=class_ids,
element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)),
sources=np.array(sources),
is_extracted_array=np.array(is_extracted_array),
text_as_html=np.array(text_as_html),
table_as_cells=np.array(table_as_cells),
table_extraction_method=np.array(table_extraction_method),
)
@dataclass
class LayoutElement(TextRegion):
type: Optional[str] = None
prob: Optional[float] = None
image_path: Optional[str] = None
parent: Optional[LayoutElement] = None
text_as_html: Optional[str] = None
table_as_cells: Optional[str] = None
table_extraction_method: Optional[str] = None
def to_dict(self) -> dict:
"""Converts the class instance to dictionary form."""
out_dict = {
"coordinates": None if self.bbox is None else self.bbox.coordinates,
"text": self.text,
"type": self.type,
"prob": self.prob,
"source": self.source,
"is_extracted": self.is_extracted,
}
return out_dict
@classmethod
def from_region(cls, region: TextRegion):
"""Create LayoutElement from superclass."""
text = region.text if hasattr(region, "text") else None
type = region.type if hasattr(region, "type") else None
prob = region.prob if hasattr(region, "prob") else None
source = region.source if hasattr(region, "source") else None
is_extracted = region.is_extracted if hasattr(region, "is_extracted") else None
return cls(
bbox=region.bbox,
text=text,
source=source,
is_extracted=is_extracted,
type=type,
prob=prob,
)
@classmethod
def from_coords(
cls,
x1: Union[int, float],
y1: Union[int, float],
x2: Union[int, float],
y2: Union[int, float],
text: Optional[str] = None,
source: Optional[Source] = None,
is_extracted: Optional[IsExtracted] = None,
type: Optional[str] = None,
prob: Optional[float] = None,
text_as_html: Optional[str] = None,
table_as_cells: Optional[str] = None,
table_extraction_method: Optional[str] = None,
**kwargs,
) -> LayoutElement:
"""Constructs a LayoutElement from coordinates."""
bbox = Rectangle(x1, y1, x2, y2)
return cls(
text=text,
is_extracted=is_extracted,
type=type,
prob=prob,
source=source,
text_as_html=text_as_html,
table_as_cells=table_as_cells,
table_extraction_method=table_extraction_method,
bbox=bbox,
**kwargs,
)
def separate(region_a: Rectangle, region_b: Rectangle):
"""Reduce leftmost rectangle to don't overlap with the other"""
def reduce(keep: Rectangle, reduce: Rectangle):
# Asume intersection
# Other is down
if reduce.y2 > keep.y2 and reduce.x1 < keep.x2:
# other is down-right
if reduce.x2 > keep.x2 and reduce.y2 > keep.y2:
reduce.x1 = keep.x2 * 1.01
reduce.y1 = keep.y2 * 1.01
return
# other is down-left
if reduce.x1 < keep.x1 and reduce.y1 < keep.y2:
reduce.y1 = keep.y2
return
# other is centered
reduce.y1 = keep.y2
else: # other is up
# other is up-right
if reduce.x2 > keep.x2 and reduce.y1 < keep.y1:
reduce.y2 = keep.y1
return
# other is left
if reduce.x1 < keep.x1 and reduce.y1 < keep.y1:
reduce.y2 = keep.y1
return
# other is centered
reduce.y2 = keep.y1
if not region_a.intersects(region_b):
return
else:
if region_a.area > region_b.area:
reduce(keep=region_a, reduce=region_b)
else:
reduce(keep=region_b, reduce=region_a)
def table_cells_to_dataframe(
cells: List[dict],
nrows: int = 1,
ncols: int = 1,
header=None,
) -> DataFrame:
"""convert table-transformer's cells data into a pandas dataframe"""
arr = np.empty((nrows, ncols), dtype=object)
for cell in cells:
rows = cell["row_nums"]
cols = cell["column_nums"]
if rows[0] >= nrows or cols[0] >= ncols:
new_arr = np.empty((max(rows[0] + 1, nrows), max(cols[0] + 1, ncols)), dtype=object)
new_arr[:nrows, :ncols] = arr
arr = new_arr
nrows, ncols = arr.shape
arr[rows[0], cols[0]] = cell["cell text"]
return DataFrame(arr, columns=header)
def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]:
"""Partitions regions into groups of regions based on proximity. Returns list of lists of
regions, each list corresponding with a group"""
if len(regions) == 0:
return []
padded_coords = regions.element_coords.copy().astype(float)
v_pad = (regions.y2 - regions.y1) * inference_config.ELEMENTS_V_PADDING_COEF
h_pad = (regions.x2 - regions.x1) * inference_config.ELEMENTS_H_PADDING_COEF
padded_coords[:, 0] -= h_pad
padded_coords[:, 1] -= v_pad
padded_coords[:, 2] += h_pad
padded_coords[:, 3] += v_pad
intersection_mtx = coords_intersections(padded_coords)
group_count, group_nums = connected_components(intersection_mtx)
groups: List[TextRegions] = []
for group in range(group_count):
groups.append(regions.slice(np.where(group_nums == group)[0]))
return groups
def intersection_areas_between_coords(
coords1: np.ndarray,
coords2: np.ndarray,
threshold: float = 0.5,
):
"""compute intersection area and own areas for two groups of bounding boxes"""
x11, y11, x12, y12 = np.split(coords1, 4, axis=1)
x21, y21, x22, y22 = np.split(coords2, 4, axis=1)
xa = np.maximum(x11, np.transpose(x21))
ya = np.maximum(y11, np.transpose(y21))
xb = np.minimum(x12, np.transpose(x22))
yb = np.minimum(y12, np.transpose(y22))
return np.maximum((xb - xa), 0) * np.maximum((yb - ya), 0)
def clean_layoutelements(elements: LayoutElements, subregion_threshold: float = 0.5):
"""After this function, the list of elements will not contain any element inside
of the type specified"""
# Sort elements from biggest to smallest
if len(elements) < 2:
return elements
sorted_by_area = np.argsort(-elements.areas)
sorted_coords = elements.element_coords[sorted_by_area]
# First check if targets contains each other
self_intersection = intersection_areas_between_coords(sorted_coords, sorted_coords)
areas = elements.areas[sorted_by_area]
# check from largest to smallest regions to find if it contains any other regions
is_almost_subregion_of = (
self_intersection / np.maximum(areas, EPSILON_AREA) > subregion_threshold
) & (areas <= areas.T)
n_candidates = len(elements)
mask = np.ones_like(areas, dtype=bool)
current_candidate = 0
while n_candidates > 1:
plus_one = current_candidate + 1
remove = (
np.where(is_almost_subregion_of[current_candidate, plus_one:])[0]
+ current_candidate
+ 1
)
if not remove.sum():
break
mask[remove] = 0
n_candidates -= len(remove) + 1
remaining_candidates = np.where(mask[plus_one:])[0]
if not len(remaining_candidates):
break
current_candidate = remaining_candidates[0] + plus_one
final_coords = sorted_coords[mask]
sorted_by_y1 = np.argsort(final_coords[:, 1])
final_attrs: dict[str, Any] = {
"element_class_id_map": elements.element_class_id_map,
}
for attr in (
"element_class_ids",
"element_probs",
"texts",
"sources",
"is_extracted_array",
"text_as_html",
"table_as_cells",
"table_extraction_method",
):
if (original_attr := getattr(elements, attr)) is None:
continue
final_attrs[attr] = original_attr[sorted_by_area][mask][sorted_by_y1]
final_elements = LayoutElements(element_coords=final_coords[sorted_by_y1], **final_attrs)
return final_elements
def clean_layoutelements_for_class(
elements: LayoutElements,
element_class: int,
subregion_threshold: float = 0.5,
):
"""After this function, the list of elements will not contain any element inside
of the type specified"""
# Sort elements from biggest to smallest
sorted_by_area = np.argsort(-elements.areas)
sorted_coords = elements.element_coords[sorted_by_area]
target_indices = elements.element_class_ids[sorted_by_area] == element_class
# skip trivial result
len_target = target_indices.sum()
if len_target == 0 or len_target == len(elements):
return elements
target_coords = sorted_coords[target_indices]
other_coords = sorted_coords[~target_indices]
# First check if targets contains each other
target_self_intersection = intersection_areas_between_coords(target_coords, target_coords)
target_areas = elements.areas[sorted_by_area][target_indices]
# check from largest to smallest regions to find if it contains any other regions
is_almost_subregion_of = (
target_self_intersection / np.maximum(target_areas, EPSILON_AREA) > subregion_threshold
) & (target_areas <= target_areas.T)
n_candidates = len_target
mask = np.ones_like(target_areas, dtype=bool)
current_candidate = 0
while n_candidates > 1:
plus_one = current_candidate + 1
remove = (
np.where(is_almost_subregion_of[current_candidate, plus_one:])[0]
+ current_candidate
+ 1
)
if not remove.sum():
break
mask[remove] = 0
n_candidates -= len(remove) + 1
remaining_candidates = np.where(mask[plus_one:])[0]
if not len(remaining_candidates):
break
current_candidate = remaining_candidates[0] + plus_one
target_coords_to_keep = target_coords[mask]
other_to_target_intersection = intersection_areas_between_coords(
other_coords,
target_coords_to_keep,
)
# check from largest to smallest regions to find if it contains any other regions
other_areas = elements.areas[sorted_by_area][~target_indices]
other_is_almost_subregion_of_target = (
other_to_target_intersection / np.maximum(other_areas, EPSILON_AREA) > subregion_threshold
) & (other_areas.reshape((-1, 1)) <= target_areas[mask].T)
other_mask = ~other_is_almost_subregion_of_target.sum(axis=1).astype(bool)
final_coords = np.vstack([target_coords[mask], other_coords[other_mask]])
final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
for attr in (
"element_class_ids",
"element_probs",
"texts",
"sources",
"is_extracted_array",
"text_as_html",
"table_as_cells",
"table_extraction_method",
):
if (original_attr := getattr(elements, attr)) is None:
continue
sorted_attr = original_attr[sorted_by_area]
final_attrs[attr] = np.concatenate(
(sorted_attr[target_indices][mask], sorted_attr[~target_indices][other_mask]),
)
final_elements = LayoutElements(element_coords=final_coords, **final_attrs)
return final_elements
================================================
FILE: unstructured_inference/inference/pdf_image.py
================================================
from __future__ import annotations
import math
import os
from functools import lru_cache
from pathlib import Path, PurePath
from threading import Lock
from typing import BinaryIO, Optional, Union
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from unstructured_inference.config import inference_config
from unstructured_inference.constants import PDF_POINTS_PER_INCH
_pdfium_lock = Lock()
class PdfRenderTooLargeError(ValueError):
pass
def _check_pdf_render_max_pixels(page, page_number: int, scale: float, maximum: int) -> None:
if maximum <= 0:
return
rendered_width = math.ceil(page.get_width() * scale)
rendered_height = math.ceil(page.get_height() * scale)
rendered_pixels = rendered_width * rendered_height
if rendered_pixels > maximum:
raise PdfRenderTooLargeError(
"PDF page would render to too many pixels for safe processing: "
f"page={page_number}, pixels={rendered_pixels}, maximum={maximum}. "
"Try splitting the PDF, reducing the page dimensions, or using a lower render DPI.",
)
@lru_cache(maxsize=1)
def _get_pdfium_module():
import pypdfium2 as pdfium
return pdfium
def convert_pdf_to_image(
filename: Optional[str] = None,
file: Optional[Union[bytes, BinaryIO]] = None,
dpi: int = 200,
output_folder: Optional[Union[str, PurePath]] = None,
path_only: bool = False,
first_page: Optional[int] = None,
last_page: Optional[int] = None,
password: Optional[str] = None,
pdf_render_max_pixels_per_page: Optional[int] = None,
) -> Union[list[Image.Image], list[str]]:
"""Render PDF pages to PIL images or saved PNGs using pypdfium2.
This is the single source of truth for PDF→image rendering across unstructured
and unstructured-inference. Callers should pass their own DPI value explicitly.
"""
if path_only and not output_folder:
raise ValueError("output_folder must be specified if path_only is true")
if filename is None and file is None:
raise ValueError("Either filename or file must be provided")
if output_folder:
assert Path(output_folder).exists()
assert Path(output_folder).is_dir()
scale = dpi / PDF_POINTS_PER_INCH
if pdf_render_max_pixels_per_page is None:
pdf_render_max_pixels_per_page = inference_config.PDF_RENDER_MAX_PIXELS_PER_PAGE
pdfium = _get_pdfium_module()
with _pdfium_lock:
pdf = pdfium.PdfDocument(filename or file, password=password)
n_pages = len(pdf)
try:
images: dict[int, Image.Image] = {}
filenames: list[str] = []
for i in range(n_pages):
page_num = i + 1
if first_page is not None and page_num < first_page:
continue
if last_page is not None and page_num > last_page:
break
with _pdfium_lock:
page = pdf[i]
try:
_check_pdf_render_max_pixels(
page=page,
page_number=page_num,
scale=scale,
maximum=pdf_render_max_pixels_per_page,
)
bitmap = page.render(
scale=scale,
no_smoothtext=False,
no_smoothimage=False,
no_smoothpath=False,
optimize_mode="print",
)
try:
pil_image = bitmap.to_pil()
finally:
bitmap.close()
rotation = page.get_rotation()
if rotation:
pil_image = pil_image.rotate(rotation, expand=True)
pil_image.info["pdf_rotation"] = rotation
finally:
page.close()
if output_folder:
fn: str = os.path.join(str(output_folder), f"page_{page_num}.png")
png_meta = PngInfo()
png_meta.add_text("pdf_rotation", str(rotation))
pil_image.save(
fn,
format="PNG",
compress_level=1,
optimize=False,
pnginfo=png_meta,
)
filenames.append(fn)
if not path_only:
images[page_num] = pil_image
else:
images[page_num] = pil_image
finally:
with _pdfium_lock:
pdf.close()
if path_only:
return filenames
return list(images.values())
================================================
FILE: unstructured_inference/logger.py
================================================
import logging
def translate_log_level(level: int) -> int:
"""Translate Python debugg level to ONNX runtime error level
since blank pages error are shown at level 3 that should be the
exception, and 4 the normal behavior"""
level_name = logging.getLevelName(level)
onnx_level = 0
if level_name in ["NOTSET", "DEBUG", "INFO", "WARNING"]:
onnx_level = 4
elif level_name in ["ERROR", "CRITICAL"]:
onnx_level = 3
return onnx_level
logger = logging.getLogger("unstructured_inference")
logger_onnx = logging.getLogger("unstructured_inference_onnxruntime")
logger_onnx.setLevel(translate_log_level(logger.getEffectiveLevel()))
================================================
FILE: unstructured_inference/math.py
================================================
"""a lightweight module that provides helpers to common math operations"""
import numpy as np
FLOAT_EPSILON = np.finfo(float).eps
def safe_division(a, b) -> float:
"""a safer division to avoid division by zero when b == 0
returns a/b or a/FLOAT_EPSILON (should be around 2.2E-16) when b == 0
Parameters:
- a (int/float): a in a/b
- b (int/float): b in a/b
Returns:
float: a/b or a/FLOAT_EPSILON (should be around 2.2E-16) when b == 0
"""
return a / max(b, FLOAT_EPSILON)
================================================
FILE: unstructured_inference/models/__init__.py
================================================
================================================
FILE: unstructured_inference/models/base.py
================================================
from __future__ import annotations
import json
import os
import threading
from typing import Dict, Optional, Tuple, Type
from unstructured_inference.models.detectron2onnx import (
MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES,
)
from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES
from unstructured_inference.models.yolox import UnstructuredYoloXModel
from unstructured_inference.utils import LazyDict
DEFAULT_MODEL = "yolox"
class Models(object):
"""Singleton container for loaded models.
Thread Safety:
- Singleton initialization protected by _lock (double-check pattern)
- Dict operations (__contains__, __getitem__, __setitem__) rely on CPython's GIL
for atomicity. Individual dict operations are atomic in CPython.
- Per-model locks in get_model() prevent concurrent initialization of same model
- This implementation is CPython-specific and may need changes for Python 3.13+
free-threaded mode or alternative Python implementations without GIL
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(Models, cls).__new__(cls)
cls.models: Dict[str, UnstructuredModel] = {}
return cls._instance
def __contains__(self, key):
"""Check if model exists. Atomic operation under CPython GIL."""
return key in self.models
def __getitem__(self, key: str):
"""Get model by name. Atomic operation under CPython GIL."""
return self.models.__getitem__(key)
def __setitem__(self, key: str, value: UnstructuredModel):
"""Store model. Atomic operation under CPython GIL."""
self.models[key] = value
models: Models = Models()
# Per-model locks for parallel loading of different models
# Current implementation: Unbounded dictionary grows with unique model names
# Memory impact: ~200 bytes per lock. Acceptable for <100 models (~20KB).
# For >1000 models: Consider lock striping (fixed 128 locks, ~25KB, 0.8% collision rate)
# Note: WeakValueDictionary is NOT suitable - locks would be GC'd immediately
_models_locks: Dict[str, threading.Lock] = {}
_models_locks_lock = threading.Lock()
def get_default_model_mappings() -> Tuple[
Dict[str, Type[UnstructuredModel]],
Dict[str, dict | LazyDict],
]:
"""default model mappings for models that are in `unstructured_inference` repo"""
return {
**dict.fromkeys(DETECTRON2_ONNX_MODEL_TYPES, UnstructuredDetectronONNXModel),
**dict.fromkeys(YOLOX_MODEL_TYPES, UnstructuredYoloXModel),
}, {**DETECTRON2_ONNX_MODEL_TYPES, **YOLOX_MODEL_TYPES}
model_class_map, model_config_map = get_default_model_mappings()
def register_new_model(model_config: dict, model_class: UnstructuredModel):
"""Register this model in model_config_map and model_class_map.
Those maps are updated with the with the new model class information.
"""
model_config_map.update(model_config)
model_class_map.update(dict.fromkeys(model_config, model_class))
def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
"""Gets the model object by model name.
Thread-safe with per-model locks to allow parallel loading of different models
while preventing duplicate initialization of the same model.
Thread-safety maintained:
- _models_locks_lock protects lock dictionary operations
- Per-model locks protect model initialization
- Double-check pattern prevents duplicate loads
"""
if model_name is None:
default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME")
model_name = default_name_from_env if default_name_from_env is not None else DEFAULT_MODEL
# Fast path: model already loaded
if model_name in models:
return models[model_name]
# Get or create lock for this specific model
with _models_locks_lock:
if model_name not in _models_locks:
_models_locks[model_name] = threading.Lock()
model_lock = _models_locks[model_name]
# Double-check pattern with per-model lock
with model_lock:
if model_name in models:
return models[model_name]
initialize_param_json = os.environ.get(
"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH"
)
if initialize_param_json is not None:
with open(initialize_param_json) as fp:
initialize_params = json.load(fp)
label_map_int_keys = {
int(key): value for key, value in initialize_params["label_map"].items()
}
initialize_params["label_map"] = label_map_int_keys
else:
if model_name in model_config_map:
initialize_params = model_config_map[model_name]
else:
raise UnknownModelException(f"Unknown model type: {model_name}")
model: UnstructuredModel = model_class_map[model_name]()
# Normalize to a plain dict via __iter__ + __getitem__. `**` unpacking
# calls `.keys()` on the mapping, which LazyDict inherits from
# collections.abc.Mapping — but we've seen environments where that
# inherited method isn't found at call time, surfacing as
# "argument after ** must be a mapping, not LazyDict".
initialize_params = {k: initialize_params[k] for k in initialize_params}
model.initialize(**initialize_params)
models[model_name] = model
return model
class UnknownModelException(Exception):
"""A model was requested with an unrecognized identifier."""
pass
================================================
FILE: unstructured_inference/models/detectron2onnx.py
================================================
import os
from typing import Dict, Final, List, Optional, Union, cast
import cv2
import numpy as np
import onnxruntime
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from onnxruntime.capi import _pybind_state as C
from onnxruntime.quantization import QuantType, quantize_dynamic
from PIL import Image
from unstructured_inference.constants import Source
from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured_inference.logger import logger, logger_onnx
from unstructured_inference.models.unstructuredmodel import (
UnstructuredObjectDetectionModel,
)
from unstructured_inference.utils import (
LazyDict,
LazyEvaluateInfo,
download_if_needed_and_get_local_path,
)
onnxruntime.set_default_logger_severity(logger_onnx.getEffectiveLevel())
DEFAULT_LABEL_MAP: Final[Dict[int, str]] = {
0: "Text",
1: "Title",
2: "List",
3: "Table",
4: "Figure",
}
# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are
# needed.
MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = {
"detectron2_onnx": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x",
"model.onnx",
),
label_map=DEFAULT_LABEL_MAP,
confidence_threshold=0.8,
),
"detectron2_quantized": {
"model_path": os.path.join(
HUGGINGFACE_HUB_CACHE,
"detectron2_quantized",
"detectrin2_quantized.onnx",
),
"label_map": DEFAULT_LABEL_MAP,
"confidence_threshold": 0.8,
},
"detectron2_mask_rcnn": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x",
"model.onnx",
),
label_map=DEFAULT_LABEL_MAP,
confidence_threshold=0.8,
),
}
class UnstructuredDetectronONNXModel(UnstructuredObjectDetectionModel):
"""Unstructured model wrapper for detectron2 ONNX model."""
# The model was trained and exported with this shape
required_w = 800
required_h = 1035
def predict(self, image: Image.Image) -> List[LayoutElement]:
"""Makes a prediction using detectron2 model."""
super().predict(image)
prepared_input = self.preprocess(image)
try:
result = self.model.run(None, prepared_input)
bboxes = result[0]
labels = result[1]
# Previous model detectron2_onnx stored confidence scores at index 2,
# bigger model stores it at index 3
confidence_scores = result[2] if "R_50" in self.model_path else result[3]
except onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException:
logger_onnx.debug(
"Ignoring runtime error from onnx (likely due to encountering blank page).",
)
return []
input_w, input_h = image.size
regions = self.postprocess(bboxes, labels, confidence_scores, input_w, input_h)
return regions
def initialize(
self,
model_path: str,
label_map: Dict[int, str],
confidence_threshold: Optional[float] = None,
):
"""Loads the detectron2 model using the specified parameters"""
if not os.path.exists(model_path) and "detectron2_quantized" in model_path:
logger.info("Quantized model don't currently exists, quantizing now...")
os.mkdir("".join(os.path.split(model_path)[:-1]))
source_path = MODEL_TYPES["detectron2_onnx"]["model_path"]
quantize_dynamic(source_path, model_path, weight_type=QuantType.QUInt8)
available_providers = C.get_available_providers()
ordered_providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
providers = [provider for provider in ordered_providers if provider in available_providers]
self.model = onnxruntime.InferenceSession(
model_path,
providers=providers,
)
self.model_path = model_path
self.label_map = label_map
if confidence_threshold is None:
confidence_threshold = 0.5
self.confidence_threshold = confidence_threshold
def preprocess(self, image: Image.Image) -> Dict[str, np.ndarray]:
"""Process input image into required format for ingestion into the Detectron2 ONNX binary.
This involves resizing to a fixed shape and converting to a specific numpy format.
"""
# TODO (benjamin): check other shapes for inference
img = np.array(image)
# TODO (benjamin): We should use models.get_model() but currenly returns Detectron model
session = self.model
# onnx input expected
# [3,1035,800]
img = cv2.resize(
img,
(self.required_w, self.required_h),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32)
img = img.transpose(2, 0, 1)
ort_inputs = {session.get_inputs()[0].name: img}
return ort_inputs
def postprocess(
self,
bboxes: np.ndarray,
labels: np.ndarray,
confidence_scores: np.ndarray,
input_w: float,
input_h: float,
) -> List[LayoutElement]:
"""Process output into Unstructured class. Bounding box coordinates are converted to
original image resolution."""
regions = []
width_conversion = input_w / self.required_w
height_conversion = input_h / self.required_h
for (x1, y1, x2, y2), label, conf in zip(bboxes, labels, confidence_scores):
detected_class = self.label_map[int(label)]
if conf >= self.confidence_threshold:
region = LayoutElement.from_coords(
x1 * width_conversion,
y1 * height_conversion,
x2 * width_conversion,
y2 * height_conversion,
text=None,
type=detected_class,
prob=conf,
source=Source.DETECTRON2_ONNX,
)
regions.append(region)
regions.sort(key=lambda element: element.bbox.y1)
return cast(List[LayoutElement], regions)
================================================
FILE: unstructured_inference/models/eval.py
================================================
from functools import partial
from typing import Callable, Dict, List, Optional
import pandas as pd
from rapidfuzz import fuzz
EVAL_FUNCTIONS = {
"token_ratio": fuzz.token_ratio,
"ratio": fuzz.ratio,
"partial_token_ratio": fuzz.partial_token_ratio,
"partial_ratio": fuzz.partial_ratio,
}
def _join_df_content(df, tab_token="\t", row_break_token="\n") -> str:
"""joining dataframe's table content as one long string"""
return row_break_token.join([tab_token.join(row) for row in df.values])
def default_tokenizer(text: str) -> List[str]:
"""a simple tokenizer that splits text by white space"""
return text.split()
def compare_contents_as_df(
actual_df: pd.DataFrame,
pred_df: pd.DataFrame,
eval_func: str = "token_ratio",
processor: Optional[Callable] = None,
tab_token: str = "\t",
row_break_token: str = "\n",
) -> Dict[str, float]:
r"""ravel the table as string then use text distance to compare the prediction against true
table
Parameters
----------
actual_df: pd.DataFrame
actual table as pandas dataframe
pred_df: pd.DataFrame
predicted table as pandas dataframe
eval_func: str, default tp "token_ratio"
the eval_func should be one of "token_ratio", "ratio", "partial_token_ratio",
"partial_ratio". Those are functions provided by rapidfuzz to evaluate text distances
using either tokens or characters. In general token is better than characters for evaluating
tables.
processor: Callable, default to None
processor to tokenize the text; by default None means no processing (using characters). For
tokens eval functions we recommend using the `default_tokenizer` or some other functions to
break down the text into words
tab_token: str, default to "\t"
the string to join cells together
row_break_token: str, default to "\n"
the string to join rows together
Returns
-------
Dict[str, int]
mapping of by column and by row scores to the scores as float numbers
"""
func = EVAL_FUNCTIONS.get(eval_func)
if func is None:
raise ValueError(
'eval_func must be one of "token_ratio", "ratio", "partial_token_ratio", '
f'"partial_ratio" but got {eval_func}',
)
join_func = partial(_join_df_content, tab_token=tab_token, row_break_token=row_break_token)
return {
f"by_col_{eval_func}": func(
join_func(actual_df),
join_func(pred_df),
processor=processor,
),
f"by_row_{eval_func}": func(
join_func(actual_df.T),
join_func(pred_df.T),
processor=processor,
),
}
================================================
FILE: unstructured_inference/models/table_postprocess.py
================================================
# https://github.com/microsoft/table-transformer/blob/main/src/postprocess.py
"""
Copyright (C) 2021 Microsoft Corporation
"""
from collections import defaultdict
class Rect:
def __init__(self, bbox=None):
if bbox is None:
self.x_min = 0
self.y_min = 0
self.x_max = 0
self.y_max = 0
else:
self.x_min = bbox[0]
self.y_min = bbox[1]
self.x_max = bbox[2]
self.y_max = bbox[3]
def get_area(self):
"""Calculates the area of the rectangle"""
area = (self.x_max - self.x_min) * (self.y_max - self.y_min)
return area if area > 0 else 0.0
def intersect(self, other):
"""Calculates the intersection with another rectangle"""
if self.get_area() == 0:
self.x_min = other.x_min
self.y_min = other.y_min
self.x_max = other.x_max
self.y_max = other.y_max
else:
self.x_min = max(self.x_min, other.x_min)
self.y_min = max(self.y_min, other.y_min)
self.x_max = min(self.x_max, other.x_max)
self.y_max = min(self.y_max, other.y_max)
if self.x_min > self.x_max or self.y_min > self.y_max or self.get_area() == 0:
self.x_min = 0
self.y_min = 0
self.x_max = 0
self.y_max = 0
return self
def include_rect(self, bbox):
"""Calculates a rectangle that includes both rectangles"""
other = Rect(bbox)
if self.get_area() == 0:
self.x_min = other.x_min
self.y_min = other.y_min
self.x_max = other.x_max
self.y_max = other.y_max
return self
self.x_min = min(self.x_min, other.x_min)
self.y_min = min(self.y_min, other.y_min)
self.x_max = max(self.x_max, other.x_max)
self.y_max = max(self.y_max, other.y_max)
# if self.get_area() == 0:
# self.x_min = other.x_min
# self.y_min = other.y_min
# self.x_max = other.x_max
# self.y_max = other.y_max
return self
def get_bbox(self):
"""Returns the coordinates that define the rectangle"""
return [self.x_min, self.y_min, self.x_max, self.y_max]
def apply_threshold(objects, threshold):
"""
Filter out objects below a certain score.
"""
return [obj for obj in objects if obj["score"] >= threshold]
def refine_rows(rows, tokens, score_threshold):
"""
Apply operations to the detected rows, such as
thresholding, NMS, and alignment.
"""
if len(tokens) > 0:
rows = nms_by_containment(rows, tokens, overlap_threshold=0.5)
remove_objects_without_content(tokens, rows)
else:
rows = nms(rows, match_criteria="object2_overlap", match_threshold=0.5, keep_higher=True)
if len(rows) > 1:
rows = sort_objects_top_to_bottom(rows)
return rows
def refine_columns(columns, tokens, score_threshold):
"""
Apply operations to the detected columns, such as
thresholding, NMS, and alignment.
"""
if len(tokens) > 0:
columns = nms_by_containment(columns, tokens, overlap_threshold=0.5)
remove_objects_without_content(tokens, columns)
else:
columns = nms(
columns,
match_criteria="object2_overlap",
match_threshold=0.25,
keep_higher=True,
)
if len(columns) > 1:
columns = sort_objects_left_to_right(columns)
return columns
def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5):
"""
Non-maxima suppression (NMS) of objects based on shared containment of other objects.
"""
container_objects = sort_objects_by_score(container_objects)
num_objects = len(container_objects)
suppression = [False for obj in container_objects]
packages_by_container, _, _ = slot_into_containers(
container_objects,
package_objects,
overlap_threshold=overlap_threshold,
forced_assignment=False,
)
for object2_num in range(1, num_objects):
object2_packages = set(packages_by_container[object2_num])
if len(object2_packages) == 0:
suppression[object2_num] = True
for object1_num in range(object2_num):
if not suppression[object1_num]:
object1_packages = set(packages_by_container[object1_num])
if len(object2_packages.intersection(object1_packages)) > 0:
suppression[object2_num] = True
final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
return final_objects
def slot_into_containers(
container_objects,
package_objects,
overlap_threshold=0.5,
forced_assignment=False,
):
"""
Slot a collection of objects into the container they occupy most (the container which holds the
largest fraction of the object).
"""
best_match_scores = []
container_assignments = [[] for container in container_objects]
package_assignments = [[] for package in package_objects]
if len(container_objects) == 0 or len(package_objects) == 0:
return container_assignments, package_assignments, best_match_scores
match_scores = defaultdict(dict)
for package_num, package in enumerate(package_objects):
match_scores = []
package_rect = Rect(package["bbox"])
package_area = package_rect.get_area()
for container_num, container in enumerate(container_objects):
container_rect = Rect(container["bbox"])
intersect_area = container_rect.intersect(Rect(package["bbox"])).get_area()
if package_area > 0:
overlap_fraction = intersect_area / package_area
match_scores.append(
{
"container": container,
"container_num": container_num,
"score": overlap_fraction,
},
)
if len(match_scores) > 0:
sorted_match_scores = sort_objects_by_score(match_scores)
best_match_score = sorted_match_scores[0]
best_match_scores.append(best_match_score["score"])
if forced_assignment or best_match_score["score"] >= overlap_threshold:
container_assignments[best_match_score["container_num"]].append(package_num)
package_assignments[package_num].append(best_match_score["container_num"])
return container_assignments, package_assignments, best_match_scores
def sort_objects_by_score(objects, reverse=True):
"""
Put any set of objects in order from high score to low score.
"""
return sorted(objects, key=lambda k: k["score"], reverse=reverse)
def remove_objects_without_content(page_spans, objects):
"""
Remove any objects (these can be rows, columns, supercells, etc.) that don't
have any text associated with them.
"""
for obj in objects[:]:
object_text, _ = extract_text_inside_bbox(page_spans, obj["bbox"])
if len(object_text.strip()) == 0:
objects.remove(obj)
def extract_text_inside_bbox(spans, bbox):
"""
Extract the text inside a bounding box.
"""
bbox_spans = get_bbox_span_subset(spans, bbox)
bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True)
return bbox_text, bbox_spans
def get_bbox_span_subset(spans, bbox, threshold=0.5):
"""
Reduce the set of spans to those that fall within a bounding box.
threshold: the fraction of the span that must overlap with the bbox.
"""
span_subset = []
for span in spans:
if overlaps(span["bbox"], bbox, threshold):
span_subset.append(span)
return span_subset
def overlaps(bbox1, bbox2, threshold=0.5):
"""
Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
"""
rect1 = Rect(list(bbox1))
area1 = rect1.get_area()
if area1 == 0:
return False
return rect1.intersect(Rect(list(bbox2))).get_area() / area1 >= threshold
def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
"""
Convert a collection of page tokens/words/spans into a single text string.
"""
join_char = " " if join_with_space else ""
spans_copy = spans[:]
if remove_integer_superscripts:
for span in spans:
if "flags" not in span:
continue
flags = span["flags"]
if flags & 2**0: # superscript flag
if span["text"].strip().isdigit():
spans_copy.remove(span)
else:
span["superscript"] = True
if len(spans_copy) == 0:
return ""
spans_copy.sort(key=lambda span: span["span_num"])
spans_copy.sort(key=lambda span: span["line_num"])
spans_copy.sort(key=lambda span: span["block_num"])
# Force the span at the end of every line within a block to have exactly one space
# unless the line ends with a space or ends with a non-space followed by a hyphen
line_texts = []
line_span_texts = [spans_copy[0]["text"]]
for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]):
if span1["block_num"] != span2["block_num"] or span1["line_num"] != span2["line_num"]:
line_text = join_char.join(line_span_texts).strip()
if (
len(line_text) > 0
and line_text[-1] != " "
and not (len(line_text) > 1 and line_text[-1] == "-" and line_text[-2] != " ")
and not join_with_space
):
line_text += " "
line_texts.append(line_text)
line_span_texts = [span2["text"]]
else:
line_span_texts.append(span2["text"])
line_text = join_char.join(line_span_texts)
line_texts.append(line_text)
return join_char.join(line_texts).strip()
def sort_objects_left_to_right(objs):
"""
Put the objects in order from left to right.
"""
return sorted(objs, key=lambda k: k["bbox"][0] + k["bbox"][2])
def sort_objects_top_to_bottom(objs):
"""
Put the objects in order from top to bottom.
"""
return sorted(objs, key=lambda k: k["bbox"][1] + k["bbox"][3])
def align_columns(columns, bbox):
"""
For every column, align the top and bottom boundaries to the final
table bounding box.
"""
try:
for column in columns:
column["bbox"][1] = bbox[1]
column["bbox"][3] = bbox[3]
except Exception as err:
print(f"Could not align columns: {err}")
pass
return columns
def align_rows(rows, bbox):
"""
For every row, align the left and right boundaries to the final
table bounding box.
"""
try:
for row in rows:
row["bbox"][0] = bbox[0]
row["bbox"][2] = bbox[2]
except Exception as err:
print(f"Could not align rows: {err}")
pass
return rows
def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True):
"""
A customizable version of non-maxima suppression (NMS).
Default behavior: If a lower-confidence object overlaps more than 5% of its area
with a higher-confidence object, remove the lower-confidence object.
objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
match_criteria: how to measure how much two objects "overlap"
match_threshold: the cutoff for determining that overlap requires suppression of one object
keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
"""
if len(objects) == 0:
return []
objects = sort_objects_by_score(objects, reverse=keep_higher)
num_objects = len(objects)
suppression = [False for obj in objects]
for object2_num in range(1, num_objects):
object2_rect = Rect(objects[object2_num]["bbox"])
object2_area = object2_rect.get_area()
for object1_num in range(object2_num):
if not suppression[object1_num]:
object1_rect = Rect(objects[object1_num]["bbox"])
object1_area = object1_rect.get_area()
intersect_area = object1_rect.intersect(object2_rect).get_area()
try:
if match_criteria == "object1_overlap":
metric = intersect_area / object1_area
elif match_criteria == "object2_overlap":
metric = intersect_area / object2_area
elif match_criteria == "iou":
metric = intersect_area / (object1_area + object2_area - intersect_area)
if metric >= match_threshold:
suppression[object2_num] = True
break
except ZeroDivisionError:
# Intended to recover from divide-by-zero
pass
return [obj for idx, obj in enumerate(objects) if not suppression[idx]]
def align_supercells(supercells, rows, columns):
"""
For each supercell, align it to the rows it intersects 50% of the height of,
and the columns it intersects 50% of the width of.
Eliminate supercells for which there are no rows and columns it intersects 50% with.
"""
aligned_supercells = []
for supercell in supercells:
supercell["header"] = False
row_bbox_rect = None
col_bbox_rect = None
intersecting_header_rows = set()
intersecting_data_rows = set()
for row_num, row in enumerate(rows):
row_height = row["bbox"][3] - row["bbox"][1]
supercell_height = supercell["bbox"][3] - supercell["bbox"][1]
min_row_overlap = max(row["bbox"][1], supercell["bbox"][1])
max_row_overlap = min(row["bbox"][3], supercell["bbox"][3])
overlap_height = max_row_overlap - min_row_overlap
if "span" in supercell:
overlap_fraction = max(
overlap_height / row_height,
overlap_height / supercell_height,
)
else:
overlap_fraction = overlap_height / row_height
if overlap_fraction >= 0.5:
if "header" in row and row["header"]:
intersecting_header_rows.add(row_num)
else:
intersecting_data_rows.add(row_num)
# Supercell cannot span across the header boundary; eliminate whichever
# group of rows is the smallest
supercell["header"] = False
if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0:
if len(intersecting_data_rows) > len(intersecting_header_rows):
intersecting_header_rows = set()
else:
intersecting_data_rows = set()
if len(intersecting_header_rows) > 0:
supercell["header"] = True
elif "span" in supercell:
continue # Require span supercell to be in the header
intersecting_rows = intersecting_data_rows.union(intersecting_header_rows)
# Determine vertical span of aligned supercell
for row_num in intersecting_rows:
if row_bbox_rect is None:
row_bbox_rect = Rect(rows[row_num]["bbox"])
else:
row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]["bbox"])
if row_bbox_rect is None:
continue
intersecting_cols = []
for col_num, col in enumerate(columns):
col_width = col["bbox"][2] - col["bbox"][0]
supercell_width = supercell["bbox"][2] - supercell["bbox"][0]
min_col_overlap = max(col["bbox"][0], supercell["bbox"][0])
max_col_overlap = min(col["bbox"][2], supercell["bbox"][2])
overlap_width = max_col_overlap - min_col_overlap
if "span" in supercell:
overlap_fraction = max(overlap_width / col_width, overlap_width / supercell_width)
# Multiply by 2 effectively lowers the threshold to 0.25
if supercell["header"]:
overlap_fraction = overlap_fraction * 2
else:
overlap_fraction = overlap_width / col_width
if overlap_fraction >= 0.5:
intersecting_cols.append(col_num)
if col_bbox_rect is None:
col_bbox_rect = Rect(col["bbox"])
else:
col_bbox_rect = col_bbox_rect.include_rect(col["bbox"])
if col_bbox_rect is None:
continue
supercell_bbox = row_bbox_rect.intersect(col_bbox_rect).get_bbox()
supercell["bbox"] = supercell_bbox
# Only a true supercell if it joins across multiple rows or columns
if (
len(intersecting_rows) > 0
and len(intersecting_cols) > 0
and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1)
):
supercell["row_numbers"] = list(intersecting_rows)
supercell["column_numbers"] = intersecting_cols
aligned_supercells.append(supercell)
# A span supercell in the header means there must be supercells above it in the header
if "span" in supercell and supercell["header"] and len(supercell["column_numbers"]) > 1:
for row_num in range(0, min(supercell["row_numbers"])):
new_supercell = {
"row_numbers": [row_num],
"column_numbers": supercell["column_numbers"],
"score": supercell["score"],
"propagated": True,
}
new_supercell_columns = [columns[idx] for idx in supercell["column_numbers"]]
new_supercell_rows = [rows[idx] for idx in supercell["row_numbers"]]
bbox = [
min([column["bbox"][0] for column in new_supercell_columns]),
min([row["bbox"][1] for row in new_supercell_rows]),
max([column["bbox"][2] for column in new_supercell_columns]),
max([row["bbox"][3] for row in new_supercell_rows]),
]
new_supercell["bbox"] = bbox
aligned_supercells.append(new_supercell)
return aligned_supercells
def nms_supercells(supercells):
"""
A NMS scheme for supercells that first attempts to shrink supercells to
resolve overlap.
If two supercells overlap the same (sub)cell, shrink the lower confidence
supercell to resolve the overlap. If shrunk supercell is empty, remove it.
"""
supercells = sort_objects_by_score(supercells)
num_supercells = len(supercells)
suppression = [False for supercell in supercells]
for supercell2_num in range(1, num_supercells):
supercell2 = supercells[supercell2_num]
for supercell1_num in range(supercell2_num):
supercell1 = supercells[supercell1_num]
remove_supercell_overlap(supercell1, supercell2)
if (
(len(supercell2["row_numbers"]) < 2 and len(supercell2["column_numbers"]) < 2)
or len(supercell2["row_numbers"]) == 0
or len(supercell2["column_numbers"]) == 0
):
suppression[supercell2_num] = True
return [obj for idx, obj in enumerate(supercells) if not suppression[idx]]
def header_supercell_tree(supercells):
"""
Make sure no supercell in the header is below more than one supercell in any row above it.
The cells in the header form a tree, but a supercell with more than one supercell in a row
above it means that some cell has more than one parent, which is not allowed. Eliminate
any supercell that would cause this to be violated.
"""
header_supercells = [
supercell for supercell in supercells if "header" in supercell and supercell["header"]
]
header_supercells = sort_objects_by_score(header_supercells)
for header_supercell in header_supercells[:]:
ancestors_by_row = defaultdict(int)
min_row = min(header_supercell["row_numbers"])
for header_supercell2 in header_supercells:
max_row2 = max(header_supercell2["row_numbers"])
if max_row2 < min_row and set(header_supercell["column_numbers"]).issubset(
set(header_supercell2["column_numbers"]),
):
for row2 in header_supercell2["row_numbers"]:
ancestors_by_row[row2] += 1
for row in range(0, min_row):
if ancestors_by_row[row] != 1:
supercells.remove(header_supercell)
break
def remove_supercell_overlap(supercell1, supercell2):
"""
This function resolves overlap between supercells (supercells must be
disjoint) by iteratively shrinking supercells by the fewest grid cells
necessary to resolve the overlap.
Example:
If two supercells overlap at grid cell (R, C), and supercell #1 is less
confident than supercell #2, we eliminate either row R from supercell #1
or column C from supercell #1 by comparing the number of columns in row R
versus the number of rows in column C. If the number of columns in row R
is less than the number of rows in column C, we eliminate row R from
supercell #1. This resolves the overlap by removing fewer grid cells from
supercell #1 than if we eliminated column C from it.
"""
common_rows = set(supercell1["row_numbers"]).intersection(set(supercell2["row_numbers"]))
common_columns = set(supercell1["column_numbers"]).intersection(
set(supercell2["column_numbers"]),
)
# While the supercells have overlapping grid cells, continue shrinking the less-confident
# supercell one row or one column at a time
while len(common_rows) > 0 and len(common_columns) > 0:
# Try to shrink the supercell as little as possible to remove the overlap;
# if the supercell has fewer rows than columns, remove an overlapping column,
# because this removes fewer grid cells from the supercell;
# otherwise remove an overlapping row
if len(supercell2["row_numbers"]) < len(supercell2["column_numbers"]):
min_column = min(supercell2["column_numbers"])
max_column = max(supercell2["column_numbers"])
if max_column in common_columns:
common_columns.remove(max_column)
supercell2["column_numbers"].remove(max_column)
elif min_column in common_columns:
common_columns.remove(min_column)
supercell2["column_numbers"].remove(min_column)
else:
supercell2["column_numbers"] = []
common_columns = set()
else:
min_row = min(supercell2["row_numbers"])
max_row = max(supercell2["row_numbers"])
if max_row in common_rows:
common_rows.remove(max_row)
supercell2["row_numbers"].remove(max_row)
elif min_row in common_rows:
common_rows.remove(min_row)
supercell2["row_numbers"].remove(min_row)
else:
supercell2["row_numbers"] = []
common_rows = set()
================================================
FILE: unstructured_inference/models/tables.py
================================================
# https://github.com/microsoft/table-transformer/blob/main/src/inference.py
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
import threading
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection, logging
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerObjectDetectionOutput,
)
from unstructured_inference.config import inference_config
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
from unstructured_inference.logger import logger
from unstructured_inference.models.table_postprocess import Rect
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
from unstructured_inference.utils import pad_image_with_background_color
from . import table_postprocess as postprocess
DEFAULT_MODEL = "microsoft/table-transformer-structure-recognition"
class UnstructuredTableTransformerModel(UnstructuredModel):
"""Unstructured model wrapper for table-transformer."""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(UnstructuredTableTransformerModel, cls).__new__(cls)
return cls._instance
def predict(
self,
x: PILImage.Image,
ocr_tokens: Optional[List[Dict]] = None,
result_format: str = "html",
):
"""Predict table structure deferring to run_prediction with ocr tokens
Note:
`ocr_tokens` is a list of dictionaries representing OCR tokens,
where each dictionary has the following format:
{
"bbox": [int, int, int, int], # Bounding box coordinates of the token
"block_num": int, # Block number
"line_num": int, # Line number
"span_num": int, # Span number
"text": str, # Text content of the token
}
The bounding box coordinates should match the table structure.
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
"""
super().predict(x)
return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format)
def initialize(
self,
model: Union[str, Path],
device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Loads the table transformer model using the specified parameters.
Device placement strategy:
- Normalize device names (cuda -> cuda:0) for consistent caching
- Load models WITHOUT device_map to avoid meta tensor errors
- Use explicit .to(device, dtype=torch.float32) for proper placement
"""
# Device normalization for consistent caching
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device.startswith("cuda") and ":" not in device:
if torch.cuda.is_available():
device = f"cuda:{torch.cuda.current_device()}"
else:
logger.warning("CUDA device requested but not available, falling back to CPU")
device = "cpu"
self.device = device
# Load feature extractor WITHOUT device_map
self.feature_extractor = DetrImageProcessor.from_pretrained(model)
# value not set in the configuration and needed for newer models
# https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all/discussions/1
self.feature_extractor.size["shortest_edge"] = inference_config.IMG_PROCESSOR_SHORTEST_EDGE
self.feature_extractor.size["longest_edge"] = inference_config.IMG_PROCESSOR_LONGEST_EDGE
try:
logger.info(f"Loading table structure model to {self.device}...")
cached_current_verbosity = logging.get_verbosity()
logging.set_verbosity_error()
# Load model WITHOUT device_map (prevents meta tensor errors)
self.model = TableTransformerForObjectDetection.from_pretrained(model)
# Explicit device placement with dtype
# NOTE: While nn.Module.to() modifies in-place, capturing return value is
# recommended best practice per PyTorch docs for consistency and clarity
self.model = self.model.to(self.device, dtype=torch.float32)
logging.set_verbosity(cached_current_verbosity)
self.model.eval()
logger.info(f"Table model successfully loaded to {self.device}")
except EnvironmentError:
logger.critical("Failed to initialize the model.")
logger.critical("Ensure that the model is correct")
raise ImportError(
"Review the parameters to initialize a UnstructuredTableTransformerModel obj",
)
def get_structure(
self,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
) -> TableTransformerObjectDetectionOutput:
"""get the table structure as a dictionary contaning different types of elements as
key-value pairs; check table-transformer documentation for more information"""
with torch.no_grad():
encoding = self.feature_extractor(
pad_image_with_background_color(x, pad_for_structure_detection),
return_tensors="pt",
).to(self.device)
outputs_structure = self.model(**encoding)
outputs_structure["pad_for_structure_detection"] = pad_for_structure_detection
return outputs_structure
def run_prediction(
self,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
result_format: Optional[str] = "html",
):
"""Predict table structure"""
outputs_structure = self.get_structure(x, pad_for_structure_detection)
if ocr_tokens is None:
raise ValueError("Cannot predict table structure with no OCR tokens")
recognized_table = recognize(outputs_structure, x, tokens=ocr_tokens)
if len(recognized_table) > 0:
prediction = recognized_table[0]
# NOTE(robinson) - This means that the table was not recognized
else:
return ""
if result_format == "html":
# Convert cells to HTML
prediction = cells_to_html(prediction) or ""
elif result_format == "dataframe":
prediction = table_cells_to_dataframe(prediction)
elif result_format == "cells":
prediction = prediction
else:
raise ValueError(
f"result_format {result_format} is not a valid format. "
f'Valid formats are: "html", "dataframe", "cells"',
)
return prediction
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()
def load_agent():
"""Loads the Table agent."""
if getattr(tables_agent, "model", None) is None:
with tables_agent._lock:
if getattr(tables_agent, "model", None) is None:
logger.info("Loading the Table agent ...")
tables_agent.initialize(DEFAULT_MODEL)
return
def get_class_map(data_type: str):
"""Defines class map dictionaries"""
if data_type == "structure":
class_map = {
"table": 0,
"table column": 1,
"table row": 2,
"table column header": 3,
"table projected row header": 4,
"table spanning cell": 5,
"no object": 6,
}
elif data_type == "detection":
class_map = {"table": 0, "table rotated": 1, "no object": 2}
return class_map
structure_class_thresholds = {
"table": inference_config.TT_TABLE_CONF,
"table column": inference_config.TABLE_COLUMN_CONF,
"table row": inference_config.TABLE_ROW_CONF,
"table column header": inference_config.TABLE_COLUMN_HEADER_CONF,
"table projected row header": inference_config.TABLE_PROJECTED_ROW_HEADER_CONF,
"table spanning cell": inference_config.TABLE_SPANNING_CELL_CONF,
# FIXME (yao) this parameter doesn't seem to be used at all in inference? Can we remove it
"no object": 10,
}
def recognize(outputs: TableTransformerObjectDetectionOutput, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
class_thresholds = structure_class_thresholds
# Post-process detected objects, assign class labels
objects = outputs_to_objects(outputs, img.size, str_class_idx2name)
high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds)
# Further process the detected objects so they correspond to a consistent table
tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
def outputs_to_objects(
outputs: TableTransformerObjectDetectionOutput,
img_size: Tuple[int, int],
class_idx2name: Mapping[int, str],
):
"""Output table element types."""
m = outputs["logits"].softmax(-1).max(-1)
pred_labels = m.indices.detach().cpu().numpy()[0]
pred_scores = m.values.detach().cpu().numpy()[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pad = outputs.get("pad_for_structure_detection", 0)
scale_size = (img_size[0] + pad * 2, img_size[1] + pad * 2)
rescaled = rescale_bboxes(pred_bboxes, scale_size)
# unshift the padding; padding effectively shifted the bounding boxes of structures in the
# original image with half of the total pad
if pad != 0:
rescaled = rescaled - pad
pred_bboxes = rescaled.tolist()
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = class_idx2name[int(label)]
if class_label != "no object":
objects.append(
{
"label": class_label,
"score": float(score),
"bbox": bbox,
},
)
return objects
def apply_thresholds_on_objects(
objects: Sequence[Mapping[str, Any]],
thresholds: Mapping[str, float],
) -> Sequence[Mapping[str, Any]]:
"""
Filters predicted objects which the confidence scores below the thresholds
Args:
objects: Sequence of mappings for example:
[
{
"label": "table row",
"score": 0.55,
"bbox": [...],
},
...,
]
thresholds: Mapping from labels to thresholds
Returns:
Filtered list of objects
"""
objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]]
return objects
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
"""Convert rectangle format from center-x, center-y, width, height to
x-min, y-min, x-max, y-max."""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
"""Rescale relative bounding box to box of size given by size."""
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device=out_bbox.device)
return b
def iob(bbox1, bbox2):
"""
Compute the intersection area over box area, for bbox1.
"""
intersection = Rect(bbox1).intersect(Rect(bbox2))
bbox1_area = Rect(bbox1).get_area()
if bbox1_area > 0:
return intersection.get_area() / bbox1_area
return 0
def objects_to_structures(objects, tokens, class_thresholds):
"""
Process the bounding boxes produced by the table structure recognition model into
a *consistent* set of table structures (rows, columns, spanning cells, headers).
This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment
conditions (for example: rows should all have the same width, etc.).
"""
tables = [obj for obj in objects if obj["label"] == "table"]
table_structures = []
for table in tables:
table_objects = [
obj
for obj in objects
if iob(obj["bbox"], table["bbox"]) >= inference_config.TABLE_IOB_THRESHOLD
]
table_tokens = [
token
for token in tokens
if iob(token["bbox"], table["bbox"]) >= inference_config.TABLE_IOB_THRESHOLD
]
structure = {}
columns = [obj for obj in table_objects if obj["label"] == "table column"]
rows = [obj for obj in table_objects if obj["label"] == "table row"]
column_headers = [obj for obj in table_objects if obj["label"] == "table column header"]
spanning_cells = [obj for obj in table_objects if obj["label"] == "table spanning cell"]
for obj in spanning_cells:
obj["projected row header"] = False
projected_row_headers = [
obj for obj in table_objects if obj["label"] == "table projected row header"
]
for obj in projected_row_headers:
obj["projected row header"] = True
spanning_cells += projected_row_headers
for obj in rows:
obj["column header"] = False
for header_obj in column_headers:
if iob(obj["bbox"], header_obj["bbox"]) >= inference_config.TABLE_IOB_THRESHOLD:
obj["column header"] = True
# Refine table structures
rows = postprocess.refine_rows(rows, table_tokens, class_thresholds["table row"])
columns = postprocess.refine_columns(
columns,
table_tokens,
class_thresholds["table column"],
)
# Shrink table bbox to just the total height of the rows
# and the total width of the columns
row_rect = Rect()
for obj in rows:
row_rect.include_rect(obj["bbox"])
column_rect = Rect()
for obj in columns:
column_rect.include_rect(obj["bbox"])
table["row_column_bbox"] = [
column_rect.x_min,
row_rect.y_min,
column_rect.x_max,
row_rect.y_max,
]
table["bbox"] = table["row_column_bbox"]
# Process the rows and columns into a complete segmented table
columns = postprocess.align_columns(columns, table["row_column_bbox"])
rows = postprocess.align_rows(rows, table["row_column_bbox"])
structure["rows"] = rows
structure["columns"] = columns
structure["column headers"] = column_headers
structure["spanning cells"] = spanning_cells
if len(rows) > 0 and len(columns) > 1:
structure = refine_table_structure(structure, class_thresholds)
table_structures.append(structure)
return table_structures
def refine_table_structure(table_structure, class_thresholds):
"""
Apply operations to the detected table structure objects such as
thresholding, NMS, and alignment.
"""
rows = table_structure["rows"]
columns = table_structure["columns"]
# Process the headers
column_headers = table_structure["column headers"]
column_headers = postprocess.apply_threshold(
column_headers,
class_thresholds["table column header"],
)
column_headers = postprocess.nms(column_headers)
column_headers = align_headers(column_headers, rows)
# Process spanning cells
spanning_cells = [
elem for elem in table_structure["spanning cells"] if not elem["projected row header"]
]
projected_row_headers = [
elem for elem in table_structure["spanning cells"] if elem["projected row header"]
]
spanning_cells = postprocess.apply_threshold(
spanning_cells,
class_thresholds["table spanning cell"],
)
projected_row_headers = postprocess.apply_threshold(
projected_row_headers,
class_thresholds["table projected row header"],
)
spanning_cells += projected_row_headers
# Align before NMS for spanning cells because alignment brings them into agreement
# with rows and columns first; if spanning cells still overlap after this operation,
# the threshold for NMS can basically be lowered to just above 0
spanning_cells = postprocess.align_supercells(spanning_cells, rows, columns)
spanning_cells = postprocess.nms_supercells(spanning_cells)
postprocess.header_supercell_tree(spanning_cells)
table_structure["columns"] = columns
table_structure["rows"] = rows
table_structure["spanning cells"] = spanning_cells
table_structure["column headers"] = column_headers
return table_structure
def align_headers(headers, rows):
"""
Adjust the header boundary to be the convex hull of the rows it intersects
at least 50% of the height of.
For now, we are not supporting tables with multiple headers, so we need to
eliminate anything besides the top-most header.
"""
aligned_headers = []
for row in rows:
row["column header"] = False
header_row_nums = []
for header in headers:
for row_num, row in enumerate(rows):
row_height = row["bbox"][3] - row["bbox"][1]
min_row_overlap = max(row["bbox"][1], header["bbox"][1])
max_row_overlap = min(row["bbox"][3], header["bbox"][3])
overlap_height = max_row_overlap - min_row_overlap
if overlap_height / row_height >= 0.5:
header_row_nums.append(row_num)
if len(header_row_nums) == 0:
return aligned_headers
header_rect = Rect()
if header_row_nums[0] > 0:
header_row_nums = list(range(header_row_nums[0] + 1)) + header_row_nums
last_row_num = -1
for row_num in header_row_nums:
if row_num == last_row_num + 1:
row = rows[row_num]
row["column header"] = True
header_rect = header_rect.include_rect(row["bbox"])
last_row_num = row_num
else:
# Break as soon as a non-header row is encountered.
# This ignores any subsequent rows in the table labeled as a header.
# Having more than 1 header is not supported currently.
break
header = {"bbox": header_rect.get_bbox()}
aligned_headers.append(header)
return aligned_headers
def compute_confidence_score(cell_match_scores):
"""
Compute a confidence score based on how well the page tokens
slot into the cells reported by the model
"""
try:
mean_match_score = sum(cell_match_scores) / len(cell_match_scores)
min_match_score = min(cell_match_scores)
confidence_score = (mean_match_score + min_match_score) / 2
except ZeroDivisionError:
confidence_score = 0
return confidence_score
def structure_to_cells(table_structure, tokens):
"""
Assuming the row, column, spanning cell, and header bounding boxes have
been refined into a set of consistent table structures, process these
table structures into table cells. This is a universal representation
format for the table, which can later be exported to Pandas or CSV formats.
Classify the cells as header/access cells or data cells
based on if they intersect with the header bounding box.
"""
columns = table_structure["columns"]
rows = table_structure["rows"]
spanning_cells = table_structure["spanning cells"]
spanning_cells = sorted(spanning_cells, reverse=True, key=lambda cell: cell["score"])
cells = []
subcells = []
# Identify complete cells and subcells
for column_num, column in enumerate(columns):
for row_num, row in enumerate(rows):
column_rect = Rect(list(column["bbox"]))
row_rect = Rect(list(row["bbox"]))
cell_rect = row_rect.intersect(column_rect)
header = "column header" in row and row["column header"]
cell = {
"bbox": cell_rect.get_bbox(),
"column_nums": [column_num],
"row_nums": [row_num],
"column header": header,
}
cell["subcell"] = False
for spanning_cell in spanning_cells:
spanning_cell_rect = Rect(list(spanning_cell["bbox"]))
if (
spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area()
) > inference_config.TABLE_IOB_THRESHOLD:
cell["subcell"] = True
cell["is_merged"] = False
break
if cell["subcell"]:
subcells.append(cell)
else:
# cell text = extract_text_inside_bbox(table_spans, cell['bbox'])
# cell['cell text'] = cell text
cell["projected row header"] = False
cells.append(cell)
for spanning_cell in spanning_cells:
spanning_cell_rect = Rect(list(spanning_cell["bbox"]))
cell_columns = set()
cell_rows = set()
cell_rect = None
header = True
for subcell in subcells:
subcell_rect = Rect(list(subcell["bbox"]))
subcell_rect_area = subcell_rect.get_area()
if (
subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area
) > inference_config.TABLE_IOB_THRESHOLD and subcell["is_merged"] is False:
if cell_rect is None:
cell_rect = Rect(list(subcell["bbox"]))
else:
cell_rect.include_rect(list(subcell["bbox"]))
cell_rows = cell_rows.union(set(subcell["row_nums"]))
cell_columns = cell_columns.union(set(subcell["column_nums"]))
# By convention here, all subcells must be classified
# as header cells for a spanning cell to be classified as a header cell;
# otherwise, this could lead to a non-rectangular header region
header = header and "column header" in subcell and subcell["column header"]
subcell["is_merged"] = True
if len(cell_rows) > 0 and len(cell_columns) > 0:
cell = {
"bbox": cell_rect.get_bbox(),
"column_nums": list(cell_columns),
"row_nums": list(cell_rows),
"column header": header,
"projected row header": spanning_cell["projected row header"],
}
cells.append(cell)
_, _, cell_match_scores = postprocess.slot_into_containers(cells, tokens)
confidence_score = compute_confidence_score(cell_match_scores)
# Dilate rows and columns before final extraction
# dilated_columns = fill_column_gaps(columns, table_bbox)
dilated_columns = columns
# dilated_rows = fill_row_gaps(rows, table_bbox)
dilated_rows = rows
for cell in cells:
column_rect = Rect()
for column_num in cell["column_nums"]:
column_rect.include_rect(list(dilated_columns[column_num]["bbox"]))
row_rect = Rect()
for row_num in cell["row_nums"]:
row_rect.include_rect(list(dilated_rows[row_num]["bbox"]))
cell_rect = column_rect.intersect(row_rect)
cell["bbox"] = cell_rect.get_bbox()
span_nums_by_cell, _, _ = postprocess.slot_into_containers(
cells,
tokens,
overlap_threshold=0.001,
forced_assignment=False,
)
for cell, cell_span_nums in zip(cells, span_nums_by_cell):
cell_spans = [tokens[num] for num in cell_span_nums]
# TODO: Refine how text is extracted; should be character-based, not span-based;
# but need to associate
cell["cell text"] = postprocess.extract_text_from_spans(
cell_spans,
remove_integer_superscripts=False,
)
cell["spans"] = cell_spans
# Adjust the row, column, and cell bounding boxes to reflect the extracted text
num_rows = len(rows)
rows = postprocess.sort_objects_top_to_bottom(rows)
num_columns = len(columns)
columns = postprocess.sort_objects_left_to_right(columns)
min_y_values_by_row = defaultdict(list)
max_y_values_by_row = defaultdict(list)
min_x_values_by_column = defaultdict(list)
max_x_values_by_column = defaultdict(list)
for cell in cells:
min_row = min(cell["row_nums"])
max_row = max(cell["row_nums"])
min_column = min(cell["column_nums"])
max_column = max(cell["column_nums"])
for span in cell["spans"]:
min_x_values_by_column[min_column].append(span["bbox"][0])
min_y_values_by_row[min_row].append(span["bbox"][1])
max_x_values_by_column[max_column].append(span["bbox"][2])
max_y_values_by_row[max_row].append(span["bbox"][3])
for row_num, row in enumerate(rows):
if len(min_x_values_by_column[0]) > 0:
row["bbox"][0] = min(min_x_values_by_column[0])
if len(min_y_values_by_row[row_num]) > 0:
row["bbox"][1] = min(min_y_values_by_row[row_num])
if len(max_x_values_by_column[num_columns - 1]) > 0:
row["bbox"][2] = max(max_x_values_by_column[num_columns - 1])
if len(max_y_values_by_row[row_num]) > 0:
row["bbox"][3] = max(max_y_values_by_row[row_num])
for column_num, column in enumerate(columns):
if len(min_x_values_by_column[column_num]) > 0:
column["bbox"][0] = min(min_x_values_by_column[column_num])
if len(min_y_values_by_row[0]) > 0:
column["bbox"][1] = min(min_y_values_by_row[0])
if len(max_x_values_by_column[column_num]) > 0:
column["bbox"][2] = max(max_x_values_by_column[column_num])
if len(max_y_values_by_row[num_rows - 1]) > 0:
column["bbox"][3] = max(max_y_values_by_row[num_rows - 1])
for cell in cells:
row_rect = None
column_rect = None
for row_num in cell["row_nums"]:
if row_rect is None:
row_rect = Rect(list(rows[row_num]["bbox"]))
else:
row_rect.include_rect(list(rows[row_num]["bbox"]))
for column_num in cell["column_nums"]:
if column_rect is None:
column_rect = Rect(list(columns[column_num]["bbox"]))
else:
column_rect.include_rect(list(columns[column_num]["bbox"]))
cell_rect = row_rect.intersect(column_rect)
if cell_rect.get_area() > 0:
cell["bbox"] = cell_rect.get_bbox()
pass
return cells, confidence_score
def fill_cells(cells: List[dict]) -> List[dict]:
"""fills the missing cells in the table by adding a cells with empty text
where there are no cells detected by the model.
A cell contains the following keys relevent to the html conversion:
row_nums: List[int]
the row numbers this cell belongs to; for cells spanning multiple rows there are more than
one numbers
column_nums: List[int]
the columns numbers this cell belongs to; for cells spanning multiple columns there are more
than one numbers
cell text: str
the text in this cell
column header: bool
whether this cell is a column header
"""
if not cells:
return []
# Find max row and col indices
max_row = max(row for cell in cells for row in cell["row_nums"])
max_col = max(col for cell in cells for col in cell["column_nums"])
filled = set()
for cell in cells:
for row in cell["row_nums"]:
for col in cell["column_nums"]:
filled.add((row, col))
header_rows = set()
for cell in cells:
if cell["column header"]:
header_rows.update(cell["row_nums"])
# Compose output list directly for speed
new_cells = cells.copy()
for row in range(max_row + 1):
for col in range(max_col + 1):
if (row, col) not in filled:
new_cells.append(
{
"row_nums": [row],
"column_nums": [col],
"cell text": "",
"column header": row in header_rows,
}
)
return new_cells
def cells_to_html(cells: List[dict]) -> str:
"""Convert table structure to html format.
Args:
cells: List of dictionaries representing table cells, where each dictionary has the
following format:
{
"row_nums": List[int],
"column_nums": List[int],
"cell text": str,
"column header": bool,
}
Returns:
str: HTML table string
"""
# Pre-sort with tuple key, as per original
cells_filled = fill_cells(cells)
cells_sorted = sorted(cells_filled, key=lambda k: (min(k["row_nums"]), min(k["column_nums"])))
table = ET.Element("table")
current_row = -1
# Check if any column header exists
table_has_header = any(cell["column header"] for cell in cells_sorted)
table_header = ET.SubElement(table, "thead") if table_has_header else None
table_body = ET.SubElement(table, "tbody")
row = None
for cell in cells_sorted:
this_row = min(cell["row_nums"])
attrib = {}
colspan = len(cell["column_nums"])
if colspan > 1:
attrib["colspan"] = str(colspan)
rowspan = len(cell["row_nums"])
if rowspan > 1:
attrib["rowspan"] = str(rowspan)
if this_row > current_row:
current_row = this_row
if cell["column header"]:
table_subelement = table_header
cell_tag = "th"
else:
table_subelement = table_body
cell_tag = "td"
row = ET.SubElement(table_subelement, "tr") # type: ignore
if row is not None:
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
tcell.text = cell["cell text"]
return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))
def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0:
# no zoom but still does dilation and erosion
zoom = 1
new_image = cv2.resize(
cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR),
None,
fx=zoom,
fy=zoom,
interpolation=cv2.INTER_CUBIC,
)
kernel = np.ones((1, 1), np.uint8)
new_image = cv2.dilate(new_image, kernel, iterations=1, dst=new_image)
new_image = cv2.erode(new_image, kernel, iterations=1, dst=new_image)
return PILImage.fromarray(new_image)
================================================
FILE: unstructured_inference/models/unstructuredmodel.py
================================================
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, cast
import numpy as np
from PIL.Image import Image
from unstructured_inference.constants import ElementType
from unstructured_inference.inference.elements import (
grow_region_to_match_region,
intersections,
)
from unstructured_inference.inference.layoutelement import (
LayoutElement,
LayoutElements,
clean_layoutelements,
partition_groups_from_regions,
separate,
)
class UnstructuredModel(ABC):
"""Wrapper class for the various models used by unstructured."""
def __init__(self):
"""model should support inference of some sort, either by calling or by some method.
UnstructuredModel doesn't provide any training interface, it's assumed the model is
already trained.
"""
self.model = None
@abstractmethod
def predict(self, x: Any) -> Any:
"""Do inference using the wrapped model."""
if self.model is None:
raise ModelNotInitializedError(
"Model has not been initialized. Please call the initialize method with the "
"appropriate arguments for loading the model.",
)
pass # pragma: no cover
def __call__(self, x: Any) -> Any:
"""Inference using function call interface."""
return self.predict(x)
@abstractmethod
def initialize(self, *args, **kwargs):
"""Load the model for inference."""
pass # pragma: no cover
class UnstructuredObjectDetectionModel(UnstructuredModel):
"""Wrapper class for object detection models used by unstructured."""
@abstractmethod
def predict(self, x: Image) -> LayoutElements | list[LayoutElement]:
"""Do inference using the wrapped model."""
super().predict(x)
return []
def __call__(self, x: Image) -> LayoutElements:
"""Inference using function call interface."""
return super().__call__(x)
@staticmethod
def enhance_regions(
elements: List[LayoutElement],
iom_to_merge: float = 0.3,
) -> List[LayoutElement]:
"""This function traverses all the elements and either deletes nested elements,
or merges or splits them depending on the iom score for both regions"""
rects = [el.bbox for el in elements]
intersections_mtx = intersections(*rects)
for i, row in enumerate(intersections_mtx):
first = elements[i]
if first:
# We get only the elements which intersected
indices_to_check = np.where(row)[0]
# Delete the first element, since it will always intersect with itself
indices_to_check = indices_to_check[indices_to_check != i]
if len(indices_to_check) == 0:
continue
if len(indices_to_check) > 1: # sort by iom
iom_to_check = [
(j, first.bbox.intersection_over_minimum(elements[j].bbox))
for j in indices_to_check
if elements[j] is not None
]
iom_to_check.sort(
key=lambda x: x[1],
reverse=True,
) # sort elements by iom, so we first check the greatest
indices_to_check = [x[0] for x in iom_to_check if x[0] != i] # type:ignore
for j in indices_to_check:
if elements[j] is None or elements[i] is None:
continue
second = elements[j]
intersection = first.bbox.intersection(
second.bbox,
) # we know it does, but need the region
first_inside_second = first.bbox.is_in(second.bbox)
second_inside_first = second.bbox.is_in(first.bbox)
if first_inside_second and not second_inside_first:
elements[i] = None # type:ignore
elif second_inside_first and not first_inside_second:
# delete second element
elements[j] = None # type:ignore
elif intersection:
iom = first.bbox.intersection_over_minimum(second.bbox)
if iom < iom_to_merge: # small
separate(first.bbox, second.bbox)
# The rectangle could become too small, which is a
# good size to delete?
else: # big
# merge
if first.bbox.area > second.bbox.area:
grow_region_to_match_region(first.bbox, second.bbox)
elements[j] = None # type:ignore
else:
grow_region_to_match_region(second.bbox, first.bbox)
elements[i] = None # type:ignore
elements = [e for e in elements if e is not None]
return elements
@staticmethod
def clean_type(
elements: list[LayoutElement],
type_to_clean=ElementType.TABLE,
) -> List[LayoutElement]:
"""After this function, the list of elements will not contain any element inside
of the type specified"""
target_elements = [e for e in elements if e.type == type_to_clean]
other_elements = [e for e in elements if e.type != type_to_clean]
if len(target_elements) == 0 or len(other_elements) == 0:
return elements
# Sort elements from biggest to smallest
target_elements.sort(key=lambda e: e.bbox.area, reverse=True)
other_elements.sort(key=lambda e: e.bbox.area, reverse=True)
# First check if targets contains each other
for element in target_elements: # Just handles containment or little overlap
contains = [
e
for e in target_elements
if e.bbox.is_almost_subregion_of(element.bbox) and e != element
]
for contained in contains:
target_elements.remove(contained)
# Then check if remaining elements intersect with targets
other_elements = filter(
lambda e: (
not any(e.bbox.is_almost_subregion_of(target.bbox) for target in target_elements)
),
other_elements,
) # type:ignore
final_elements = list(other_elements)
final_elements.extend(target_elements)
# Note(benjamin): could use bisect.insort if < operator is added to LayoutElement
final_elements.sort(key=lambda e: e.bbox.y1)
return final_elements
def deduplicate_detected_elements(
self,
elements: LayoutElements,
min_text_size: int = 15,
) -> LayoutElements:
"""Deletes overlapping elements in a list of elements."""
if len(elements) <= 1:
return elements
cleaned_elements = []
# TODO: Delete nested elements with low or None probability
# TODO: Keep most confident
# TODO: Better to grow horizontally than vertically?
groups = cast(list[LayoutElements], partition_groups_from_regions(elements))
for group in groups:
cleaned_elements.append(clean_layoutelements(group))
return LayoutElements.concatenate(cleaned_elements)
class UnstructuredElementExtractionModel(UnstructuredModel):
"""Wrapper class for object extraction models used by unstructured."""
@abstractmethod
def predict(self, x: Image) -> List[LayoutElement]:
"""Do inference using the wrapped model."""
super().predict(x)
return [] # pragma: no cover
def __call__(self, x: Image) -> List[LayoutElement]:
"""Inference using function call interface."""
return super().__call__(x)
class ModelNotInitializedError(Exception):
pass
================================================
FILE: unstructured_inference/models/yolox.py
================================================
# Copyright (c) Megvii, Inc. and its affiliates.
# Unstructured modified the original source code found at:
# https://github.com/Megvii-BaseDetection/YOLOX/blob/237e943ac64aa32eb32f875faa93ebb18512d41d/yolox/data/data_augment.py
# https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py
import cv2
import numpy as np
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from PIL import Image as PILImage
from unstructured_inference.constants import ElementType, Source
from unstructured_inference.inference.layoutelement import LayoutElements
from unstructured_inference.models.unstructuredmodel import (
UnstructuredObjectDetectionModel,
)
from unstructured_inference.utils import (
LazyDict,
LazyEvaluateInfo,
download_if_needed_and_get_local_path,
)
YOLOX_LABEL_MAP = {
0: ElementType.CAPTION,
1: ElementType.FOOTNOTE,
2: ElementType.FORMULA,
3: ElementType.LIST_ITEM,
4: ElementType.PAGE_FOOTER,
5: ElementType.PAGE_HEADER,
6: ElementType.PICTURE,
7: ElementType.SECTION_HEADER,
8: ElementType.TABLE,
9: ElementType.TEXT,
10: ElementType.TITLE,
}
MODEL_TYPES = {
"yolox": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
"yolox_tiny": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_tiny.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
"yolox_quantized": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05_quantized.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
}
class UnstructuredYoloXModel(UnstructuredObjectDetectionModel):
def predict(self, x: PILImage.Image):
"""Predict using YoloX model."""
super().predict(x)
return self.image_processing(x)
def initialize(self, model_path: str, label_map: dict):
"""Start inference session for YoloX model."""
self.model_path = model_path
available_providers = C.get_available_providers()
ordered_providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
providers = [provider for provider in ordered_providers if provider in available_providers]
self.model = onnxruntime.InferenceSession(
model_path,
providers=providers,
)
self.layout_classes = label_map
def image_processing(
self,
image: PILImage.Image,
) -> LayoutElements:
"""Method runing YoloX for layout detection, returns a PageLayout
parameters
----------
page
Path for image file with the image to process
origin_img
If specified, an Image object for process with YoloX model
page_number
Number asigned to the PageLayout returned
output_directory
Boolean indicating if result will be stored
"""
# The model was trained and exported with this shape
# TODO (benjamin): check other shapes for inference
input_shape = (1024, 768)
origin_img = np.array(image)
image.close()
img, ratio = preprocess(origin_img, input_shape)
del origin_img # Free full-size image array before ONNX inference
session = self.model
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
output = session.run(None, ort_inputs)
del img, ort_inputs # Free preprocessed inputs after inference
# TODO(benjamin): check for p6
predictions = demo_postprocess(output[0], input_shape, p6=False)[0]
del output
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
boxes_xyxy /= ratio
# Note (Benjamin): Distinct models (quantized and original) requires distincts
# levels of thresholds
if "quantized" in self.model_path:
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.0, score_thr=0.07)
else:
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.1, score_thr=0.25)
order = np.argsort(dets[:, 1])
sorted_dets = dets[order]
return LayoutElements(
element_coords=sorted_dets[:, :4].astype(float),
element_probs=sorted_dets[:, 4].astype(float),
element_class_ids=sorted_dets[:, 5].astype(int),
element_class_id_map=self.layout_classes,
sources=np.array([Source.YOLOX] * sorted_dets.shape[0]),
)
# Note: preprocess function was named preproc on original source
def preprocess(img, input_size, swap=(2, 0, 1)):
"""Preprocess image data before YoloX inference."""
if len(img.shape) == 3:
padded_img = np.full((input_size[0], input_size[1], 3), 114, dtype=np.uint8)
else:
padded_img = np.full(input_size, 114, dtype=np.uint8)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
def demo_postprocess(outputs, img_size, p6=False):
"""Postprocessing for YoloX model."""
grids = []
expanded_strides = []
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
return outputs
def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
"""Multiclass NMS implemented in Numpy"""
# TODO(benjamin): check for non-class agnostic
# if class_agnostic:
nms_method = multiclass_nms_class_agnostic
# else:
# nms_method = multiclass_nms_class_aware
return nms_method(boxes, scores, nms_thr, score_thr)
def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-agnostic version."""
cls_inds = scores.argmax(1)
cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
valid_score_mask = cls_scores > score_thr
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
valid_cls_inds = cls_inds[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]],
1,
)
return dets
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep
================================================
FILE: unstructured_inference/utils.py
================================================
import os
from collections.abc import Mapping
from html.parser import HTMLParser
from io import StringIO
from typing import Any, Callable, Hashable, Iterable, Iterator, Union
from huggingface_hub import hf_hub_download
from PIL import Image
from unstructured_inference.inference.layoutelement import LayoutElement
class LazyEvaluateInfo:
"""Class that stores the information needed to lazily evaluate a function with given arguments.
The object stores the information needed for evaluation as a function and its arguments.
"""
def __init__(self, evaluate: Callable, *args, **kwargs):
self.evaluate = evaluate
self.info = (args, kwargs)
class LazyDict(Mapping):
"""Class that wraps a dict and only evaluates keys of the dict when the key is accessed. Keys
that should be evaluated lazily should use LazyEvaluateInfo objects as values. By default when
a value is computed from a LazyEvaluateInfo object, it is converted to the raw value in the
internal dict, so subsequent accessing of the key will produce the same value. Set cache=False
to avoid storing the raw value.
"""
def __init__(self, *args, cache=True, **kwargs):
self.cache = cache
self._raw_dict = dict(*args, **kwargs)
def __getitem__(self, key: Hashable) -> Union[LazyEvaluateInfo, Any]:
value = self._raw_dict.__getitem__(key)
if isinstance(value, LazyEvaluateInfo):
evaluate = value.evaluate
args, kwargs = value.info
value = evaluate(*args, **kwargs)
if self.cache:
self._raw_dict[key] = value
return value
def __iter__(self) -> Iterator:
return iter(self._raw_dict)
def __len__(self) -> int:
return len(self._raw_dict)
def tag(elements: Iterable[LayoutElement]):
"""Asign an numeric id to the elements in the list.
Useful for debugging"""
colors = ["red", "blue", "green", "magenta", "brown"]
for i, e in enumerate(elements):
e.text = f"-{i}-:{e.text}"
# currently not a property
e.id = i # type:ignore
e.color = colors[i % len(colors)] # type:ignore
def pad_image_with_background_color(
image: Image.Image,
pad: int = 10,
background_color: str = "white",
) -> Image.Image:
"""pads an input image with the same background color around it by pad on all 4 sides
The original image is kept intact and a new image is returned with padding added.
"""
width, height = image.size
if pad < 0:
raise ValueError(
"Can not pad an image with negative space! Please use a positive value for `pad`.",
)
new = Image.new(image.mode, (width + pad * 2, height + pad * 2), background_color)
new.paste(image, (pad, pad))
return new
class MLStripper(HTMLParser):
"""simple markup language stripper that helps to strip tags from string"""
def __init__(self):
super().__init__()
self.reset()
self.strict = True
self.convert_charrefs = True
self.text = StringIO()
def handle_data(self, d):
"""process input data"""
self.text.write(d)
def get_data(self):
"""performs stripping by get the value of text"""
return self.text.getvalue()
def strip_tags(html: str) -> str:
"""stripping html tags from input string and return string without tags"""
s = MLStripper()
s.feed(html)
return s.get_data()
def download_if_needed_and_get_local_path(path_or_repo: str, filename: str, **kwargs) -> str:
"""Returns path to local file if it exists, otherwise treats it as a huggingface repo and
attempts to download."""
full_path = os.path.join(path_or_repo, filename)
if os.path.exists(full_path):
return full_path
else:
return hf_hub_download(path_or_repo, filename, **kwargs)
================================================
FILE: unstructured_inference/visualize.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
# Unstructured modified the original source code found at
# https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/visualize.py
import typing
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from PIL import ImageFont
from PIL.Image import Image
from PIL.ImageDraw import ImageDraw
from unstructured_inference.inference.elements import TextRegion
@typing.no_type_check
def draw_bbox(
image: Image,
element: TextRegion,
color: str = "red",
width=1,
details: bool = False,
) -> Image:
"""Draws bounding box in image"""
try:
img = image.copy()
draw = ImageDraw(img)
topleft, _, bottomright, _ = element.bbox.coordinates
c = getattr(element, "color", color)
if details:
source = getattr(element, "source", "Unknown")
type = getattr(element, "type", "")
kbd = ImageFont.truetype("Keyboard.ttf", 20)
draw.text(topleft, text=f"{type} {source}", fill=c, font=kbd)
draw.rectangle((topleft, bottomright), outline=c, width=width)
except OSError:
print("Failed to find font file. Skipping details.")
img = draw_bbox(image, element, color, width)
except Exception as e:
print(f"Failed to draw bounding box: {e}")
return img
def show_plot(
image: Union[Image, np.ndarray],
desired_width: Optional[int] = None,
):
"""
Display an image using matplotlib with an optional desired width while maintaining the aspect
ratio.
Parameters:
- image (Union[Image, np.ndarray]): An image in PIL Image format or a numpy ndarray format.
- desired_width (Optional[int]): Desired width for the display size of the image.
If provided, the height is calculated based on the original aspect ratio.
If not provided, the image will be displayed with its original dimensions.
Raises:
- ValueError: If the provided image type is neither PIL Image nor numpy ndarray.
Returns:
- None: The function displays the image using matplotlib but does not return any value.
"""
if isinstance(image, Image):
image_width, image_height = image.size
elif isinstance(image, np.ndarray):
image_height, image_width, _ = image.shape
else:
raise ValueError("Unsupported Image Type")
if desired_width:
# Calculate the desired height based on the original aspect ratio
aspect_ratio = image_width / image_height
desired_height = desired_width / aspect_ratio
# Create a figure with the desired size and aspect ratio
fig, ax = plt.subplots(figsize=(desired_width, desired_height))
else:
# Create figure and axes
fig, ax = plt.subplots()
# Display the image
ax.imshow(image)
plt.show()
_COLORS = np.array(
[
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
[0.635, 0.078, 0.184],
[0.300, 0.300, 0.300],
[0.600, 0.600, 0.600],
[1.000, 0.000, 0.000],
[1.000, 0.500, 0.000],
[0.749, 0.749, 0.000],
[0.000, 1.000, 0.000],
[0.000, 0.000, 1.000],
[0.667, 0.000, 1.000],
[0.333, 0.333, 0.000],
[0.333, 0.667, 0.000],
[0.333, 1.000, 0.000],
[0.667, 0.333, 0.000],
[0.667, 0.667, 0.000],
[0.667, 1.000, 0.000],
[1.000, 0.333, 0.000],
[1.000, 0.667, 0.000],
[1.000, 1.000, 0.000],
[0.000, 0.333, 0.500],
[0.000, 0.667, 0.500],
[0.000, 1.000, 0.500],
[0.333, 0.000, 0.500],
[0.333, 0.333, 0.500],
[0.333, 0.667, 0.500],
[0.333, 1.000, 0.500],
[0.667, 0.000, 0.500],
[0.667, 0.333, 0.500],
[0.667, 0.667, 0.500],
[0.667, 1.000, 0.500],
[1.000, 0.000, 0.500],
[1.000, 0.333, 0.500],
[1.000, 0.667, 0.500],
[1.000, 1.000, 0.500],
[0.000, 0.333, 1.000],
[0.000, 0.667, 1.000],
[0.000, 1.000, 1.000],
[0.333, 0.000, 1.000],
[0.333, 0.333, 1.000],
[0.333, 0.667, 1.000],
[0.333, 1.000, 1.000],
[0.667, 0.000, 1.000],
[0.667, 0.333, 1.000],
[0.667, 0.667, 1.000],
[0.667, 1.000, 1.000],
[1.000, 0.000, 1.000],
[1.000, 0.333, 1.000],
[1.000, 0.667, 1.000],
[0.333, 0.000, 0.000],
[0.500, 0.000, 0.000],
[0.667, 0.000, 0.000],
[0.833, 0.000, 0.000],
[1.000, 0.000, 0.000],
[0.000, 0.167, 0.000],
[0.000, 0.333, 0.000],
[0.000, 0.500, 0.000],
[0.000, 0.667, 0.000],
[0.000, 0.833, 0.000],
[0.000, 1.000, 0.000],
[0.000, 0.000, 0.167],
[0.000, 0.000, 0.333],
[0.000, 0.000, 0.500],
[0.000, 0.000, 0.667],
[0.000, 0.000, 0.833],
[0.000, 0.000, 1.000],
[0.000, 0.000, 0.000],
[0.143, 0.143, 0.143],
[0.286, 0.286, 0.286],
[0.429, 0.429, 0.429],
[0.571, 0.571, 0.571],
[0.714, 0.714, 0.714],
[0.857, 0.857, 0.857],
[0.000, 0.447, 0.741],
[0.314, 0.717, 0.741],
[0.50, 0.5, 0],
],
).astype(np.float32)