Full Code of facebookresearch/nougat for AI

main 5a92920d342f cached
47 files
2.2 MB
586.5k tokens
289 symbols
1 requests
Download .txt
Showing preview only (2,345K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/nougat
Branch: main
Commit: 5a92920d342f
Files: 47
Total size: 2.2 MB

Directory structure:
gitextract_rzpn8tp_/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE-MODEL.md
├── MANIFEST.in
├── NOTICE
├── README.md
├── app.py
├── config/
│   └── train_nougat.yaml
├── docker/
│   ├── Dockerfile
│   └── README.md
├── lightning_module.py
├── nougat/
│   ├── __init__.py
│   ├── _version.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── create_index.py
│   │   ├── gen_seek.py
│   │   ├── parser/
│   │   │   ├── __init__.py
│   │   │   ├── document.py
│   │   │   ├── html2md.py
│   │   │   ├── latexml_parser.py
│   │   │   └── markdown.py
│   │   ├── pdffigures.py
│   │   ├── rasterize.py
│   │   ├── split_htmls_to_pages.py
│   │   ├── split_md_to_pages.py
│   │   ├── splitter.py
│   │   ├── staircase.py
│   │   ├── tokenizer.json
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── latex_conversion.py
│   │       ├── pdf_text_extract.py
│   │       └── utils.py
│   ├── metrics.py
│   ├── model.py
│   ├── postprocessing.py
│   ├── transforms.py
│   └── utils/
│       ├── __init__.py
│       ├── checkpoint.py
│       ├── dataset.py
│       └── device.py
├── predict.py
├── setup.cfg
├── setup.py
├── test.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
core.*
*.bin
.nfs*
.vscode/*
result/*
!result/extract.py
misc/*
wandb/
!misc/*.png
!dataset/gen_seek.py
!result/.gitkeep
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

ckpt*/

# Misc
pdfs


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to Nougat

## Pull Requests

In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

## License
By contributing to this repo, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) Meta Platforms, Inc. and affiliates.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: LICENSE-MODEL.md
================================================
# Creative Commons Attribution-NonCommercial 4.0 International Public License

By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.

## Section 1 – Definitions.

a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.

b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.

c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not d. Copyright and Similar Rights.

d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.

e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.

f. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.

g. Licensor means the individual(s) or entity(ies) granting rights under this Public License.

i. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.

j. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.

k. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.

l. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.

## Section 2 – Scope.

a. License grant.
	1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
		A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
		B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.

	2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
	3. Term. The term of this Public License is specified in Section 6(a).
	4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
	5. Downstream recipients.
		a. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
		b. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
	6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).

b. Other rights.

1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.

2. Patent and trademark rights are not licensed under this Public License.

3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.

## Section 3 – License Conditions.

Your exercise of the Licensed Rights is expressly made subject to the following conditions.

a. Attribution.

1. If You Share the Licensed Material (including in modified form), You must:

	A. retain the following if it is supplied by the Licensor with the Licensed Material:
identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
		i) a copyright notice;
		ii) a notice that refers to this Public License;
		iii) a notice that refers to the disclaimer of warranties;
		iv) a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
	B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
	C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.

## Section 4 – Sui Generis Database Rights.

Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:

	a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
	b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
	c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.

For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.

## Section 5 – Disclaimer of Warranties and Limitation of Liability.

	a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.

	b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.

	c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.

## Section 6 – Term and Termination.

a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.

b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:

	1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
	2. upon express reinstatement by the Licensor.

For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.

c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.

d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.

## Section 7 – Other Terms and Conditions.

a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.

b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.

## Section 8 – Interpretation.

a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.

b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.

c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.

d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.

================================================
FILE: MANIFEST.in
================================================
include ./*.*


================================================
FILE: NOTICE
================================================
Donut
Copyright (c) 2022-present NAVER Corp.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

--------------------------------------------------------------------------------------

This project contains subcomponents with separate copyright notices and license terms. 
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.

=====

googlefonts/noto-fonts
https://fonts.google.com/specimen/Noto+Sans


Copyright 2018 The Noto Project Authors (github.com/googlei18n/noto-fonts)

This Font Software is licensed under the SIL Open Font License,
Version 1.1.

This license is copied below, and is also available with a FAQ at:
http://scripts.sil.org/OFL

-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
-----------------------------------------------------------

PREAMBLE
The goals of the Open Font License (OFL) are to stimulate worldwide
development of collaborative font projects, to support the font
creation efforts of academic and linguistic communities, and to
provide a free and open framework in which fonts may be shared and
improved in partnership with others.

The OFL allows the licensed fonts to be used, studied, modified and
redistributed freely as long as they are not sold by themselves. The
fonts, including any derivative works, can be bundled, embedded,
redistributed and/or sold with any software provided that any reserved
names are not used by derivative works. The fonts and derivatives,
however, cannot be released under any other type of license. The
requirement for fonts to remain under this license does not apply to
any document created using the fonts or their derivatives.

DEFINITIONS
"Font Software" refers to the set of files released by the Copyright
Holder(s) under this license and clearly marked as such. This may
include source files, build scripts and documentation.

"Reserved Font Name" refers to any names specified as such after the
copyright statement(s).

"Original Version" refers to the collection of Font Software
components as distributed by the Copyright Holder(s).

"Modified Version" refers to any derivative made by adding to,
deleting, or substituting -- in part or in whole -- any of the
components of the Original Version, by changing formats or by porting
the Font Software to a new environment.

"Author" refers to any designer, engineer, programmer, technical
writer or other person who contributed to the Font Software.

PERMISSION & CONDITIONS
Permission is hereby granted, free of charge, to any person obtaining
a copy of the Font Software, to use, study, copy, merge, embed,
modify, redistribute, and sell modified and unmodified copies of the
Font Software, subject to the following conditions:

1) Neither the Font Software nor any of its individual components, in
Original or Modified Versions, may be sold by itself.

2) Original or Modified Versions of the Font Software may be bundled,
redistributed and/or sold with any software, provided that each copy
contains the above copyright notice and this license. These can be
included either as stand-alone text files, human-readable headers or
in the appropriate machine-readable metadata fields within text or
binary files as long as those fields can be easily viewed by the user.

3) No Modified Version of the Font Software may use the Reserved Font
Name(s) unless explicit written permission is granted by the
corresponding Copyright Holder. This restriction only applies to the
primary font name as presented to the users.

4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
Software shall not be used to promote, endorse or advertise any
Modified Version, except to acknowledge the contribution(s) of the
Copyright Holder(s) and the Author(s) or with their explicit written
permission.

5) The Font Software, modified or unmodified, in part or in whole,
must be distributed entirely under this license, and must not be
distributed under any other license. The requirement for fonts to
remain under this license does not apply to any document created using
the Font Software.

TERMINATION
This license becomes null and void if any of the above conditions are
not met.

DISCLAIMER
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
OTHER DEALINGS IN THE FONT SOFTWARE.

=====

huggingface/transformers
https://github.com/huggingface/transformers


Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.

=====

clovaai/synthtiger
https://github.com/clovaai/synthtiger


Copyright (c) 2021-present NAVER Corp.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

=====

rwightman/pytorch-image-models
https://github.com/rwightman/pytorch-image-models


   Copyright 2019 Ross Wightman

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

=====

ankush-me/SynthText
https://github.com/ankush-me/SynthText


   Copyright 2017, Ankush Gupta.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

=====


================================================
FILE: README.md
================================================
<div align="center">
<h1>Nougat: Neural Optical Understanding for Academic Documents</h1>

[![Paper](https://img.shields.io/badge/Paper-arxiv.2308.13418-white)](https://arxiv.org/abs/2308.13418)
[![GitHub](https://img.shields.io/github/license/facebookresearch/nougat)](https://github.com/facebookresearch/nougat)
[![PyPI](https://img.shields.io/pypi/v/nougat-ocr?logo=pypi)](https://pypi.org/project/nougat-ocr)
[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Hugging Face Spaces](https://img.shields.io/badge/🤗%20Hugging%20Face-Community%20Space-blue)](https://huggingface.co/spaces/ysharma/nougat)

</div>

This is the official repository for Nougat, the academic document PDF parser that understands LaTeX math and tables.

Project page: https://facebookresearch.github.io/nougat/

## Install

From pip:
```
pip install nougat-ocr
```

From repository:
```
pip install git+https://github.com/facebookresearch/nougat
```

> Note, on Windows: If you want to utilize a GPU, make sure you first install the correct PyTorch version. Follow instructions [here](https://pytorch.org/get-started/locally/)

There are extra dependencies if you want to call the model from an API or generate a dataset.
Install via

`pip install "nougat-ocr[api]"` or `pip install "nougat-ocr[dataset]"`

### Get prediction for a PDF
#### CLI

To get predictions for a PDF run

```
$ nougat path/to/file.pdf -o output_directory
```

A path to a directory or to a file where each line is a path to a PDF can also be passed as a positional argument

```
$ nougat path/to/directory -o output_directory
```

```
usage: nougat [-h] [--batchsize BATCHSIZE] [--checkpoint CHECKPOINT] [--model MODEL] [--out OUT]
              [--recompute] [--markdown] [--no-skipping] pdf [pdf ...]

positional arguments:
  pdf                   PDF(s) to process.

options:
  -h, --help            show this help message and exit
  --batchsize BATCHSIZE, -b BATCHSIZE
                        Batch size to use.
  --checkpoint CHECKPOINT, -c CHECKPOINT
                        Path to checkpoint directory.
  --model MODEL_TAG, -m MODEL_TAG
                        Model tag to use.
  --out OUT, -o OUT     Output directory.
  --recompute           Recompute already computed PDF, discarding previous predictions.
  --full-precision      Use float32 instead of bfloat16. Can speed up CPU conversion for some setups.
  --no-markdown         Do not add postprocessing step for markdown compatibility.
  --markdown            Add postprocessing step for markdown compatibility (default).
  --no-skipping         Don't apply failure detection heuristic.
  --pages PAGES, -p PAGES
                        Provide page numbers like '1-4,7' for pages 1 through 4 and page 7. Only works for single PDFs.
```

The default model tag is `0.1.0-small`. If you want to use the base model, use `0.1.0-base`.
```
$ nougat path/to/file.pdf -o output_directory -m 0.1.0-base
```

In the output directory every PDF will be saved as a `.mmd` file, the lightweight markup language, mostly compatible with [Mathpix Markdown](https://github.com/Mathpix/mathpix-markdown-it) (we make use of the LaTeX tables).

> Note: On some devices the failure detection heuristic is not working properly. If you experience a lot of `[MISSING_PAGE]` responses, try to run with the `--no-skipping` flag. Related: [#11](https://github.com/facebookresearch/nougat/issues/11), [#67](https://github.com/facebookresearch/nougat/issues/67)

#### API

With the extra dependencies you use `app.py` to start an API. Call

```sh
$ nougat_api
```

To get a prediction of a PDF file by making a POST request to http://127.0.0.1:8503/predict/. It also accepts parameters `start` and `stop` to limit the computation to select page numbers (boundaries are included).

The response is a string with the markdown text of the document.

```sh
curl -X 'POST' \
  'http://127.0.0.1:8503/predict/' \
  -H 'accept: application/json' \
  -H 'Content-Type: multipart/form-data' \
  -F 'file=@<PDFFILE.pdf>;type=application/pdf'
```
To use the limit the conversion to pages 1 to 5, use the start/stop parameters in the request URL: http://127.0.0.1:8503/predict/?start=1&stop=5

## Dataset
### Generate dataset

To generate a dataset you need 

1. A directory containing the PDFs
2. A directory containing the `.html` files (processed `.tex` files by [LaTeXML](https://math.nist.gov/~BMiller/LaTeXML/)) with the same folder structure
3. A binary file of [pdffigures2](https://github.com/allenai/pdffigures2) and a corresponding environment variable `export PDFFIGURES_PATH="/path/to/binary.jar"`

Next run

```
python -m nougat.dataset.split_htmls_to_pages --html path/html/root --pdfs path/pdf/root --out path/paired/output --figure path/pdffigures/outputs
```

Additional arguments include

| Argument              | Description                                |
| --------------------- | ------------------------------------------ |
| `--recompute`         | recompute all splits                       |
| `--markdown MARKDOWN` | Markdown output dir                        |
| `--workers WORKERS`   | How many processes to use                  |
| `--dpi DPI`           | What resolution the pages will be saved at |
| `--timeout TIMEOUT`   | max time per paper in seconds              |
| `--tesseract`         | Tesseract OCR prediction for each page     |

Finally create a `jsonl` file that contains all the image paths, markdown text and meta information.

```
python -m nougat.dataset.create_index --dir path/paired/output --out index.jsonl
```

For each `jsonl` file you also need to generate a seek map for faster data loading:

```
python -m nougat.dataset.gen_seek file.jsonl
```

The resulting directory structure can look as follows:

```
root/
├── images
├── train.jsonl
├── train.seek.map
├── test.jsonl
├── test.seek.map
├── validation.jsonl
└── validation.seek.map
```

Note that the `.mmd` and `.json` files in the `path/paired/output` (here `images`) are no longer required.
This can be useful for pushing to a S3 bucket by halving the amount of files.

## Training

To train or fine tune a Nougat model, run 

```
python train.py --config config/train_nougat.yaml
```

## Evaluation

Run 

```
python test.py --checkpoint path/to/checkpoint --dataset path/to/test.jsonl --save_path path/to/results.json
```

To get the results for the different text modalities, run

```
python -m nougat.metrics path/to/results.json
```

## FAQ

- Why am I only getting `[MISSING_PAGE]`?

  Nougat was trained on scientific papers found on arXiv and PMC. Is the document you're processing similar to that?
  What language is the document in? Nougat works best with English papers, other Latin-based languages might work. **Chinese, Russian, Japanese etc. will not work**.
  If these requirements are fulfilled it might be because of false positives in the failure detection, when computing on CPU or older GPUs ([#11](https://github.com/facebookresearch/nougat/issues/11)). Try passing the `--no-skipping` flag for now.

- Where can I download the model checkpoint from.

  They are uploaded here on GitHub in the release section. You can also download them during the first execution of the program. Choose the preferred preferred model by passing `--model 0.1.0-{base,small}`

## Citation

```
@misc{blecher2023nougat,
      title={Nougat: Neural Optical Understanding for Academic Documents}, 
      author={Lukas Blecher and Guillem Cucurull and Thomas Scialom and Robert Stojnic},
      year={2023},
      eprint={2308.13418},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

## Acknowledgments

This repository builds on top of the [Donut](https://github.com/clovaai/donut/) repository.

## License

Nougat codebase is licensed under MIT.

Nougat model weights are licensed under CC-BY-NC.


================================================
FILE: app.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import sys
from functools import partial
from http import HTTPStatus
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from pathlib import Path
import hashlib
from fastapi.middleware.cors import CORSMiddleware
import pypdfium2
import torch
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible, close_envs
from nougat.utils.dataset import ImageDataset
from nougat.utils.checkpoint import get_checkpoint
from nougat.dataset.rasterize import rasterize_paper
from nougat.utils.device import move_to_device, default_batch_size
from tqdm import tqdm


SAVE_DIR = Path("./pdfs")
BATCHSIZE = int(os.environ.get("NOUGAT_BATCHSIZE", default_batch_size()))
NOUGAT_CHECKPOINT = get_checkpoint()
if NOUGAT_CHECKPOINT is None:
    print(
        "Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!"
    )
    sys.exit(1)

app = FastAPI(title="Nougat API")
origins = ["http://localhost", "http://127.0.0.1"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
model = None


@app.on_event("startup")
async def load_model(
    checkpoint: str = NOUGAT_CHECKPOINT,
):
    global model, BATCHSIZE
    if model is None:
        model = NougatModel.from_pretrained(checkpoint)
        model = move_to_device(model, cuda=BATCHSIZE > 0)
        if BATCHSIZE <= 0:
            BATCHSIZE = 1
        model.eval()


@app.get("/")
def root():
    """Health check."""
    response = {
        "status-code": HTTPStatus.OK,
        "data": {},
    }
    return response


@app.post("/predict/")
async def predict(
    file: UploadFile = File(...), start: int = None, stop: int = None
) -> str:
    """
    Perform predictions on a PDF document and return the extracted text in Markdown format.

    Args:
        file (UploadFile): The uploaded PDF file to process.
        start (int, optional): The starting page number for prediction.
        stop (int, optional): The ending page number for prediction.

    Returns:
        str: The extracted text in Markdown format.
    """
    pdfbin = file.file.read()
    pdf = pypdfium2.PdfDocument(pdfbin)
    md5 = hashlib.md5(pdfbin).hexdigest()
    save_path = SAVE_DIR / md5

    if start is not None and stop is not None:
        pages = list(range(start - 1, stop))
    else:
        pages = list(range(len(pdf)))
    predictions = [""] * len(pages)
    dellist = []
    if save_path.exists():
        for computed in (save_path / "pages").glob("*.mmd"):
            try:
                idx = int(computed.stem) - 1
                if idx in pages:
                    i = pages.index(idx)
                    print("skip page", idx + 1)
                    predictions[i] = computed.read_text(encoding="utf-8")
                    dellist.append(idx)
            except Exception as e:
                print(e)
    compute_pages = pages.copy()
    for el in dellist:
        compute_pages.remove(el)
    images = rasterize_paper(pdf, pages=compute_pages)
    global model

    dataset = ImageDataset(
        images,
        partial(model.encoder.prepare_input, random_padding=False),
    )

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCHSIZE,
        pin_memory=True,
        shuffle=False,
    )

    for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
        if sample is None:
            continue
        model_output = model.inference(image_tensors=sample)
        for j, output in enumerate(model_output["predictions"]):
            if model_output["repeats"][j] is not None:
                if model_output["repeats"][j] > 0:
                    disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n"
                else:
                    disclaimer = (
                        "\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n"
                    )
                rest = close_envs(model_output["repetitions"][j]).strip()
                if len(rest) > 0:
                    disclaimer = disclaimer % rest
                else:
                    disclaimer = ""
            else:
                disclaimer = ""

            predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = (
                markdown_compatible(output) + disclaimer
            )

    (save_path / "pages").mkdir(parents=True, exist_ok=True)
    pdf.save(save_path / "doc.pdf")
    if len(images) > 0:
        thumb = Image.open(images[0])
        thumb.thumbnail((400, 400))
        thumb.save(save_path / "thumb.jpg")
    for idx, page_num in enumerate(pages):
        (save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text(
            predictions[idx], encoding="utf-8"
        )
    final = "".join(predictions).strip()
    (save_path / "doc.mmd").write_text(final, encoding="utf-8")
    return final


def main():
    import uvicorn

    uvicorn.run("app:app", port=8503)


if __name__ == "__main__":
    main()


================================================
FILE: config/train_nougat.yaml
================================================
resume_from_checkpoint_path: null
result_path: "result"
model_path: null
dataset_paths: ["path/to/train.jsonl"]
tokenizer: "dataset/tokenizer.json"
exp_name: "nougat"
train_batch_sizes: [1]
num_workers: 8
val_batch_sizes: [1]
val_batches: 1
input_size: [896, 672]
max_length: 4096
max_position_embeddings: 4096
accumulate_grad_batches: 3
window_size: 7
patch_size: 4
embed_dim: 128
hidden_dimension: 1024
num_heads: [4, 8, 16, 32]
encoder_layer: [2, 2, 14, 2]
decoder_layer: 10
align_long_axis: False
num_nodes: 1
seed: 25
lr: 5e-5
min_lr: 7.5e-6
lr_step: 16
gamma: 0.9996
warmup_steps: 250
num_training_samples_per_epoch: 10000
max_epochs: 30
max_steps: -1
val_check_interval: null
check_val_every_n_epoch: 1
gradient_clip_val: 0.5
verbose: False


================================================
FILE: docker/Dockerfile
================================================
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
# replace CUDA version to your CUDA version.
# You can check your CUDA version with below.
# nvcc -V

RUN apt-get update
RUN apt-get install -y python3
RUN apt-get -y install python3-pip git
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# replace CUDA version to your CUDA version.

RUN mkdir workspace
WORKDIR /workspace

RUN pip3 install fastapi uvicorn[standard] fsspec[http]==2023.1.0
RUN git clone https://github.com/facebookresearch/nougat.git
WORKDIR /workspace/nougat

RUN python3 setup.py install

EXPOSE 8503

CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8503"]
# Run this using 'docker run -it -d -p <YOUR PORT>:8503 --gpus all <IMAGE NAME>


================================================
FILE: docker/README.md
================================================
## Prerequisites
Ensure you have Docker installed on your machine. 
And you must also have NVIDIA CUDA and CuDNN installed in your machine. 

Then, you must check your machine's CUDA version.
```sh
nvcc -V
```

You must change base image name and pytorch version compatible with your **CUDA version**. 

## Building the Docker Image
Clone this repository and navigate into the current directory(nougat/docker). You can build the Docker image by running:
```sh
docker build -t <image-name> .
```
Replace <image-name> with a name of your choice. This will be used to refer to the image later.
Please be patient as this operation can take a while. It needs to pull the CUDA-capable image from NVIDIA’s Docker repository and install several libraries.
Image size will be about 17GB.


## Running the Docker Container
You can run your Docker container with the following command:
```sh
docker run -it -d -p <your-port>:8503 --gpus all <image-name>
```
Replace <your-port> with the port number you wish to expose on your host machine to access the nougat API server.
This can be any valid port number. Replace <image-name> with the name you chose earlier during the build step.


## Testing the API Server
Once the Docker container is running, you can access the nougat API server.
You can easily check connection by running:
```sh
curl -X 'GET' \
  'http://127.0.0.1:<your-port>/'
```
It can be take a while for loading API server, because the server have to download nougat model at startup.

If connection is successful, you can get response looks like this.
```
{"status-code":200,"data":{}}
```

## Using the API Server
To get a prediction of a PDF file by making a POST request to `http://127.0.0.1:<your-port>/predict/`. It also accepts parameters `start` and `stop` to limit the computation to select page numbers (boundaries are included).

The response is a string with the markdown text of the document.

```sh
curl -X 'POST' \
  'http://127.0.0.1:<your-port>/predict/' \
  -H 'accept: application/json' \
  -H 'Content-Type: multipart/form-data' \
  -F 'file=@<PDFFILE.pdf>;type=application/pdf'
```
To use the limit the conversion to pages 1 to 5, use the start/stop parameters in the request URL: 
`http://127.0.0.1:<your-port>/predict/?start=1&stop=5`




================================================
FILE: lightning_module.py
================================================
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
Copyright (c) Meta Platforms, Inc. and affiliates.
"""
import math
import random
from pathlib import Path

import numpy as np
import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities import rank_zero_only
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from nougat import NougatConfig, NougatModel
from nougat.metrics import get_metrics


class NougatModelPLModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.validation_step_outputs = []
        self.config = config
        if self.config.get("model_path", False):
            self.model = NougatModel.from_pretrained(
                self.config.model_path,
                input_size=self.config.input_size,
                max_length=self.config.max_length,
                align_long_axis=self.config.align_long_axis,
                window_size=self.config.window_size,
                encoder_layer=self.config.encoder_layer,
                decoder_layer=self.config.decoder_layer,
                patch_size=self.config.patch_size,
                embed_dim=self.config.embed_dim,
                num_heads=self.config.num_heads,
                hidden_dimension=self.config.hidden_dimension,
                ignore_mismatched_sizes=True,
            )
        else:
            self.model = NougatModel(
                config=NougatConfig(
                    input_size=self.config.input_size,
                    max_length=self.config.max_length,
                    align_long_axis=self.config.align_long_axis,
                    window_size=self.config.window_size,
                    encoder_layer=self.config.encoder_layer,
                    decoder_layer=self.config.decoder_layer,
                    tokenizer_file=self.config.tokenizer,
                    patch_size=self.config.patch_size,
                    embed_dim=self.config.embed_dim,
                    num_heads=self.config.num_heads,
                    hidden_dimension=self.config.hidden_dimension,
                )
            )

    def training_step(self, batch, batch_idx):
        image_tensors, decoder_input_ids, attention_masks = list(), list(), list()
        if batch is None:
            return
        for batch_data in batch:
            if batch_data is None or batch_data[0] is None:
                continue
            image_tensors.append(batch_data[0])
            decoder_input_ids.append(batch_data[1])
            attention_masks.append(batch_data[2])
        image_tensors = torch.cat(image_tensors)
        decoder_input_ids = torch.cat(decoder_input_ids)
        attention_masks = torch.cat(attention_masks)
        loss = self.model(image_tensors, decoder_input_ids, attention_masks)[0]
        if loss is not None:
            self.log_dict({"train/loss": loss}, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        if batch is None:
            return
        image_tensors, decoder_input_ids, _ = batch
        if image_tensors is None:
            return
        markdown = pad_sequence(
            decoder_input_ids,
            batch_first=True,
        )
        preds = self.model.inference(
            image_tensors=image_tensors,
            return_attentions=False,
        )["predictions"]
        gts = self.model.decoder.tokenizer.batch_decode(
            markdown, skip_special_tokens=True
        )
        metrics = get_metrics(gts, preds, pool=False)
        scores = {
            "val/" + key: sum(values) / len(values) for key, values in metrics.items()
        }
        self.validation_step_outputs.append(scores)
        return scores

    def on_validation_epoch_end(self):
        if (
            self.validation_step_outputs is not None
            and len(self.validation_step_outputs) >= 1
        ):
            self.log_dict(self.validation_step_outputs[0], sync_dist=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        def _get_device_count():
            if torch.cuda.is_available():
                return torch.cuda.device_count()
            elif torch.backends.mps.is_available():
                # Can MPS have more than one device?
                return 1
            return 1

        max_iter = None

        if int(self.config.get("max_epochs", -1)) > 0:
            assert (
                len(self.config.train_batch_sizes) == 1
            ), "Set max_epochs only if the number of datasets is 1"
            steps = self.config.num_training_samples_per_epoch
            max_iter = (self.config.max_epochs * steps) / max(
                1,
                (
                    self.config.train_batch_sizes[0]
                    * _get_device_count()
                    * self.config.get("num_nodes", 1)
                ),
            )

        if int(self.config.get("max_steps", -1)) > 0:
            max_iter = (
                min(self.config.max_steps, max_iter)
                if max_iter is not None
                else self.config.max_steps
            )

        assert max_iter is not None
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr)
        scheduler = {
            "scheduler": self.exponential_scheduler(
                optimizer,
                self.config.warmup_steps,
                self.config.lr,
                self.config.get("min_lr", 5e-5),
                self.config.get("gamma", 0.9996),
            ),
            "name": "learning_rate",
            "interval": "step",
            "frequency": self.config.get("lr_step", 1),
        }
        return [optimizer], [scheduler]

    @staticmethod
    def cosine_scheduler(optimizer, training_steps, warmup_steps):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return current_step / max(1, warmup_steps)
            progress = current_step - warmup_steps
            progress /= max(1, training_steps - warmup_steps)
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        return LambdaLR(optimizer, lr_lambda)

    @staticmethod
    def exponential_scheduler(optimizer, warmup_steps, lr, min_lr=5e-5, gamma=0.9999):
        def lr_lambda(x):
            if x > warmup_steps or warmup_steps <= 0:
                if lr * gamma ** (x - warmup_steps) > min_lr:
                    return gamma ** (x - warmup_steps)
                else:
                    return min_lr / lr
            else:
                return x / warmup_steps

        return LambdaLR(optimizer, lr_lambda=lr_lambda)

    def get_progress_bar_dict(self):
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        items["exp_name"] = f"{self.config.get('exp_name', '')}"
        items["exp_version"] = f"{self.config.get('exp_version', '')}"
        return items

    @rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        save_path = (
            Path(self.config.result_path)
            / self.config.exp_name
            / self.config.exp_version
        )
        self.model.save_pretrained(save_path)
        self.model.decoder.tokenizer.save_pretrained(save_path)


class NougatDataPLModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.train_batch_sizes = self.config.train_batch_sizes
        self.val_batch_sizes = self.config.val_batch_sizes
        self.train_datasets = []
        self.val_datasets = []
        self.g = torch.Generator()
        self.g.manual_seed(self.config.seed)

    def train_dataloader(self):
        loaders = [
            DataLoader(
                torch.utils.data.ConcatDataset(self.train_datasets),
                batch_size=self.train_batch_sizes[0],
                num_workers=self.config.num_workers,
                pin_memory=True,
                worker_init_fn=self.seed_worker,
                generator=self.g,
                shuffle=True,
                collate_fn=self.ignore_none_collate,
            )
        ]
        return loaders

    def val_dataloader(self):
        loaders = [
            DataLoader(
                torch.utils.data.ConcatDataset(self.val_datasets),
                batch_size=self.val_batch_sizes[0],
                pin_memory=True,
                shuffle=True,
                collate_fn=self.ignore_none_collate,
            )
        ]
        return loaders

    @staticmethod
    def seed_worker(wordker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    @staticmethod
    def ignore_none_collate(batch):
        if batch is None:
            return
        try:
            batch = [x for x in batch if x is not None and x[0] is not None]
            if len(batch) == 0:
                return
            return torch.utils.data.dataloader.default_collate(batch)
        except AttributeError:
            pass


================================================
FILE: nougat/__init__.py
================================================
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
Copyright (c) Meta Platforms, Inc. and affiliates.
"""
from .model import NougatConfig, NougatModel
from .utils.dataset import NougatDataset
from ._version import __version__

__all__ = [
    "NougatConfig",
    "NougatModel",
    "NougatDataset",
]


================================================
FILE: nougat/_version.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

__version__ = "0.1.18"


================================================
FILE: nougat/dataset/__init__.py
================================================


================================================
FILE: nougat/dataset/create_index.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
"""
This script creates an index of all available pages and parses the meta data for all pages into a separate file.
Optionally TesseractOCR is called for each image.
"""
import argparse
import json
from typing import Dict, List
import numpy as np
from pathlib import Path
import multiprocessing
from pebble import ProcessPool
from PIL import Image
import pytesseract
import re
import logging
from tqdm import tqdm


logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def convert_pt2px(pt, dpi=96):
    if isinstance(pt, list):
        return [round(dpi / 72 * p) for p in pt]
    elif isinstance(pt, dict):
        for k in pt:
            pt[k] = round(dpi / 72 * pt[k])
        return pt


def read_metadata(data: Dict) -> List[List[Dict]]:
    N = data["num_pages"]
    out = [[] for _ in range(N)]
    # pdffigures2 meta data
    if "pdffigures" in data and data["pdffigures"]:
        for item in data["pdffigures"]:
            p = item.pop("page", None)
            if p is None or p >= N:
                continue
            item["source"] = "fig"
            if "regionBoundary" in item:
                item["regionBoundary"] = convert_pt2px(item["regionBoundary"])
            if "captionBoundary" in item:
                item["captionBoundary"] = convert_pt2px(item["captionBoundary"])
            out[p].append(item)

    return out


def index_paper(directory: Path, args: argparse.Namespace):
    """
    Pack all image-text pairs into a single h5 file and save it at `args.out`
    """
    paper = directory.name
    markdowns = directory.glob("*.mmd")
    meta_file = directory / "meta.json"
    data_samples = []
    if not meta_file.exists():
        return
    # load meta info
    try:
        meta = read_metadata(json.load(meta_file.open("r", encoding="utf-8")))
    except json.JSONDecodeError:
        return

    for md_path in markdowns:
        image = md_path.parent / (md_path.stem + ".png")
        i = int(image.stem) - 1
        if not image.exists():
            continue
        if i >= len(meta):
            continue
        data_sample = {}
        ocr_path = image.parent / (image.stem + "_OCR.txt")
        if args.tesseract and not ocr_path.exists():
            try:
                pil = Image.open(image)
                ocr = pytesseract.image_to_string(pil, lang="eng", timeout=2)
                ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
                with ocr_path.open("w", encoding="utf-8") as f_ocr:
                    f_ocr.write(ocr)
            except RuntimeError:
                logger.info("Page %s of paper %s timed out", image.stem, paper)
                pass
        if ocr_path.exists():
            data_sample["ocr"] = str(ocr_path.relative_to(args.root))
        data_sample["image"] = str(image.relative_to(args.root))
        data_sample["markdown"] = md_path.read_text(encoding="utf8").strip()
        data_sample["meta"] = meta[i]
        data_samples.append(data_sample)
    return data_samples


def create_index(args):
    if not args.dir.exists() and not args.dir.is_dir():
        logger.error("%s does not exist or is no dir.", args.dir)
        return
    papers = []
    depth = 0
    p = args.dir
    while True:
        p = next(p.iterdir())
        if p.is_file():
            break
        else:
            depth += 1
    papers = args.dir.glob("*/" * depth)
    index = []
    with ProcessPool(max_workers=args.workers) as pool:
        tasks = {}
        for j, paper in enumerate(papers):
            fname = paper.name
            tasks[fname] = pool.schedule(
                index_paper,
                args=[paper, args],
                timeout=args.timeout,
            )

        for fname in tqdm(tasks):
            try:
                res = tasks[fname].result()
                if res is None:
                    logger.info("%s is faulty", fname)
                    continue
                index.append(res)
            except TimeoutError:
                logger.info("%s timed out", fname)

        with args.out.open("w", encoding="utf-8") as f:
            for item in index:
                for page in item:
                    if len(page) == 0:
                        continue
                    f.write(json.dumps(page) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", type=Path, required=True, help="Index file")
    parser.add_argument(
        "--dir", type=Path, required=True, help="Parent directory for input dirs"
    )
    parser.add_argument("--root", type=Path, default=None)
    parser.add_argument(
        "--tesseract",
        action="store_true",
        help="Tesseract OCR prediction for each page",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=multiprocessing.cpu_count(),
        help="How many processes to use",
    )
    parser.add_argument(
        "--dpi", type=int, default=96, help="DPI the images were saved with"
    )
    parser.add_argument("--timeout", type=int, default=240, help="Max time per paper")
    args = parser.parse_args()
    if args.root is None:
        args.root = args.dir
    else:
        # check if dir is subdir of root
        args.dir.relative_to(args.root)
    create_index(args)


================================================
FILE: nougat/dataset/gen_seek.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from tqdm import tqdm
import json
from pathlib import Path
import argparse


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("src_file", nargs="+", type=Path, help="JSONL file in question")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_args()
    for file in args.src_file:
        seek_map = []
        seek_pos = 0
        with open(file) as f:
            with tqdm(smoothing=0.0) as pbar:
                line = f.readline()
                while line:
                    seek_map.append(seek_pos)
                    seek_pos = f.tell()
                    line = f.readline()
                    pbar.update(1)

        out_file = file.parent / (file.stem + ".seek.map")
        with open(out_file, "w") as f:
            f.write(json.dumps(seek_map))


================================================
FILE: nougat/dataset/parser/__init__.py
================================================


================================================
FILE: nougat/dataset/parser/document.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from collections import defaultdict
from copy import copy
import itertools
import re
from dataclasses import dataclass, field, asdict
from typing import (
    Any,
    List,
    Dict,
    Optional,
    TypeVar,
    Type,
    Generic,
)
import numpy as np

import logging

logger = logging.getLogger()

from dataclasses import dataclass, field, asdict
from typing import List, Dict, TypeVar, Type, Generic

T = TypeVar("T")
EL = TypeVar("EL")


@dataclass
class Element(Generic[EL]):
    """
    Generic class representing an element with children in a tree-like structure.

    Attributes:
        parent (Element): The parent element.
        children (List[Element]): List of child elements.
    """

    parent: "Element" = None
    children: List["Element"] = field(default_factory=list)

    @property
    def plaintext(self):
        return "".join([child.plaintext for child in self.children])

    def append(self, child: EL) -> EL:
        self.children.append(child)
        child.parent = self
        return child

    def find_parent(self, class_or_tuple: Type[T]) -> T:
        elem = self
        while elem:
            if isinstance(elem, class_or_tuple):
                return elem
            elem = elem.parent
        return None


@dataclass
class UnknownElement(Element):
    pass


@dataclass
class TextElement(Element):
    content: str = ""

    @property
    def plaintext(self):
        return self.content

    def append(self, child: "Element"):
        raise Exception(f"Cannot append elements to {self.__class__.__name__}")


@dataclass
class Math(Element):
    pass


@dataclass
class PlaintextMath(Math):
    pass


@dataclass
class LatexMath(Math):
    inline: bool = True
    code: str = ""

    @property
    def plaintext(self):
        return self.code


@dataclass
class Author:
    fullname: str = None
    lastname: str = None
    affiliation: str = None


@dataclass
class Link(Element):
    target: str = None


@dataclass
class InlineRef(Element):
    target: str = None

    def as_dict(self):
        return {
            "target": self.target,
        }


@dataclass
class Reference:
    """
    Data class representing a reference with various attributes.

    Attributes:
        title (Element): The title of the reference.
        authors (List[Author]): List of authors of the reference.
        ids (Dict[str, str]): Dictionary of identification information.
        date (str): The publication date of the reference.
        url (str): The URL link to the reference.
        journal (str): The journal where the reference is published.
        full_text (str): The full text content of the reference.

    Methods:
        as_dict(): Convert the reference object to a dictionary.
    """

    title: Element = None
    authors: List[Author] = field(default_factory=list)
    ids: Dict[str, str] = field(default_factory=dict)
    date: str = None
    url: str = None
    journal: str = None
    full_text: str = None

    def as_dict(self):
        return {
            "title": self.title.plaintext,
            "authors": [asdict(auth) for auth in self.authors],
            "ids": self.ids,
            "date": self.date,
            "url": self.url,
            "journal": self.journal,
            "full_text": self.full_text,
        }


@dataclass
class SpanElement(Element):
    pass


@dataclass
class Italic(SpanElement):
    pass


@dataclass
class Bold(SpanElement):
    pass


@dataclass
class Superscript(SpanElement):
    pass


@dataclass
class Subscript(SpanElement):
    pass


@dataclass
class Paragraph(Element):
    pass


@dataclass
class TableRow(Element):
    cells: List[Element] = field(default_factory=list)

    def add_cell(self, cell: Element):
        self.cells.append(cell)
        cell.parent = self
        return cell

    @property
    def plaintext(self):
        return "\t".join([cell.plaintext for cell in self.cells])


@dataclass
class TableHead(TableRow):
    pass


@dataclass
class Table(Element):
    id: str = None
    header: Element = None
    caption: Element = None
    rows: List[TableRow] = field(default_factory=list)
    keep_table: bool = False

    def add_row(self, row: TableRow) -> TableRow:
        self.rows.append(row)
        row.parent = self
        return row

    @property
    def plaintext(self):
        return "\n".join([row.plaintext for row in self.rows])


@dataclass
class Equation(Element):
    pass


@dataclass
class EquationList(Element):
    equations: List[Equation] = field(default_factory=list)

    def add_equation(self, eqn: Equation) -> Equation:
        self.equations.append(eqn)
        eqn.parent = self
        return eqn

    @property
    def plaintext(self):
        return "\n".join([eqn.plaintext for eqn in self.equations])


@dataclass
class Algorithm(Element):
    caption: Element = None
    lines: List[Element] = field(default_factory=list)
    inline: bool = False

    def add_line(self, line: Element) -> Element:
        self.lines.append(line)
        line.parent = self
        return line

    @property
    def plaintext(self):
        return "\n".join([line.plaintext for line in self.lines])


@dataclass
class Definition(Element):
    term: Element = None
    definition: Element = None

    @property
    def plaintext(self):
        parts = []
        if self.term:
            parts.append(f"{self.term.plaintext}:")
        if self.definition:
            parts.append(self.definition.plaintext)
        return " ".join(parts)


@dataclass
class DefinitionList(Element):
    """
    Data class representing a list of definitions with an optional header.

    Attributes:
        header (Element): The header element for the definition list.
        items (List[Definition]): List of Definition elements.

    Methods:
        add_item(item: Definition) -> Definition: Add a definition item to the list.
    """

    header: Element = None
    items: List[Element] = field(default_factory=list)

    def add_item(self, item: Definition) -> Definition:
        self.items.append(item)
        item.parent = self
        return item

    @property
    def plaintext(self):
        parts = []
        if self.header:
            parts.append(self.header.plaintext)
        parts.extend([df.plaintext for df in self.items])
        return "\n".join(parts)


@dataclass
class Figure(Element):
    id: str = None
    header: Element = None
    caption: Element = None


@dataclass
class Section(Element):
    id: str = None
    header: Element = None
    level: int = 0
    hnum: int = 1


@dataclass
class SectionHeader(Element):
    id: str = None
    header: Element = None
    level: int = 0


@dataclass
class ListItem(Element):
    label: str = ""


@dataclass
class ListContainer(Element):
    level: int = 0
    ordered: bool = False
    items: List[Element] = field(default_factory=list)

    def add_item(self, item: ListItem) -> ListItem:
        self.items.append(item)
        item.parent = self
        return item

    @property
    def plaintext(self):
        return "\n".join([item.plaintext for item in self.items])


@dataclass
class Footnote(Element):
    id: str = None


@dataclass
class Document(Element, Reference):
    abstract: Element = None
    language: str = None
    keywords: List[Element] = field(default_factory=list)
    references: List[Reference] = field(default_factory=list)
    inline_refs: List[InlineRef] = field(default_factory=list)
    bib: Reference = None

    def add_reference(self, reference):
        self.references.append(reference)

    def add_inline_ref(self, in_ref):
        self.inline_refs.append(in_ref)

    def set_bib(self, reference):
        self.bib = reference


@dataclass
class Spec:
    """
    Data class representing specifications for table cells.

    Attributes:
        t (int): The top border size.
        b (int): The bottom border size.
        l (int): The left border size.
        r (int): The right border size.
        align (str): The alignment of the cell content ('c' for center, 'l' for left, 'r' for right,
                     or 'p{width}' for justified with a specified width).

    Methods:
        __hash__() -> int: Compute the hash of the specification.
        __eq__(__o: object) -> bool: Check if two specifications are equal.
        set_align(classes: List[str], style: Optional[str] = None) -> None:
            Extract alignment information from HTML classes.
        set_border(classes: List[str]) -> None: Automatically set border specifications.
        set_attrs(attrs: Dict[str, Any]) -> None: Automatically set all attributes from HTML class attributes.
        __str__() -> str: Get the string representation of the specification.
    """

    t: int = field(default=0, repr=False)
    b: int = field(default=0, repr=False)
    l: int = field(default=0)
    r: int = field(default=0)
    align: str = field(default="")

    def __hash__(self) -> int:
        return hash(repr(self))

    def __eq__(self, __o: object) -> bool:
        return repr(self) == repr(__o)

    def set_align(self, classes: List[str], style: Optional[str] = None) -> None:
        """extract alignment information from available classes (html)"""
        aligns = [s for s in classes if "align" in s]
        if len(aligns) == 0:
            return
        elif len(aligns) > 1:
            logger.warn("Found multiple aligns in classes: %s", ", ".join(classes))
        align = aligns[0]
        if "center" in align or align == "c":
            self.align = "c"
        elif "left" in align or align == "l":
            self.align = "l"
        elif "right" in align or align == "r":
            self.align = "r"
        elif "justify" in align or align == "p":
            # assert style is not None, "justify without style information"
            if style is None:
                self.align = "c"
            else:
                width = style.partition("width:")[2].partition(";")[0]
                self.align = "p{%s}" % width
        else:
            logger.warn(
                "only center, left, right, justify supported at the moment. Found %s",
                align,
            )
            self.align = "c"

    def set_border(self, classes: List[str]) -> None:
        """automatically set spec with border classes e.g 'ltx_border_t'"""
        for border in classes:
            orientation = border.partition("border_")[2]
            if len(orientation) > 0 and orientation[0] in "tbrl":
                setattr(self, orientation[0], len(orientation))

    def set_attrs(self, attrs: Dict[str, Any]) -> None:
        """automatically set all attr from html class attributes"""
        classes = attrs["class"]
        style = attrs["style"] if "style" in attrs else None

        self.set_align(classes, style=style)
        self.set_border(classes)

    def __str__(self) -> str:
        if self.align:
            return "|" * self.l + self.align + "|" * self.r
        else:
            # default center
            return "|" * self.l + "c" + "|" * self.r


@dataclass
class TableCell(Element):
    """
    Represents a cell in an HTML table.

    Attributes:
        multicolumn (Optional[int]): The number of columns spanned by the cell.
        multirow (Optional[int]): The number of rows spanned by the cell.
        spec (Spec): The specification for the cell's formatting.
        content (Element): The content of the cell.

    Methods:
        __post_init__(*args, **kwargs) -> None: Initialize the cell, ensuring that the spec property is not None.
        __hash__() -> int: Compute the hash of the cell.
        __eq__(__o: object) -> bool: Check if two cells are equal.
        set_attrs(attrs: Dict[str, Any]) -> None: Set attributes for the cell from HTML attributes.
        plaintext() -> str: Get the plaintext content of the cell.
    """

    multicolumn: Optional[int] = None
    multirow: Optional[int] = None
    spec: Spec = None
    content: Element = None

    def __post_init__(self, *args, **kwargs) -> None:
        # spec property cannot be None
        if self.spec is None:
            self.spec = Spec()

    def __hash__(self) -> int:
        return hash(repr(self))

    def __eq__(self, __o: object) -> bool:
        return repr(self) == repr(__o)

    def set_attrs(self, attrs: Dict[str, Any]) -> None:
        if "colspan" in attrs:
            self.multicolumn = int(attrs["colspan"])
        if "rowspan" in attrs:
            self.multirow = int(attrs["rowspan"])
        self.spec.set_attrs(attrs)

    @property
    def plaintext(self):
        if self.content is None:
            return ""
        return self.content.plaintext


@dataclass
class TableRow(Element):
    """
    Represents a row in an HTML table.

    Attributes:
        cells (List[TableCell]): The list of cells in the row.

    Methods:
        add_cell(cell: TableCell) -> TableCell: Add a cell to the row.
        __iter__() -> Iterator: Iterate through the cells in the row.
        __len__() -> int: Get the number of cells in the row.
        __bool__() -> bool: Check if the row is not empty.
        cum_cell_widths() -> List[int]: Get the cumulative cell widths.
        cell_widths() -> List[int]: Get the widths of individual cells.
        width() -> int: Get the total width of the row.
        _hline(orientation: str) -> str: Determine horizontal lines to be inserted.
        hline_above() -> str: Get the horizontal line description for the top of the row.
        hline_below() -> str: Get the horizontal line description for the bottom of the row.
        plaintext() -> str: Get the plaintext content of the row.
    """

    cells: List[TableCell] = field(default_factory=list)

    def add_cell(self, cell: TableCell):
        self.cells.append(cell)
        cell.parent = self
        return cell

    def __iter__(self):
        return iter(self.cells)

    def __len__(self) -> int:
        return len(self.cells)

    def __bool__(self) -> bool:
        return True

    @property
    def cum_cell_widths(self) -> List[int]:
        return np.cumsum(self.cell_widths)

    @property
    def cell_widths(self) -> List[int]:
        return [(cell.multicolumn or 1) for cell in self.cells]

    @property
    def width(self) -> int:
        return sum(self.cell_widths)

    def _hline(self, orientation: str) -> str:
        """Figure out if and where horizontal lines need to be inserted.

        Args:
            orientation (str): Either 't' (top) or 'b' (bottom)

        Returns:
            str: Correct vertical line description for latex tables.
        """
        assert orientation == "t" or orientation == "b"
        lines = []
        for cell in self.cells:
            lines.extend([getattr(cell.spec, orientation)] * (cell.multicolumn or 1))
        lines.append(0)
        indices = []
        start = None
        for i, v in enumerate(lines):
            if v and start is None:
                start = i
            elif start is not None and not v:
                indices.append((start, i - 1))
                start = None
        s = ""
        for a, b in indices:
            if b - a + 1 == self.width:
                s += "\\hline " * lines[0]
            else:
                s += "\\cline{%i-%i} " % (a + 1, b + 1)
        return s.strip()

    @property
    def hline_above(self) -> str:
        return self._hline("t")

    @property
    def hline_below(self) -> str:
        return self._hline("b")

    @property
    def plaintext(self) -> str:
        return "\t".join([cell.plaintext for cell in self.cells])


@dataclass
class Tabular(Element):
    rows: List[TableRow] = field(default_factory=list)
    """
    Represents a tabular structure, such as an HTML table.

    Attributes:
        rows (List[TableRow]): The list of rows in the tabular structure.

    Methods:
        add_row(row: TableRow) -> TableRow: Add a row to the tabular structure.
        width() -> int: Get the maximum width of the tabular structure.
        cols() -> List[List[TableCell]]: Get a list of columns in the tabular structure.
        _square_table() -> None: Ensure the table has an equal number of columns in each row.
        get_table_spec() -> str: Generate a LaTeX table specification based on cell alignments.
        plaintext() -> str: Get the plaintext content of the tabular structure.
    """

    def add_row(self, row: TableRow) -> TableRow:
        self.rows.append(row)
        row.parent = self
        return row

    @property
    def width(self) -> int:
        if len(self.rows) > 0:
            return max([r.width for r in self.rows])
        else:
            return 0

    @property
    def cols(self) -> List[List[TableCell]]:
        return list(
            map(
                list,
                itertools.zip_longest(*[r.cells for r in self.rows], fillvalue=None),
            )
        )

    def _square_table(self) -> None:
        """check if number of columns is equal for every row. Add placeholders for `\multirow` instances"""
        for i, row in enumerate(self.rows):
            for j, cell in enumerate(row.cells):
                if cell.multirow is not None and cell.multirow > 1:
                    spec = copy(cell.spec)
                    # assume no hlines in multi cells: disable bottom lines for top and top lines for lower cells.
                    spec.t = 0
                    cell.spec.b = 0
                    for k in range(i + 1, i + cell.multirow):
                        if k < len(self.rows):
                            for _ in range(row.cell_widths[j]):
                                # add empty cell
                                self.rows[k].cells.insert(
                                    j, TableCell(parent=self.rows[k], spec=spec)
                                )

    def get_table_spec(self) -> str:
        """Generates a LaTeX table spec."""
        # First make table square
        self._square_table()
        # Find the most used spec in regular cells (no multi-col/row)
        specs = [Spec() for _ in range(self.width)]
        for i, col in enumerate(self.cols):
            counts = defaultdict(int)
            for cell in col:
                if cell is None or cell.spec.align == "":
                    continue
                if cell.multicolumn is None and cell.multirow is None:
                    counts[cell.spec] += 1
            if len(counts) > 0:
                specs[i] = max(counts, key=counts.get)
        # convert all cells that don't match the column style into a multicol{1}{custom_spec}
        for i, col in enumerate(self.cols):
            for cell in col:
                if cell is not None and cell.spec != specs[i]:
                    # check if there is text in the cell. If not alignment doesn't matter
                    if (
                        len(cell.children) == 0
                        and cell.spec.l == specs[i].l
                        and cell.spec.r == specs[i].r
                    ):
                        continue
                    # convert any standard cell into a multicol cell of width 1
                    if cell.multicolumn is None:
                        cell.multicolumn = 1
        # generate final latex table spec
        out = " ".join([str(spec) for spec in specs])
        out = re.sub(r"(\|) +(\w)", r"\1\2", out)
        out = re.sub(r"(\w) +(\|)", r"\1\2", out)
        return out

    @property
    def plaintext(self):
        return "\n".join([row.plaintext for row in self.rows])


@dataclass
class Table(Element):
    id: str = None
    caption: Element = None


================================================
FILE: nougat/dataset/parser/html2md.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
from pathlib import Path
from typing import List, Optional
from bs4 import BeautifulSoup
from tqdm import tqdm
import htmlmin
from nougat.dataset.parser.latexml_parser import parse_latexml, _clean_html_whitespace
from nougat.dataset.parser.markdown import format_document


def check_file_path(paths: List[Path], wdir: Optional[Path] = None) -> List[str]:
    """
    Checks if the given file paths exist.

    Args:
        paths: A list of file paths.
        wdir: The working directory. If None, the current working directory is used.

    Returns:
        A list of file paths that exist.
    """
    files = []
    for path in paths:
        if type(path) == str:
            if path == "":
                continue
            path = Path(path)
        pathsi = [path] if wdir is None else [path, wdir / path]
        for p in pathsi:
            if p.exists():
                files.append((p.resolve()))
            elif "*" in path.name:
                files.extend([(pi.resolve()) for pi in p.parent.glob(p.name)])
    return list(set(files))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--html", type=Path, nargs="+", help="HTML file", required=True)
    parser.add_argument("--out", type=Path, help="Output file", required=True)
    args = parser.parse_args()
    args.html = check_file_path(args.html)
    for f in tqdm(args.html):
        html = BeautifulSoup(
            htmlmin.minify(
                open(f, "r", encoding="utf-8").read().replace("\xa0", " "),
                remove_all_empty_space=1,
            ),
            features="html.parser",
        )
        try:
            doc = parse_latexml(html)
        except ValueError as e:
            print(e)
            continue
        if doc is None:
            continue
        out, fig = format_document(doc, keep_refs=True)
        outp = (args.out if args.out.is_dir() else args.out.parent) / (f.stem + ".mmd")
        with open(outp, "w", encoding="utf-8") as f:
            f.write(out)


================================================
FILE: nougat/dataset/parser/latexml_parser.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import re
import sys
import requests
from typing import Optional, Set
from bs4 import BeautifulSoup, NavigableString
import soupsieve as sv

from nougat.dataset.parser.document import *


def printerr(*args, **kwargs):
    # uncomment for debugging
    # print(*args, **kwargs)
    pass


latexml_wrapper_selector = sv.compile(
    ", ".join(
        [
            ".ltx_engrafo_equation_container",
            "tbody",
            ".ltx_note_content",
            ".ltx_role_footnote",
            ".ltx_note_type",
            ".ltx_theorem",
            ".ltx_proof",
            ".ltx_quote",
            "blockquote",
            ".ltx_inline-para",
            ".ltx_inline-block",
        ]
    )
)
latexml_ignore_selector = sv.compile(".ltx_rule, .ltx_pagination.ltx_role_newpage")


def is_wrapper_element(element: BeautifulSoup) -> bool:
    return latexml_wrapper_selector.match(element)


def ignore_element(element: BeautifulSoup) -> bool:
    return latexml_ignore_selector.match(element)


def _get_classes(el: BeautifulSoup) -> Set[str]:
    if not hasattr(el, "attrs"):
        return set()
    classes = el.attrs.get("class")
    if classes is None:
        return set()
    return set(classes)


def _detach_selected(element: BeautifulSoup, selector: str) -> None:
    for elem in element.select(selector):
        elem.extract()


def parse_latexml_authors(ltx_authors: BeautifulSoup) -> List[Author]:
    authors = Paragraph()
    parse_latexml_children(ltx_authors, authors)
    return authors


def parse_latexml_citations(cite: BeautifulSoup, parent: Element) -> None:
    """
    Parses LaTeXML citations and appends them as children to the given parent element.

    Args:
        cite (BeautifulSoup): The BeautifulSoup object containing the citation data.
        parent (Element): The parent element to which the citations will be added as children.
    """
    parse_latexml_children(cite, parent)
    if ("[" in parent.plaintext and "]" in parent.plaintext) or re.search(
        r"[A-Za-z]", parent.plaintext
    ):
        return

    parent.children.insert(0, TextElement(content="["))
    parent.children.append(TextElement(content="]"))


def _clean_html_whitespace(text: str) -> str:
    if text.strip():
        text = re.sub(r"(^\n+|\n+$)", "\n", text)
    else:
        text = text.strip("\n")
    text = re.sub(r"[ \t]+", " ", text)
    return text


def parse_latexml_children(html: BeautifulSoup, parent: Element) -> None:
    """
    Parses LaTeXML children and appends them as appropriate elements to the given parent element.

    Args:
        html (BeautifulSoup): The BeautifulSoup object containing the HTML data.
        parent (Element): The parent element to which the parsed children will be added.
    """
    if html is None:
        return
    for child in html.children:
        classes = _get_classes(child)
        if isinstance(child, NavigableString):
            parent.append(TextElement(content=_clean_html_whitespace(str(child))))
        elif sv.match(
            "p, .ltx_p, div.ltx_para, span.ltx_para, section.ltx_paragraph", child
        ):
            paragraph = parent.append(Paragraph())
            parse_latexml_children(child, paragraph)
        elif sv.match(".ltx_tag", child):
            if "ltx_tag_note" not in classes:
                if sv.match(".ltx_tag_section", child):
                    child.string = child.string.upper()
                elif sv.match(".ltx_tag_subsection", child):
                    child.string = ""
                parse_latexml_children(child, parent)
            elif "ltx_tag_bibitem" in classes:
                parse_latexml_children(child, parent.append(SpanElement()))
        elif sv.match(".ltx_note_outer", child):
            # try to place the footnote outside the current paragraph
            paragraph = parent.find_parent(Paragraph)
            if paragraph is not None and paragraph.parent is not None:
                footnote = paragraph.parent.append(Footnote())
            else:
                footnote = parent.append(Footnote())
            parse_latexml_children(child, footnote)
        elif sv.match(".ltx_note_content > .ltx_note_mark", child):
            footnote = parent.find_parent(Footnote)
            if footnote is not None:
                footnote.id = child.get_text(strip=True)
            else:
                printerr("Unable to find footnote to set its id", file=sys.stderr)
                parse_latexml_children(child, parent)
        elif sv.match("sup", child):
            sup = parent.append(Superscript())
            parse_latexml_children(child, sup)
        elif sv.match("sub", child):
            sub = parent.append(Subscript())
            parse_latexml_children(child, sub)
        elif sv.match("span.ltx_Math, span.ltx_DisplayMath", child):
            inline = "ltx_DisplayMath" not in classes
            math_elem = child.select_one(".mjx-math")
            if math_elem:
                tex = math_elem.attrs["aria-label"]
                if inline:
                    tex = rf"\({tex}\)"
                else:
                    tex = rf"\[{tex}\]"
                parent.append(LatexMath(code=tex, inline=inline))
        elif sv.match("math.ltx_Math", child):
            # not sure if the math tag LaTeXML version specific, but that seems to work
            inline = True
            if "display" in child.attrs:
                inline = child.attrs["display"] == "inline"
            tex = child.attrs["alttext"]
            if inline:
                tex = rf"\({tex}\)"
            else:
                tex = rf"\[{tex}\]"
            parent.append(LatexMath(code=tex, inline=inline))
        elif sv.match("a.ref", child):
            link = parent.append(Link())
            link.target = child.attrs.get("href")
            parse_latexml_children(child, link)
        elif sv.match(
            ".ltx_ref.ltx_missing_citation, .ltx_ref.ltx_missing_label", child
        ):
            placeholder = child.get_text().strip()
            resolved = False
            if placeholder.isnumeric():
                parent.append(TextElement(content=placeholder))
                resolved = True
            else:
                target = child.attrs.get("href")
                if target is not None:
                    potential_num = target.partition(".bib")[2]
                    if potential_num.isnumeric():
                        parent.append(TextElement(content=potential_num))
                        resolved = True
            if not resolved:
                raise ValueError("missing reference detected")
        elif sv.match(
            ".ltx_bibblock, .ltx_role_author, .ltx_contact, .ltx_role_email, .ltx_role_affiliation",
            child,
        ):
            parse_latexml_children(child, parent.append(SpanElement()))
            parent.append(TextElement(content="\n"))
        elif sv.match(
            ".ltx_authors, .ltx_personname, .ltx_role_creation.ltx_date, .ltx_engrafo_author_notes, .ltx_author_notes, .ltx_date.ltx_role_creation",
            child,
        ):
            parse_latexml_children(child, parent.append(Paragraph()))
            parent.append(TextElement(content="\n"))
        elif sv.match(
            ".ltx_author_before, .ltx_role_pubyear, .ltx_role_pagerange", child
        ):
            pass
        elif sv.match("h1.ltx_title_document", child):
            doc = parent.find_parent(Document)
            if doc is not None:
                if doc.title is None:
                    doc.title = SectionHeader(parent=doc)
                    doc.title.hnum = int(child.name[1])
                    parse_latexml_children(child, doc.title)
                else:
                    printerr("Document title is already set", file=sys.stderr)
            else:
                printerr("Unable to find document to set title", file=sys.stderr)
        elif sv.match("section", child):
            if ".ltx_bibliography" not in classes:
                section = parent.append(Section())
                parse_latexml_children(child, section)
        elif sv.match("h1, h2, h3, h4, h5, h6", child) and "ltx_title" in classes:
            if {"ltx_title_theorem", "ltx_title_proof"} & classes:
                parse_latexml_children(child, parent)
                parent.append(TextElement(content=": "))
            elif isinstance(parent, Section):
                parent.hnum = int(child.name[1])
                if parent.header is None:
                    parent.header = SpanElement()
                parse_latexml_children(child, parent.header)
            else:
                printerr("Dangling title element", file=sys.stderr)
                parse_latexml_children(child, parent)
        elif sv.match(".ltx_TOC.ltx_toc_toc", child):
            s = parent.append(Section(hnum=6, header=TextElement(content="Contents")))
            parse_latexml_children(child, s.append(Paragraph()))
        elif sv.match(
            "ul.ltx_itemize, ul.ltx_toclist, ul.ltx_biblist, ol.ltx_enumerate", child
        ):
            lst = parent.append(ListContainer())
            lst.ordered = child.name == "ol"
            parent_list = parent.find_parent(ListContainer)
            lst.level = parent_list.level + 1 if parent_list is not None else 1
            parse_latexml_children(child, lst)
        elif sv.match("li.ltx_item, li.ltx_tocentry, li.ltx_bibitem", child):
            lst = parent.find_parent(ListContainer)
            if lst is not None:
                item = lst.add_item(ListItem())
                parse_latexml_children(child, item)
            else:
                printerr("List item outside list", file=sys.stderr)
        elif sv.match("cite", child):
            span = parent.append(SpanElement())
            parse_latexml_citations(child, span)
        elif sv.match("a.ltx_ref", child):
            target = child.attrs.get("href")
            if target.startswith("#bib"):  # citation link
                in_ref = parent.append(InlineRef())
                in_ref.target = target
                text = child.get_text()
                in_ref.target = target
                if text.strip().isnumeric():
                    in_ref.append(TextElement(content=text))
                elif re.search(r"[A-Za-z][:;.,_]?\d", text):
                    # probably a broken citation, go with link number instead
                    in_ref.append(
                        TextElement(
                            content=re.sub(r"\D", "", target.partition(".bib")[2])
                        )
                    )
                else:
                    raise ValueError('unusable reference "%s"' % text)
                doc = parent.find_parent(Document)
                if doc:
                    doc.add_inline_ref(in_ref)
            else:
                link = parent.append(Link())
                link.target = target
                parse_latexml_children(child, link)
        elif sv.match("a", child) and len(classes) == 0:
            target = child.attrs.get("href")
            parse_latexml_children(child, parent.append(Link(target=target)))
        elif sv.match(".ltx_eqn_table", child):
            eqn_list = parent.append(EquationList())
            parse_latexml_children(child, eqn_list)
        elif sv.match(".ltx_eqn_row", child):
            eqn_list = parent.find_parent(EquationList)
            if eqn_list is not None:
                eqn = eqn_list.add_equation(Equation())
                parse_latexml_children(child, eqn)
            else:
                printerr("Dangling equation row", file=sys.stderr)
                parse_latexml_children(child, parent)
        elif sv.match(".ltx_eqn_cell", child):
            parse_latexml_children(child, parent)
        elif sv.match("table, span.ltx_tabular, div.ltx_tabular", child):
            tabular = parent.append(Tabular())
            parse_latexml_children(child, tabular)
        elif sv.match("thead.ltx_thead", child):
            table = parent.find_parent(Tabular)
            if table is not None:
                parse_latexml_children(child, table)
            else:
                printerr("Table header element outside table", file=sys.stderr)
        elif sv.match("tbody.ltx_tbody", child):
            parse_latexml_children(child, parent)
        elif sv.match("tr.ltx_tr", child):
            table = parent.find_parent(Tabular)
            if table is not None:
                row = table.add_row(TableRow())
                parse_latexml_children(child, row)
            else:
                printerr("TableRow element outside table", file=sys.stderr)
        elif sv.match("td.ltx_td, th.ltx_th", child):
            row = parent.find_parent(TableRow)
            if row is not None:
                cell = TableCell()
                cell.set_attrs(child.attrs)
                row.add_cell(cell)
                parse_latexml_children(child, cell)
            else:
                printerr("TableData element outside table row", file=sys.stderr)
        elif sv.match("span.ltx_text, em.ltx_emph", child):
            if (
                child.find_parent(ListItem) is None
                or child.get_text() != "[label=0)]"
                or child.get_text() != "[leftmargin=*] "
            ):
                if "ltx_font_italic" in classes:
                    elem = Italic()
                elif "ltx_font_bold" in classes:
                    elem = Bold()
                else:
                    elem = SpanElement()
                parent.append(elem)
                parse_latexml_children(child, elem)
            else:
                parent.find_parent(ListContainer).items.pop()
        elif sv.match("figure.ltx_table", child):
            figure = parent.append(Table())
            if "id" in child.attrs:
                figure.id = child.attrs["id"]
            parse_latexml_children(child, figure)
        elif sv.match("figure.ltx_figure", child):
            figure = parent.append(Figure())
            if "id" in child.attrs:
                figure.id = child.attrs["id"]
            parse_latexml_children(child, figure)
        elif sv.match("figure.ltx_float", child):
            parse_latexml_children(child, parent)
        elif sv.match(".ltx_listing", child):
            alg = parent.append(Algorithm())
            parse_latexml_children(child, alg)
        elif sv.match(".ltx_listingline", child):
            alg = parent.find_parent(Algorithm)
            if alg is not None:
                line = alg.add_line(Element())
                parse_latexml_children(child, line)
            else:
                printerr("Listing line outside algorithm environment", file=sys.stderr)
        elif sv.match("dl.ltx_description", child):
            def_list = parent.append(DefinitionList())
            parse_latexml_children(child, def_list)
        elif sv.match("dt.ltx_item", child):
            def_list = parent.find_parent(DefinitionList)
            if def_list is not None:
                item = def_list.add_item(Definition())
                item.term = SpanElement(parent=item)
                parse_latexml_children(child, item.term)
            else:
                printerr("Found dangling definition term", file=sys.stderr)
        elif sv.match("dd.ltx_item", child):
            def_list = parent.find_parent(DefinitionList)
            if def_list is not None:
                if def_list.items and def_list.items[-1].definition is None:
                    item = def_list.items[-1]
                else:
                    printerr("Found definition without term", file=sys.stderr)
                    item = def_list.add_item(Definition())
                item.definition = SpanElement(parent=item)
                parse_latexml_children(child, item.definition)
            else:
                printerr("Found dangling definition", file=sys.stderr)
                parse_latexml_children(child, parent)
        elif sv.match("figcaption", child):
            fig = parent.find_parent((Figure, Table))
            if fig is not None:
                if fig.caption is None:
                    fig.caption = Paragraph(parent=fig)
                parse_latexml_children(child, fig.caption)
                fig.caption.append(TextElement(content="\n"))
            else:
                printerr("Figure caption outside figure element", file=sys.stderr)
                para = parent.append(Paragraph())
                parse_latexml_children(child, para)
        elif sv.match(".ltx_break", child):
            parent.append(TextElement(content="\n\n"))
        elif sv.match(".ltx_abstract, .ltx_acknowledgements", child):
            abstract = parent.append(Section())
            parse_latexml_children(child, abstract)
        elif sv.match(".ltx_ERROR", child):
            printerr(
                f"LaTeX error element: {child.get_text(strip=True)}", file=sys.stderr
            )
        elif is_wrapper_element(child):
            parse_latexml_children(child, parent)
        elif ignore_element(child):
            continue
        else:
            printerr(
                f"Unknown LaTeXML element <{child.name}> with classes {', '.join(classes)}",
                file=sys.stderr,
            )
            elem = parent.append(UnknownElement())
            parse_latexml_children(child, elem)


# TODO: move this somewhere else, so I can use it with plaintext too
sess = requests.Session()


def parse_latexml_references(html: BeautifulSoup, doc: Document) -> None:
    for child in html.select("li.ltx_bibitem"):
        child.attrs.get("id")
        ref_text = child.get_text(strip=False).replace("\n", " ")
        reference = Reference()
        reference.title = TextElement(content=child.get_text(strip=True))
        doc.add_reference(reference)


def parse_latexml(
    html: BeautifulSoup,
) -> Optional[Document]:
    if html.article is None:
        printerr("Missing article element", file=sys.stderr)
        return None
    doc = Document()
    parse_latexml_children(html.article, doc)
    parse_latexml_references(
        html.article,
        doc,
    )
    return doc


================================================
FILE: nougat/dataset/parser/markdown.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Iterable, List, Optional, Tuple
import re
from uuid import uuid4
from nougat.dataset.utils import normalize_tex
from nougat.dataset.parser.document import *
from nougat.dataset.parser.latexml_parser import _clean_html_whitespace
from unidecode import unidecode

SUPERSCRIPT_MAP = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹")
SUBSCRIPT_MAP = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
figure_regex = re.compile(r"\[(FOOTNOTE|FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", re.S)
conv = {
    "&": r"\&",
    "%": r"\%",
    "$": r"\$",
    "#": r"\#",
    "_": r"\_",
    "{": r"\{",
    "}": r"\}",
    "~": r"\textasciitilde{}",
    "^": r"\^{}",
    "\\": r"\textbackslash{}",
    "<": r"\textless{}",
    ">": r"\textgreater{}",
}
regex = re.compile(
    "|".join(
        re.escape(str(key)) for key in sorted(conv.keys(), key=lambda item: -len(item))
    )
)


def remove_trailing_whitespace(parts: List[str]) -> None:
    """Removes whitespace elements in list inplace"""
    for s in reversed(parts):
        if s.rstrip() == "":
            del parts[-1]
        else:
            break


def remove_line_breaks(parts: List[str]):
    out = []
    for s in parts:
        out.append(s.replace("\n", " "))
    return out


def leading_trailing_whitespace(
    parts: List[str],
) -> Tuple[List[str], List[str], List[str]]:
    """splits the list into three parts. The first and last return elements are made up only of whitespace

    Args:
        parts (List[str]): List to split.

    Returns:
        Tuple[List[str],List[str],List[str]]: Splitted list
    """
    lead = []
    trail = []
    out_slice = [None, None]
    for i, s in enumerate(parts):
        if s.strip() == "":
            lead.append(s)
            out_slice[0] = i + 1
        else:
            break
    for i, s in enumerate(reversed(parts)):
        if s.strip() == "":
            trail.append(s)
            out_slice[1] = -1 - i
        else:
            break
    return lead, parts[slice(*out_slice)], trail[::-1]


def latex_escape(string: str) -> str:
    return regex.sub(lambda match: conv[match.group()], string)


def is_empty(content: List) -> bool:
    """Used to determine if a Section is empty"""
    empty = True
    for part in content:
        if len(part.strip()):
            empty = False
            break
    return empty


def format_element(
    element: Element, keep_refs: bool = False, latex_env: bool = False
) -> List[str]:
    """
    Formats a given Element into a list of formatted strings.

    Args:
        element (Element): The element to be formatted.
        keep_refs (bool, optional): Whether to keep references in the formatting. Default is False.
        latex_env (bool, optional): Whether to use LaTeX environment formatting. Default is False.

    Returns:
        List[str]: A list of formatted strings representing the formatted element.
    """
    if isinstance(element, TextElement):
        if latex_env:
            return [latex_escape(element.content)]
        else:
            return [element.content]
    if isinstance(element, Bold):
        parts = format_children(element, keep_refs, latex_env)
        if element.find_parent(Algorithm) is not None:
            return parts
        lead, text, tail = leading_trailing_whitespace("".join(parts))
        return [*lead, "**", *remove_line_breaks(text), "**", *tail]
    if isinstance(element, Italic):
        parts = format_children(element, keep_refs, latex_env)
        if element.find_parent(Algorithm) is not None:
            return parts
        lead, text, tail = leading_trailing_whitespace("".join(parts))
        return [*lead, "_", *remove_line_breaks(text), "_", *tail]
    if isinstance(element, PlaintextMath):
        return format_children(element, keep_refs) + ["\n"]
    if isinstance(element, Paragraph):
        return format_children(element, keep_refs, latex_env) + ["\n\n"]
    if isinstance(element, TableCell):
        parts = format_children(element, keep_refs, latex_env)
        remove_trailing_whitespace(parts)
        if element.multirow is not None:
            parts.insert(0, "\\multirow{%i}{*}{" % (element.multirow))
            parts.append("}")
        if element.multicolumn is not None:
            parts.insert(
                0, "\\multicolumn{%i}{%s}{" % (element.multicolumn, element.spec)
            )
            parts.append("}")
        return parts
    if isinstance(element, TableRow):
        parts = []
        if element.hline_above:
            parts.append(element.hline_above + "\n")
        parts.extend(
            remove_line_breaks(
                format_iterator(element.cells, keep_refs, latex_env, join=" & ")
            )
        )
        parts.append(r" \\")
        parts.append((" " + element.hline_below).rstrip())
        return parts
    if isinstance(element, Tabular):
        parts = [
            "\\begin{tabular}",
            "{%s}\n" % element.get_table_spec(),
        ]
        parts.extend(format_iterator(element.rows, keep_refs, True, join="\n"))
        parts.append("\n\\end{tabular}\n")
        return parts
    if isinstance(element, Table):
        parts = [
            "[TABLE%s]\n\\begin{table}\n"
            % (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
        ]
        parts.extend(format_children(element, keep_refs, latex_env))
        caption_parts = format_element(element.caption, keep_refs, latex_env)
        remove_trailing_whitespace(caption_parts)
        parts.append("\\end{table}\n")
        if len(caption_parts) > 0:
            parts.extend(caption_parts + ["\n"])
        parts.append("[ENDTABLE]\n\n")
        return parts
    if isinstance(element, Figure):
        parts = format_element(element.caption, keep_refs)
        remove_trailing_whitespace(parts)
        return (
            [
                "[FIGURE%s]\n"
                % (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
            ]
            + parts
            + ["\n[ENDFIGURE]\n\n"]
        )
    if isinstance(element, SectionHeader):
        parts = ["# "]
        if element.id:
            parts.append(f"{element.id.upper()} ")
        if element.header:
            header = format_element(element.header, keep_refs)
        else:
            header = format_iterator(element.children, keep_refs)
        _, title, _ = leading_trailing_whitespace("".join(header))
        parts.append(title)
        parts.append("\n\n")
        return parts
    if isinstance(element, Section):
        children_parts = format_children(element, keep_refs)
        if is_empty(children_parts):
            return []
        if element.header:
            parts = [f"\n\n{'#'*element.hnum} "]
            _, title, _ = leading_trailing_whitespace(
                "".join(format_element(element.header, keep_refs))
            )
            parts.append(title)
            parts.append("\n\n")
        else:
            parts = []
        return parts + children_parts
    if isinstance(element, Footnote):
        if element.id is not None:
            foot = f"\n[FOOTNOTE:{element.id}]Footnote {element.id}: "
        else:
            foot = "\n[FOOTNOTE:%s]Footnote: " % (str(uuid4())[:5])
        return [foot] + format_children(element, keep_refs) + ["[ENDFOOTNOTE]\n\n"]
    if isinstance(element, ListContainer):
        items = [
            (
                item.label,
                "".join(format_element(item, keep_refs)).strip().replace("\n", " "),
            )
            for item in element.items
        ]
        parts = ["\n"]
        indent = "  " * max(element.level - 1, 0)
        for i, (label, item) in enumerate(items, 1):
            if label:
                bullet = label
            else:
                bullet = f"{i}." if element.ordered else "*"
            parts.append(f"{indent}{bullet} {item}\n")
        parts.append("\n")
        return parts
    if isinstance(element, Equation):
        # equation comprises of multiple displaystyle TeX formulas and optional equation label
        parts = []
        for child in element.children:
            if isinstance(child, LatexMath):
                tex = normalize_tex(
                    "".join(format_element(child, keep_refs)).strip(" \n"), inline=False
                )
                parts.append(tex)
            else:
                text = "".join(format_element(child, keep_refs))
                if text:
                    parts.append(text)
        lead, eqs, tail = leading_trailing_whitespace(parts)
        s = " ".join(eqs).replace(r"\] \[", " ")
        return [*lead, s, *tail]
    if isinstance(element, EquationList):
        parts = ["\n"]
        items = element.equations
        items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
        items = [item + "\n" for item in items if item]
        if items:
            parts.extend(items)
            parts.append("\n")
        return parts
    if isinstance(element, Algorithm):
        parts = []
        items = element.lines
        items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
        if element.inline:
            items = [item for item in items if item]
        else:
            items = [item + "\n" for item in items if item]
        if items:
            prepend = "`" if element.inline else "\n```\n"
            parts.append(prepend)
            parts.extend(items)
            append = "`" if element.inline else "```\n\n"
            parts.append(append)
        return parts
    if isinstance(element, DefinitionList):
        parts = ["\n"]
        if element.header is not None:
            parts.extend(format_element(element.header, keep_refs))
            parts.append("\n")
        items = [
            "".join(format_element(item, keep_refs)).rstrip() for item in element.items
        ]
        items = [item + "\n" for item in items if item]
        if items:
            parts.extend(items)
            parts.append("\n")
        return parts
    if isinstance(element, Definition):
        parts = []
        if element.term is not None:
            term = (
                "".join(format_element(element.term, keep_refs)).rstrip(" \n\t:") + ": "
            )
            # maths in wiki might be inside a definition without a term
            if term.strip() != ":":
                parts.append(term)
        if element.definition is not None:
            definition = "".join(format_element(element.definition, keep_refs)).rstrip()
            parts.append(definition)
        if parts:
            parts.append("\n")
        return parts
    if isinstance(element, LatexMath):
        parts = []
        if not element.inline:
            parts.append("\n\n")
        parts.append(normalize_tex(element.code, element.inline).strip())
        if not element.inline:
            parts.append("\n\n")
        return parts
    if isinstance(element, (Superscript, Subscript)):
        content = element.plaintext
        if content.strip().isdigit():
            script_map = (
                SUBSCRIPT_MAP if isinstance(element, Subscript) else SUPERSCRIPT_MAP
            )
            return [content.translate(script_map)]
        else:
            return format_children(element, keep_refs)
    if isinstance(element, InlineRef):
        parts = format_children(element, keep_refs)
        return parts
    return format_children(element, keep_refs, latex_env)


def format_iterator(
    iterator: Iterable,
    keep_refs: bool = False,
    latex_env: bool = False,
    join: Optional[str] = None,
) -> List[str]:
    """
    The `format_iterator` function takes an iterator and formats its elements, optionally joining them with a specified string.

    :param iterator: The `iterator` parameter is an iterable object that contains the elements to be formatted. It could be a list, tuple, set, or any other iterable object
    :type iterator: Iterable
    :param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether references to other elements should be preserved in the formatted output. If `keep_refs` is set to `True`, the references will be included in the output. If `keep_refs` is set to `False` (default), the, defaults to False
    :type keep_refs: bool (optional)
    :param latex_env: The `latex_env` parameter is a boolean flag that determines whether the output should be formatted as LaTeX code. If `latex_env` is set to `True`, the output will be formatted using LaTeX syntax. If `latex_env` is set to `False` (default), the output will be, defaults to False
    :type latex_env: bool (optional)
    :param join: The `join` parameter is an optional string that specifies the delimiter to be used when joining the formatted elements of the iterator into a single string. If `join` is provided, it will be inserted between each formatted element. If `join` is not provided, the formatted elements will be returned as
    :type join: Optional[str]
    :return: The function `format_iterator` returns a list of strings.
    """
    parts = []
    for child in iterator:
        parts.extend(format_element(child, keep_refs, latex_env))
        if join is not None:
            parts.append(join)
    if join is not None:
        parts = parts[:-1]
    return parts


def format_children(
    element: Element, keep_refs: bool = False, latex_env: bool = False
) -> List[str]:
    if element is None:
        return []
    return format_iterator(element.children, keep_refs, latex_env)


def format_document(
    doc: Document, keep_refs: bool = False
) -> Tuple[str, Dict[str, str]]:
    """
    The `format_document` function takes a `doc` object of type `Document` and a boolean `keep_refs` as input and returns a tuple containing the formatted text of the document and a dictionary of figures found in the document.

    :param doc: The `doc` parameter is of type `Document`, which is presumably a custom class representing a document
    :type doc: Document
    :param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether to keep references in the formatted document or not. If `keep_refs` is set to `True`, the references will be included in the formatted document. If `keep_refs` is set to `False`, the references will be excluded, defaults to False
    :type keep_refs: bool (optional)
    :return: The function `format_document` returns a tuple containing two elements: a formatted text document and a dictionary of figures.
    """
    parts = []

    if doc.title:
        parts.extend([*format_element(doc.title), "\n"])
    parts.append("\n")
    parts.extend(format_children(doc, keep_refs))
    text = "".join(parts)
    text = text.replace("\xa0", " ")  # replace non-breakable spaces
    text = re.sub(r" $", "", text, flags=re.MULTILINE)
    text = re.sub(r"\n[\t ]*$", "\n", text, flags=re.MULTILINE)
    text = re.sub(r"(?<!\n) {2,}", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text).lstrip()
    figures = {unidecode(m[0] + m[1]): m[2].strip() for m in figure_regex.findall(text)}
    text = figure_regex.sub(
        r"[\1\2][END\1]",
        text,
    )
    return text, figures


================================================
FILE: nougat/dataset/pdffigures.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import subprocess
import logging

PDFFIGURES2_JAR_PATH = os.environ.get("PDFFIGURES_PATH", None)
logger = logging.getLogger()
if PDFFIGURES2_JAR_PATH is None:
    logger.warning(
        "You need to configure the path to the pdffigures2 executable in this file (nougat/dataset/pdffigures.py) or set the environment variable 'PDFFIGURES_PATH'."
    )


def call_pdffigures(
    pdf_path: str, figures_dir: str, timeout: int = 30, verbose: bool = False
):
    """
    Extract figures from a PDF file using pdffigures2.

    Args:
        pdf_path (str): The path to the PDF file.
        figures_dir (str): The directory where the figures will be extracted.
        timeout (int, optional): The timeout in seconds for the pdffigures2 command. Defaults to 30.
        verbose (bool, optional): Whether to print the output of the pdffigures2 command. Defaults to False.

    Returns:
        str: The path to the JSON file containing the extracted figures.
    """
    os.makedirs(figures_dir, exist_ok=True)
    kwargs = (
        {} if verbose else {"stderr": subprocess.DEVNULL, "stdout": subprocess.DEVNULL}
    )
    if PDFFIGURES2_JAR_PATH is None:
        return
    process = subprocess.Popen(
        "java"
        " -jar {pdffigures_jar_path}"
        " -d {figures_dir}/"
        " -c"
        " -q"
        " {pdf_path}".format(
            pdffigures_jar_path=PDFFIGURES2_JAR_PATH,
            pdf_path=pdf_path,
            figures_dir=figures_dir,
        ),
        shell=True,
        **kwargs
    )

    try:
        exit_code = process.wait(timeout=timeout)
        if exit_code != 0:
            logger.error("Extracting figures from file %s failed.", pdf_path)
            return False
    except subprocess.TimeoutExpired as e:
        logger.error(
            "pdffigures2 command did not terminate in 30 seconds, "
            "terminating. Error: %s",
            e,
        )
        process.terminate()  # give up
        return False
    pdf_name = os.path.basename(pdf_path).partition(".pdf")[0]
    dest_file = os.path.join(figures_dir, (pdf_name + ".json"))

    return dest_file


================================================
FILE: nougat/dataset/rasterize.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
import logging
import pypdfium2
from pathlib import Path
from tqdm import tqdm
import io
from typing import Optional, List, Union

logging.getLogger("pypdfium2").setLevel(logging.WARNING)


def rasterize_paper(
    pdf: Union[Path, bytes],
    outpath: Optional[Path] = None,
    dpi: int = 96,
    return_pil=False,
    pages=None,
) -> Optional[List[io.BytesIO]]:
    """
    Rasterize a PDF file to PNG images.

    Args:
        pdf (Path): The path to the PDF file.
        outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.
        dpi (int, optional): The output DPI. Defaults to 96.
        return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.
        pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.

    Returns:
        Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.
    """
    pils = []
    if outpath is None:
        return_pil = True
    try:
        if isinstance(pdf, (str, Path)):
            pdf = pypdfium2.PdfDocument(pdf)
        if pages is None:
            pages = range(len(pdf))
        renderer = pdf.render(
            pypdfium2.PdfBitmap.to_pil,
            page_indices=pages,
            scale=dpi / 72,
        )
        for i, image in zip(pages, renderer):
            if return_pil:
                page_bytes = io.BytesIO()
                image.save(page_bytes, "bmp")
                pils.append(page_bytes)
            else:
                image.save((outpath / ("%02d.png" % (i + 1))), "png")
    except Exception as e:
        logging.error(e)
    if return_pil:
        return pils


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pdfs", nargs="+", type=Path, help="PDF files", required=True)
    parser.add_argument("--out", type=Path, help="Output dir", default=None)
    parser.add_argument(
        "--dpi", type=int, default=96, help="What resolution the pages will be saved"
    )
    parser.add_argument(
        "--pages", type=int, nargs="+", default=None, help="list of page numbers"
    )
    args = parser.parse_args()
    if args.pages:
        args.pages = [p - 1 for p in args.pages]
    for pdf_file in tqdm(args.pdfs):
        assert pdf_file.exists() and pdf_file.is_file()
        outpath: Path = args.out or (pdf_file.parent / pdf_file.stem)
        outpath.mkdir(exist_ok=True)
        rasterize_paper(pdf_file, outpath, pages=args.pages, dpi=args.dpi)


================================================
FILE: nougat/dataset/split_htmls_to_pages.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
from io import BytesIO
import multiprocessing
from pebble import ProcessPool
from concurrent.futures import TimeoutError
from tqdm import tqdm
from typing import Tuple
import os
from pathlib import Path
import logging
import pypdf
from PIL import Image
import pytesseract
from nougat.dataset.split_md_to_pages import *
from nougat.dataset.parser.html2md import *
from nougat.dataset.pdffigures import call_pdffigures

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def process_paper(
    fname: str,
    pdf_file: Path,
    html_file: Path,
    json_file: Path,
    args: argparse.Namespace,
) -> Tuple[int, int]:
    """
    Process a single paper.

    Args:
        fname (str): The paper's filename.
        pdf_file (Path): The path to the PDF file.
        html_file (Path): The path to the HTML file.
        json_file (Path): The path to the JSON file containing the extracted figures.
        args (argparse.Namespace): The command-line arguments.

    Returns:
        Tuple[int, int]: The number of total pages and the number of recognized pages.
    """
    total_pages = 0
    num_recognized_pages = 0
    try:
        pdf = pypdf.PdfReader(pdf_file)
        total_pages = len(pdf.pages)
        outpath: Path = args.out / fname
        # skip this paper if already processed
        dirs_with_same_stem = list(args.out.glob(fname.partition("v")[0] + "*"))
        if (
            len(dirs_with_same_stem) > 0
            and len(list(dirs_with_same_stem[0].iterdir())) > 0
            and not args.recompute
        ):
            logger.info(
                "%s (or another version thereof) already processed. Skipping paper",
                fname,
            )
            return total_pages, len(list(outpath.glob("*.mmd")))
        html = BeautifulSoup(
            htmlmin.minify(
                open(html_file, "r", encoding="utf-8").read().replace("\xa0", " "),
                remove_all_empty_space=True,
            ),
            features="html.parser",
        )
        doc = parse_latexml(html)
        if doc is None:
            return
        out, fig = format_document(doc, keep_refs=True)

        if args.markdown:
            md_out = args.markdown / (fname + ".mmd")
            with open(md_out, "w", encoding="utf-8") as f:
                f.write(out)

        if json_file is None:
            json_file = call_pdffigures(pdf_file, args.figure)
        if json_file:
            figure_info = json.load(open(json_file, "r", encoding="utf-8"))
        else:
            figure_info = None
        split = split_markdown(
            out, pdf_file, figure_info=figure_info, doc_fig=fig, min_score=0.9
        )
        if split is None:
            return
        pages, meta = split
        num_recognized_pages = sum([len(p) > 0 for p in pages])
        if all([len(p) == 0 for p in pages]):
            return
        os.makedirs(outpath, exist_ok=True)
        recognized_indices = []
        for i, content in enumerate(pages):
            with (outpath / "meta.json").open("w", encoding="utf-8") as f:
                f.write(json.dumps(meta))
            if content:
                if re.search(r"\[(?:\?\?(?:. )?)+\]", content):
                    # there are wrongly parsed references in the page eg [??].
                    continue
                with (outpath / ("%02d.mmd" % (i + 1))).open(
                    "w", encoding="utf-8"
                ) as f:
                    f.write(content)
                recognized_indices.append(i)
        rasterize_paper(pdf_file, outpath, dpi=args.dpi, pages=recognized_indices)
        if args.tesseract:
            for i in recognized_indices:
                ocr = pytesseract.image_to_string(
                    Image.open((outpath / ("%02d.png" % (i + 1)))), lang="eng"
                )
                ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
                with (outpath / ("%02d_OCR.txt" % (i + 1))).open(
                    "w", encoding="utf-8"
                ) as f_ocr:
                    f_ocr.write(ocr)
    except Exception as e:
        logger.error(e)

    return total_pages, num_recognized_pages


def process_htmls(args):
    for input_dir in (args.pdfs, args.html):
        if not input_dir.exists() and not input_dir.is_dir():
            logger.error("%s does not exist or is no dir.", input_dir)
            return
    htmls: List[Path] = args.html.glob("*.html")
    args.out.mkdir(exist_ok=True)
    if args.markdown:
        args.markdown.mkdir(exist_ok=True)

    with ProcessPool(max_workers=args.workers) as pool:
        total_pages, total_pages_extracted = 0, 0
        tasks = {}
        for j, html_file in enumerate(htmls):
            fname = html_file.stem
            pdf_file = args.pdfs / (fname + ".pdf")
            if not pdf_file.exists():
                logger.info("%s pdf could not be found.", fname)
                continue
            json_file = args.figure / (fname + ".json")
            if not json_file.exists():
                logger.info("%s figure json could not be found.", fname)
                json_file = None
            tasks[fname] = pool.schedule(
                process_paper,
                args=[fname, pdf_file, html_file, json_file, args],
                timeout=args.timeout,
            )

        for fname in tqdm(tasks):
            try:
                res = tasks[fname].result()
                if res is None:
                    logger.info("%s is faulty", fname)
                    continue
                num_pages, num_recognized_pages = res
                total_pages += num_pages
                total_pages_extracted += num_recognized_pages
                logger.info(
                    "%s: %i/%i pages recognized. Percentage: %.2f%%",
                    fname,
                    num_recognized_pages,
                    num_pages,
                    (100 * num_recognized_pages / max(1, num_pages)),
                )
            except TimeoutError:
                logger.info("%s timed out", fname)
    if total_pages > 0:
        logger.info(
            "In total: %i/%i pages recognized. Percentage: %.2f%%",
            total_pages_extracted,
            total_pages,
            (100 * total_pages_extracted / max(1, total_pages)),
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--html", type=Path, help="HTML files", required=True)
    parser.add_argument("--pdfs", type=Path, help="PDF files", required=True)
    parser.add_argument("--out", type=Path, help="Output dir", required=True)
    parser.add_argument("--recompute", action="store_true", help="recompute all splits")
    parser.add_argument(
        "--markdown", type=Path, help="Markdown output dir", default=None
    )
    parser.add_argument(
        "--figure",
        type=Path,
        help="Figure info JSON dir",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=multiprocessing.cpu_count(),
        help="How many processes to use",
    )
    parser.add_argument(
        "--dpi", type=int, default=96, help="What resolution the pages will be saved at"
    )
    parser.add_argument(
        "--timeout", type=float, default=120, help="max time per paper in seconds"
    )
    parser.add_argument(
        "--tesseract",
        action="store_true",
        help="Tesseract OCR prediction for each page",
    )
    args = parser.parse_args()
    print(args)
    process_htmls(args)


================================================
FILE: nougat/dataset/split_md_to_pages.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import argparse
from collections import Counter
from copy import deepcopy
import json
import math
from operator import itemgetter
import re
from typing import Dict, List, Tuple, Union, Optional
import os
import pypdf
from unidecode import unidecode
from rapidfuzz.fuzz import ratio as ratio_perc

import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import SGDClassifier

from nougat.dataset.staircase import Staircase
from nougat.dataset.splitter import (
    Splitter,
    get_first_last,
    get_glob_index,
)
from nougat.dataset.utils import unicode_to_latex, remove_pretty_linebreaks
from nougat.dataset.utils.pdf_text_extract import get_pages, get_paragraphs
from nougat.dataset.rasterize import rasterize_paper


def ratio(*args, **kwargs):
    return ratio_perc(*args, **kwargs) / 100


class BagOfWords:
    """
    A bag-of-words model for text classification.

    Args:
        sentences (List[str]): The training sentences.
        target (Optional[List[int]]): The target labels for the training sentences. Defaults to None.

    """

    def __init__(
        self,
        sentences: List[str],
        target: Optional[List[int]] = None,
    ) -> None:
        self.sentences = sentences
        self.target = target
        self.train()

    def train(self):
        if self.target is None:
            self.target = np.arange(len(self.sentences))
        self.count_vect = CountVectorizer()
        X_train_counts = self.count_vect.fit_transform(self.sentences)
        self.tfidf_transformer = TfidfTransformer(use_idf=True)
        X_train_tfidf = self.tfidf_transformer.fit_transform(X_train_counts)
        self.clf = SGDClassifier(
            loss="hinge",
            penalty="l2",
            alpha=1e-3,
            random_state=42,
            max_iter=5,
            tol=None,
        )
        self.clf.fit(X_train_tfidf, self.target)

    def __call__(
        self, text: Union[str, List[str]], lob_probs: bool = False
    ) -> np.ndarray:
        if type(text) == str:
            text = [text]
        X_new_counts = self.count_vect.transform(text)
        X_new_tfidf = self.tfidf_transformer.transform(X_new_counts)
        if lob_probs:
            return self.clf.predict_log_proba(X_new_tfidf)
        else:
            return self.clf.predict(X_new_tfidf)


def remove_short_seqs(seqs: List[str], minimum: int = 10) -> List[str]:
    """Remove sequences shorter than the specified minimum length."""
    out = []
    for seq in seqs:
        if len(seq) > minimum:
            out.append(seq)
    return out


def find_figures(
    pdf_pages: List[List[str]], figure_info: Union[Dict, List]
) -> List[Tuple[int, int]]:
    """ "
    Find the locations of figures in a PDF file.

    Args:
        pdf_pages (List[List[str]]): The text of the PDF pages.
        figure_info (Union[Dict, List]): A dictionary or list of dictionaries, where each dictionary
            specifies the information about a figure, such as its caption, page number, and bounding box.

    Returns:
        List[Tuple[int, int]]: A list of tuples, where each tuple contains the figure index, page number,
            start position, and end position of the figure in the PDF file.
    """
    figure_locations = []
    iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
    for figure_list in iterator:
        for i, f in enumerate(figure_list):
            if "caption" in f:
                fig_string = f["caption"]
            elif "text" in f:
                fig_string = f["text"]
            else:
                continue
            fig_string = unicode_to_latex(fig_string)
            if f["page"] >= len(pdf_pages):
                continue
            block, score = Splitter.fuzzysearch(
                "\n".join(pdf_pages[f["page"]]),
                fig_string,
            )
            if score > 0.8 and block[2] > 0:
                figure_locations.append((i, f["page"], block[0], block[2]))
    return figure_locations


def flatten(l: List) -> List:
    return [item for sublist in l for item in sublist]


def get_doc_text(
    pdf: str,
    splitn: bool = True,
    split_block: bool = True,
    minlen: Optional[int] = 10,
) -> List[List[str]]:
    """
    Get the text from a PDF document.

    Args:
        doc (str): Path to the PDF document.
        splitn (bool): Whether to split the text into lines. Defaults to True.
        split_block (bool): Whether to split the text into blocks. Defaults to True.
        minlen (Optional[int]): The minimum length of a line or block. Defaults to 10.

    Returns:
        List[List[str]]: The text of the PDF document, either as a list of lines or a list of blocks..
    """
    document_lines = []
    if split_block:
        pages = get_paragraphs(pdf)
    else:
        pages = [get_pages(pdf)]
    for blocks in pages:
        page_lines = []
        for block in blocks:
            if splitn:
                page_lines.extend(block.split("\n"))
            else:
                page_lines.append(block)
        if splitn:
            page_lines = remove_short_seqs(page_lines, minlen)
        document_lines.append(page_lines)
    return document_lines


def clean_pdf_text(pages: List[List[str]], num_words: int = 10) -> List[List[str]]:
    """
    Clean the text of a PDF document by removing frequent words from the beginning and end of each page.

    Args:
        pages (List[List[str]]): The text of the PDF document, as a list of lists of strings.
        num_words (int, optional): The number of words to consider at the beginning and end of each page. Defaults to 10.

    Returns:
        List[List[str]]: The cleaned text of the PDF document.
    """
    words = []
    for page in pages:
        first = get_first_last(
            " ".join(page).lower(), num_words=num_words, first_only=True
        )
        words.extend(first.split(" "))
    word_counts = Counter(words)
    common_words = [
        "the",
        "of",
        "a",
        "and",
        "to",
        "in",
        "is",
        "that",
        "for",
        "are",
        "this",
        "we",
        "figure",
        "fig.",
        "",
    ]
    frequent_words = []
    for w, f in word_counts.items():
        if w in common_words or w.startswith("\\"):
            continue
        if f / len(pages) >= 0.4:
            frequent_words.append(w)
    if len(frequent_words) == 0:
        return pages
    # remove frequent words from page beginning/end
    for i in range(len(pages)):
        page = pages[i]
        stop = 0
        page_num_words = 0
        for p in page:
            page_num_words += len(p.split(" "))
            stop += 1
            if page_num_words >= num_words:
                break
        for w in frequent_words:
            for j in range(stop):
                if w == "-":  # probably page number - \d -
                    pages[i][j] = re.sub(
                        r"-\s*\d{1,3}\s*-", "", pages[i][j], flags=re.IGNORECASE
                    )
                pages[i][j] = re.sub(re.escape(w), "", pages[i][j], flags=re.IGNORECASE)
    return pages


def split_markdown(
    doc: str,
    pdf_file: str,
    figure_info: Optional[List[Dict]] = None,
    doc_fig: Dict[str, str] = {},
    minlen: int = 3,
    min_num_words: int = 22,
    doc_paragraph_chars: int = 1000,
    min_score: float = 0.75,
    staircase: bool = True,
) -> Tuple[List[str], Dict]:
    """
    Split a PDF document into Markdown paragraphs.

    Args:
        doc (str): The text of the Markdown document.
        pdf (str): The PDF document.
        figure_info (Optional[List[Dict]]): A list of dictionaries, where each dictionary
            specifies the information about a figure, such as its caption, page number, and bounding box.
        doc_fig (Dict[str, str]): A dictionary mapping figure ids to LaTeX code.
        minlen (int): The minimum length of a Markdown paragraph.
        min_num_words: The minimum number of words in a Markdown paragraph.
        doc_paragraph_chars: The maximum number of characters in a Markdown paragraph.
        min_score: The minimum score for a Markdown paragraph to be split.
        staircase: Whether to split the document into paragraphs with a staircase pattern.

    Returns:
        Tuple[List[str], Dict]: The list of Markdown paragraphs and the metadata.
    """
    pdf = pypdf.PdfReader(pdf_file)
    doc_paragraphs_full: List[str] = doc.split("\n")
    doc_paragraph_lengths = [len(p) for p in doc_paragraphs_full if len(p) > 1]
    num_lines = 1 + int(doc_paragraph_chars / np.mean(doc_paragraph_lengths))
    doc_paragraphs_full = [
        unidecode("\n".join(doc_paragraphs_full[i : i + num_lines]))
        for i in range(0, len(doc_paragraphs_full), num_lines)
    ]
    doc_paragraphs: List[str] = []
    doc_paragraph_indices: List[int] = []
    for i, p in enumerate(doc_paragraphs_full):
        if len(p) > 1:
            doc_paragraphs.append(
                re.sub(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])", "", p)
            )
            doc_paragraph_indices.append(i)
    meta = {"pdffigures": figure_info}
    if len(pdf.pages) > 1:
        pdf_text = get_doc_text(pdf_file, True, True, minlen)
        pdf_content = [
            [unicode_to_latex(q).replace("\n", " ") for q in p if len(q) >= minlen]
            for p in pdf_text
        ]

        pdf_content = clean_pdf_text(pdf_content)
        if figure_info is not None:
            figure_locations = sorted(
                find_figures(pdf_content, figure_info), key=itemgetter(2), reverse=True
            )
            clean_pdf_content = deepcopy(pdf_content)
            for i, page_content in enumerate(pdf_content):
                len_sentences = np.cumsum([0] + [len(p) for p in page_content])
                for match in figure_locations:
                    _, page, start, len_ = match
                    if i != page:
                        continue
                    a, b = (
                        get_glob_index(len_sentences, start),
                        get_glob_index(len_sentences, start + len_) + 1,
                    )
                    for j, k in enumerate(range(a, b + 1)):
                        if len(clean_pdf_content[i]) == k:
                            break
                        if j == 0:
                            clean_pdf_content[i][k] = clean_pdf_content[i][k][
                                : start - len_sentences[k]
                            ]
                        elif k == b:
                            clean_pdf_content[i][k] = clean_pdf_content[i][k][
                                start + len_ - len_sentences[k] :
                            ]
                        else:
                            clean_pdf_content[i][k] = ""
                clean_pdf_content[i] = remove_short_seqs(clean_pdf_content[i], 0)
            pdf_content = clean_pdf_content
        paragraphs = flatten(pdf_content)
        num_paragraphs = np.cumsum([0] + [len(page) for page in pdf_content])
        if staircase:
            # train bag of words
            page_target = np.zeros(len(paragraphs))
            page_target[num_paragraphs[1:-1] - 1] = 1
            page_target = np.cumsum(page_target).astype(int)
            model = BagOfWords(paragraphs, target=page_target)
            labels = model(doc_paragraphs)

            # fit stair case function
            x = np.arange(len(labels))
            stairs = Staircase(len(labels), labels.max() + 1)
            stairs.fit(x, labels)
            boundaries = (stairs.get_boundaries().astype(int)).tolist()
            boundaries.insert(0, 0)
        else:
            boundaries = [0] * (len(pdf.pages))
        splitter = Splitter(doc_paragraphs)
        pages = [(0, 0, 1.0)]
        meta["first_words"] = []
        meta["last_words"] = []
        for i in range(1, len(boundaries)):
            delta = (
                math.ceil(stairs.uncertainty[i - 1]) + 5
                if staircase
                else len(doc_paragraphs)
            )
            words_f = []
            words_l = []
            for p in pdf_content[i]:
                words_f.extend(p.split(" "))
                if len(words_f) >= min_num_words:
                    break
            for p in pdf_content[i - 1][::-1]:
                words_l.extend(p.split(" ")[::-1])
                if len(words_l) >= min_num_words:
                    words_l = words_l[::-1]
                    break
            if len(words_f) < 2:
                pages.append(pages[-1])
            first_words = " ".join(words_f[:min_num_words]).strip()
            last_words = " ".join(words_l[-min_num_words:]).strip()
            meta["first_words"].append(first_words)
            meta["last_words"].append(last_words)
            if len(first_words) < minlen and len(last_words) < minlen:
                pages.append(pages[-1])
                continue
            pages.append(
                splitter.split_first_last(
                    boundaries[i],
                    first_words,
                    last_words,
                    delta=delta,
                )
            )
    elif len(pdf.pages) == 1:  # single page
        pages = [(0, 0, 1)]
    else:
        return
    pages.append((len(doc_paragraphs), -1, 1.0))
    out = []
    page_scores = {}
    for i in range(len(pages) - 1):
        score = (pages[i][2] + pages[i + 1][2]) * 0.5
        if score >= min_score:
            end = pages[i + 1][0]
            if end >= len(doc_paragraph_indices):
                end = None
            else:
                end = doc_paragraph_indices[pages[i + 1][0]] + 1
            lines = doc_paragraphs_full[doc_paragraph_indices[pages[i][0]] : end]
            if len(lines) > 0:
                lines[0] = lines[0][pages[i][1] :]
                lines[-1] = lines[-1][: pages[i + 1][1]]
        else:
            lines = []
        page_content = "\n".join(lines)
        page_content = remove_pretty_linebreaks(page_content)
        page_scores[i] = score
        out.append(page_content)

    meta["page_splits"] = pages
    meta["page_scores"] = page_scores
    meta["num_pages"] = len(pdf.pages)

    # Reintroduce figures, tables and footnotes
    figure_tex = list(doc_fig.keys()), list(doc_fig.values())
    if len(doc_fig) > 0:
        iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
        for figure_list in iterator:
            if not figure_list:
                continue
            for i, f in enumerate(figure_list):
                if "caption" in f:
                    fig_string = f["caption"]
                elif "text" in f:
                    fig_string = f["text"]
                else:
                    continue
                ratios = []
                for tex in figure_tex[1]:
                    if f["figType"] == "Table":
                        tex = tex.partition(r"\end{table}")[2]
                    ratios.append(ratio(tex, fig_string))
                k = np.argmax(ratios)
                if ratios[k] < 0.8:
                    continue
                if f["page"] < len(out) and out[f["page"]] != "":
                    out[f["page"]] += "\n\n" + remove_pretty_linebreaks(
                        figure_tex[1][k].strip()
                    )

    for i in range(len(out)):
        foot_match = re.findall(r"\[FOOTNOTE(.*?)\]\[ENDFOOTNOTE\]", out[i])
        for match in foot_match:
            out[i] = out[i].replace(
                "[FOOTNOTE%s][ENDFOOTNOTE]" % match,
                doc_fig.get("FOOTNOTE%s" % match, ""),
            )

        out[i] = re.sub(r"\[(FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", "", out[i])
    return out, meta


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--md", type=str, help="Markdown file", required=True)
    parser.add_argument("--pdf", type=str, help="PDF File", required=True)
    parser.add_argument("--out", type=str, help="Out dir", required=True)
    parser.add_argument(
        "--figure",
        type=str,
        help="Figure info JSON",
    )
    parser.add_argument("--dpi", type=int, default=96)
    args = parser.parse_args()
    md = open(args.md, "r", encoding="utf-8").read().replace("\xa0", " ")
    pdf = pypdf.PdfReader(args.pdf)
    try:
        fig_info = json.load(open(args.figure, "r", encoding="utf-8"))
    except FileNotFoundError:
        fig_info = None
    pages, meta = split_markdown(md, pdf, fig_info)
    if args.out:
        outpath = os.path.join(args.out, os.path.basename(args.pdf).partition(".")[0])
        os.makedirs(outpath, exist_ok=True)
        found_pages = []
        for i, content in enumerate(pages):
            if content:
                with open(
                    os.path.join(
                        outpath, "%02d_s=%.2f.mmd" % (i + 1, meta["page_scores"][i])
                    ),
                    "w",
                    encoding="utf-8",
                ) as f:
                    f.write(content)
                found_pages.append(i)
        rasterize_paper(pdf, outpath, dpi=args.dpi, pages=found_pages)


================================================
FILE: nougat/dataset/splitter.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from typing import List, Tuple, Union
import re
import numpy as np
from rapidfuzz.fuzz import ratio as ratio_perc
from fuzzysearch import find_near_matches

math_start_regex = re.compile(r"(?<!\\)\\[\[\(]", re.M)
math_end_regex = re.compile(r"(?<!\\)\\[\]\)]", re.M)


def ratio(*args, **kwargs):
    return ratio_perc(*args, **kwargs) / 100


def reverse(lst: List[str]) -> List[str]:
    """Reverses a list and the strings inside

    Args:
        lst (List[str]): List to process

    Returns:
        List[str]: Reversed list
    """
    out = lst[::-1]
    for i in range(len(out)):
        out[i] = out[i][::-1]
    return out


def get_first_last(
    s: str,
    num_words: int = 8,
    delim: str = " ",
    first_only: bool = False,
    last_only: bool = False,
) -> Union[Tuple[str, str], str]:
    """
    Get the first and last `num_words` from a string `s`.

    Args:
        s (str): The string.
        num_words (int): The number of words.
        delim (str): The delimiter between words.
        first_only (bool): Whether to only get the first `num_words`.
        last_only (bool): Whether to only get the last `num_words`.

    Returns:
        Union[Tuple[str, str], str]: The first and last `num_words` from `s`, or `s` if `num_words` is 0.
    """
    s = s.split(delim)
    if not first_only and not last_only:
        return delim.join(s[:num_words]), delim.join(s[-num_words:])
    elif first_only:
        return delim.join(s[:num_words])
    elif last_only:
        return delim.join(s[-num_words:])


def get_glob_index(
    lengths: List[int], ind: int, return_breakpoints: bool = False
) -> int:
    """returns the index where ind is closest and greater than the lengths"""
    breakpoints = np.cumsum(lengths)
    overlap = breakpoints - ind
    overlap[overlap > 0] = -int(1e5)
    indices = overlap.argmax(0)
    if return_breakpoints:
        return indices, breakpoints
    else:
        return indices


# table-header-figure regex
# thf_regex = re.compile(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])")


class Splitter:
    _split_locs: List[Tuple[int, int]] = None

    def __init__(self, paragraphs: List[str]) -> None:
        self.paragraphs = paragraphs
        self.paragraphs_no_space = [self.remove_special_chars(h) for h in paragraphs]
        self._split_locs = [(0, 0)]
        self.paragraphs_rev = reverse(self.paragraphs)
        self.paragraphs_rev_no_space = reverse(self.paragraphs_no_space)

    @staticmethod
    def remove_special_chars(string: str) -> str:
        # string = thf_regex.sub(r"", string)
        return (
            string.replace("\\ ", "")
            .replace(" ", "")
            .replace("\n", "")
            .replace("*", "")
            .replace("_", "")
            .replace("^", "")
            .replace("\\[", "")
            .replace("\\]", "")
            .replace("\\(", "")
            .replace("\\)", "")
            .replace("\\right", "")
            .replace("\\left", "")
            .replace("\\sum", "X")  # old latex unicode encoding issue
            .replace("{", "")
            .replace("}", "")
            .replace("#", "")
            .replace("[REF]", "")
            .replace("[ENDREF]", "")
            .replace("\\varphi", "\\phi")  # https://meta.stackexchange.com/a/349360
            .replace("\\quad", "")
            .replace("\\qquad", "")
            .replace("\\hskip", "")
            .replace("\\vskip", "")
            .replace("\\frac", "")
            .replace("\\rm", "")
            .replace("\\,", "")
            .replace("-", "")
            .lower()
        )

    @staticmethod
    def count_special_chars(string: str, char_ind: int) -> int:
        if len(string) == 0:
            return 0
        add_space_ind = 0
        while True:
            string_ = string[: char_ind + add_space_ind]
            # last_first = string[: char_ind + add_space_ind+]
            add = (
                string_.count(" ")
                + string_.count("\\ ") * 2
                + string_.count("\n")
                + string_.count("*")
                + string_.count("_")
                + string_.count("^")
                + string_.count("\\[") * 2
                + string_.count("\\]") * 2
                + string_.count("\\(") * 2
                + string_.count("\\)") * 2
                + string_.count("\\right") * 6
                + string_.count("\\left") * 5
                + string_.count("\\sum") * 3  # replaced to X that's why not 4
                + string_.count("{")
                + string_.count("}")
                + string_.count("#")
                + string_.count("[REF]") * 5
                + string_.count("[ENDREF]") * 8
                + string_.count("\\varphi") * 3
                + string_.count("\\quad") * 5
                + string_.count("\\qquad") * 6
                + string_.count("\\hskip") * 6
                + string_.count("\\vskip") * 6
                + string_.count("\\frac") * 5
                + string_.count("\\rm") * 3
                + string_.count("\\,") * 2
                + string_.count("-")
            )
            if add == add_space_ind:
                break
            add_space_ind = add
        if len(string) <= char_ind + add_space_ind:
            add_space_ind = max(0, len(string) - 1 - char_ind)

        # check first chars of rest if they match closing expressions
        while True:
            rest = string[char_ind + add_space_ind :]
            string_ = string[: char_ind + add_space_ind]
            section_title = re.match(r"#+\s?\d*\s*$", string_)
            if rest.startswith("\\]") or rest.startswith("\\)"):
                add_space_ind += 2
            elif (rest.startswith(")") or rest.startswith("]")) and string_.endswith(
                "\\"
            ):
                add_space_ind += 1
            elif (rest.startswith("(") or rest.startswith("[")) and string_.endswith(
                "\\"
            ):
                add_space_ind -= 1
            elif rest.startswith(" "):
                add_space_ind += 1
            elif section_title:
                add_space_ind -= section_title.end() - section_title.start()
            elif (
                re.match(r"^[^\w\s]*_\s", rest)
                or re.match(r"^[^\w\s]*\*\*?\s", rest)
                or re.match(r"^.\n", rest)
            ):
                add_space_ind += 1
            else:
                break
        # check if it starts in a math env and include everything before
        end = math_end_regex.search(rest)
        if end is not None:
            start = math_start_regex.search(rest)
            if start is None or start.start() > end.start():
                inds = [
                    m.start()
                    for m in math_start_regex.finditer(string_)
                    if m.start() < end.start() + len(string_)
                ]
                if len(inds) > 0:
                    add_space_ind = inds[-1] - char_ind
                    # assert string_[char_ind+add_space_ind]=='\\'
        return add_space_ind

    def split_first_last(
        self, index: int, first: str, last: str, delta: int = 5
    ) -> Tuple[int, int, float]:
        """Refines a split by looking at both the first words from a new page and the last words from the previous page.

        Args:
            index (int): paragraph index
            first (str): first words
            last (str): last words
            delta (int, optional): paragraph search radius. Defaults to 5.

        Returns:
            Tuple[int, int, float]: split prediction
        """
        if first:
            first_split = glob_f, char_f, score_f = self.split(
                index, first, delta=delta
            )
        if last:
            last_split = glob_l, char_l, score_l = self.split(
                index, last, delta=delta, reverse=True
            )
        if first and not last:
            return first_split
        elif not first and last:
            return last_split
        elif not first and not last:
            return index, 0, 0.0
        if char_f == char_l and glob_f == glob_l and (score_f > 0.5 or score_l > 0.5):
            return glob_l, char_l, 1.0

        # score calculation
        first, last = self.remove_special_chars(first), self.remove_special_chars(last)
        matching = []
        for split in (first_split, last_split):
            first_source = []
            num_chars_first = len(first)
            num_chars_last = len(last)
            last_source = []
            for i, p in enumerate(self.paragraphs[split[0] :]):
                if i == 0:
                    p = p[split[1] :]
                first_source.append(self.remove_special_chars(p))
                if sum([len(s) for s in first_source]) >= num_chars_first:
                    break
            first_source = "".join(first_source)[:num_chars_first]
            for i, p in enumerate(self.paragraphs[split[0] :: -1]):
                if i == 0:
                    p = p[: split[1]]
                last_source.insert(0, self.remove_special_chars(p))
                if sum([len(s) for s in last_source]) >= num_chars_last:
                    last_source = last_source
                    break
            last_source = "".join(last_source)[-num_chars_last:]
            matching.append(
                [
                    ratio(first, first_source) * ratio(first[:10], first_source[:10]),
                    ratio(last, last_source) * ratio(last[-10:], last_source[-10:]),
                ]
            )
        scores = np.asarray(matching).max(0)
        return (
            (glob_l, char_l, scores[1])
            if scores.argmax()
            else (glob_f, char_f, scores[0])
        )

    def split(
        self, index: int, string: str, delta: int = 5, reverse: bool = False
    ) -> Tuple[int, int, float]:
        """
        refine split prediction. `string` are the first words from new page.
        delta can be used as uncertainty measure.
        returns new index and split index
        """
        if reverse:
            index = len(self.paragraphs) - 1 - index
            string = string[::-1]
            paragraphs = self.paragraphs_rev
            paragraphs_no_space = self.paragraphs_rev_no_space
        else:
            paragraphs = self.paragraphs
            paragraphs_no_space = self.paragraphs_no_space

        string_ = self.remove_special_chars(string)
        start_ind = max(0, index - delta)
        search_corpus = paragraphs_no_space[start_ind : index + delta + 1]
        lengths = np.asarray([0] + [len(p) for p in search_corpus])
        corp = "".join(search_corpus)
        if len(corp) == 0:
            self._split_locs.append((index, 0))
            return index, 0, 1
        ind, score = self._find_match(corp, string_)
        indices, breakpoints = get_glob_index(lengths, ind, True)
        global_ind, char_ind = int(start_ind + indices), int(ind - breakpoints[indices])
        self._split_locs.append((global_ind, char_ind))
        if reverse:
            char_ind = len(paragraphs_no_space[global_ind]) - char_ind
            global_ind = len(paragraphs) - global_ind - 1
        add_space_ind = self.count_special_chars(self.paragraphs[global_ind], char_ind)
        return global_ind, char_ind + add_space_ind, score

    def _find_match(
        self, corp: str, key: str, get_start: bool = True
    ) -> Tuple[int, float]:
        block, score = self._fuzzy(corp, key)
        index = max(0, block[0])
        if not get_start:
            index += block[2]
        return index, score

    @staticmethod
    def _fuzzy(
        corpus: str, string: str, max_error_rate: float = 0.025
    ) -> Tuple[Tuple[int, int, int], float]:
        max_dist = min(len(string) - 1, int(len(string) * min(0.9, max_error_rate)) + 5)
        matches = find_near_matches(string, corpus, max_l_dist=max_dist)
        if len(matches) > 0 and max_dist > 0:
            match = min(matches, key=lambda x: x.dist)
            block = (match.start, 0, match.end - match.start)
            score = 1 - match.dist / max_dist
            return block, score
        return (0, 0, 0), 0

    @staticmethod
    def fuzzysearch(
        corpus: str, string: str, max_error_rate: float = 0.025
    ) -> Tuple[Tuple[int, int, int], float]:
        corpus_ = Splitter.remove_special_chars(corpus)
        string_ = Splitter.remove_special_chars(string)
        (start, _, dist), score = Splitter._fuzzy(
            corpus_, string_, max_error_rate=max_error_rate
        )
        end = Splitter.count_special_chars(corpus, start + dist) + start + dist
        start = start + Splitter.count_special_chars(corpus, start)
        return (start, _, end - start), score

    def evaluate_split(self, page_num: int, page_content: str) -> float:
        if page_num > len(self._split_locs) or page_num < 1:
            return 0
        page_content = self.remove_special_chars(page_content)
        if page_num == len(self._split_locs):
            start, end = self._split_locs[-1], (-1, -1)
        else:
            start, end = self._split_locs[page_num - 1], self._split_locs[page_num]
        if (end[0] + 1) - start[0] < 0:
            return 0
        doc_content = self.paragraphs_no_space[start[0] : (end[0] + 1) or None]
        if (
            len(doc_content) < 1
            or len(doc_content[0]) < start[1]
            or len(doc_content[-1]) < end[1]
        ):
            return 0
        doc_content[0] = doc_content[0][start[1] :]
        doc_content[-1] = doc_content[-1][: end[1]]
        doc_content = "".join(doc_content)
        match = ratio(page_content, doc_content)
        return match


================================================
FILE: nougat/dataset/staircase.py
================================================
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from collections import deque
import operator
import itertools
from typing import Optional, List, Tuple
import numpy as np
import warnings

warnings.filterwarnings("ignore", message="All-NaN slice encountered")


def stair_func(x: np.ndarray, thresholds: np.ndarray) -> np.ndarray:
    return np.heaviside(x[:, None] - np.floor(thresholds)[None, :], 0).sum(1)


def compute_gini(labels: np.ndarray) -> float:
    N = len(labels)
    if N == 0:
        return 0
    G = N - np.square(np.bincount(labels)).sum() / N
    return G


def compute_binary_gini(labels: np.ndarray) -> float:
    N = len(labels)
    if N == 0:
        return 0
    G = N - labels.sum() ** 2 / N
    return G


def gini_impurity(
    thresholds: np.ndarray,
    data: np.ndarray,
    labels: np.ndarray,
    classes: Optional[List[int]] = None,
    reduction: Optional[str] = "sum",
    padded: bool = True,
) -> float:
    """
    Calculate the Gini impurity of a dataset split on a set of thresholds.

    Args:
        thresholds (np.ndarray): The thresholds to split the data on.
        data (np.ndarray): The data to split.
        labels (np.ndarray): The labels for the data.
        classes (Optional[List[int]]): The classes to consider. If None, all classes are used.
        reduction (Optional[str]): The reduction to apply to the impurity. One of "none", "sum", or "mean".
        padded (bool): Whether to pad the thresholds with `[-0.5, data.max() + 0.5]`.

    Returns:
        float: The Gini impurity.
    """
    G = []
    if not padded:
        thresholds = np.insert(
            thresholds, [0, len(thresholds)], [-0.5, data.max() + 0.5]
        )
    if classes is None:
        classes = np.arange(len(thresholds) - 1)
    else:
        classes = np.asarray(classes)
    if data.ndim == 1:
        data = np.expand_dims(data, 0)
    masks = np.logical_and(
        data > thresholds[classes, None],
        data <= thresholds[classes + 1, None],
    )
    for i, c in enumerate(classes):
        G.append(compute_binary_gini(np.where(labels[masks[i]] == c, 1, 0)))

    if reduction is None or reduction == "none":
        return G
    elif reduction == "sum":
        return sum(G)
    elif reduction == "mean":
        return sum(G) / len(G)
    else:
        raise NotImplementedError


def step_impurity(
    thresholds,
    data: np.ndarray,
    labels: np.ndarray,
    classes: Optional[List[int]] = None,
) -> float:
    """
    Calculate the step-wise Gini impurity of a dataset split on a set of thresholds.

    Args:
        thresholds (np.ndarray): The thresholds to split the data on.
        data (np.ndarray): The data to split.
        labels (np.ndarray): The labels for the data.
        classes (Optional[List[int]]): The classes to consider. If None, all classes are used.

    Returns:
        float: The step-wise Gini impurity.
    """
    G = gini_impurity(thresholds, data, labels, reduction=None, classes=classes)
    out = []
    for i in range(len(G) - 1):
        out.append(G[i] + G[i + 1])
    return out


class PaddedArray:
    """
    A wrapper class for an array that allows for relative indexing.

    Args:
        array (np.ndarray): The array to wrap.
        range (Optional[Tuple[int, int]]): The range of the array to expose. Defaults to (1, -1).
    """

    def __init__(
        self, array: np.ndarray, range: Optional[Tuple[int, int]] = (1, -1)
    ) -> None:
        self.array = array
        mi, ma = range
        assert ma <= 0, "relative assignment only"
        self.range = mi, ma

    def __len__(self):
        return len(self.array) + self.range[1] - self.range[0]

    def _process_index(self, index):
        if isinstance(index, slice):
            index = slice(
                (index.start or 0) + self.range[0],
                self.range[0] + (len(self) if index.stop is None else index.stop),
                index.step,
            )
            if index.stop > len(self.array):
                raise IndexError
        else:
            index = index + self.range[0]
            if index > len(self):
                raise IndexError
        return index

    def __getitem__(self, index):
        index = self._process_index(index)
        return self.array[index]

    def __setitem__(self, index, value):
        self.array[self._process_index(index)] = value

    def copy(self):
        return PaddedArray(self.array.copy(), self.range)

    def toarray(self):
        return self.array[self.range[0] : self.range[1]]


class Staircase:
    """
    A class for learning a staircase decision tree.

    Args:
        domain: The number of points in the domain.
        n_classes: The number of classes.
    """

    def __init__(self, domain: int, n_classes: int) -> None:
        self.domain = domain
        self.classes = n_classes
        assert domain > 0
        assert n_classes > 0
        self.thresholds = self._back_thres = self._forward_thres = np.linspace(
            domain / n_classes, domain, n_classes - 1, endpoint=False
        )
        self.uncertainty = np.zeros_like(self.thresholds)

    def statistic_fit(
        self,
        data: np.ndarray,
        labels: np.ndarray,
    ):
        """
        Fit statistical thresholds for anomaly detection.

        This method fits statistical thresholds for anomaly detection based on input data and labels.

        Args:
            data (np.ndarray): The input data.
            labels (np.ndarray): The labels corresponding to the data.

        Note:
            This method modifies the internal state of the object to set statistical thresholds.
        """
        onehot = np.eye(self.classes)[labels.reshape(-1)]
        onehot.reshape(list(labels.shape) + [self.classes])
        k = onehot * data.T.repeat(self.classes, 1)
        k[k == 0] = np.nan
        med = np.nanmedian(k, 0)
        for i in range(len(med)):
            if med[i] != med[i]:
                med[i] = 0 if i == 0 else med[i - 1]
        mad = 5 * np.nan_to_num(
            np.nanmedian(np.absolute(k - np.nanmedian(k, 0)), 0),
            nan=self.domain / self.classes / 2,
        )
        arr = np.vstack(((med - mad)[:-1], (med + mad)[1:]))
        self._forward_thres[:] = arr.max(0)
        self._back_thres[:] = arr.min(0)

        self._stat_forward = self._forward_thres.copy()
        self._stat_back = self._back_thres.copy()

    def fit(
        self,
        data: np.ndarray,
        labels: np.ndarray,
        early_stop_after: int = 10,
        fixed: bool = True,
    ) -> None:
        """
        Fit statistical thresholds for anomaly detection.

        This method fits statistical thresholds for anomaly detection based on input data and labels.

        Args:
            data (np.ndarray): The input data.
            labels (np.ndarray): The labels corresponding to the data.
            early_stop_after (int, optional): The number of consecutive early stops to consider. Default is 10.
            fixed (bool, optional): Whether to use fixed thresholds. Default is True.

        Note:
            This method modifies the internal state of the object to set statistical thresholds.
        """
        assert data.ndim == 1
        assert labels.ndim <= 2
        if self.classes == 1:
            self.thresholds = np.array([0.5 + data.max()])
            self.uncertainty = np.zeros_like(self.thresholds)
        if data.ndim == 1:
            data = np.expand_dims(data, 0)
        thresholds = PaddedArray(
            np.insert(
                np.arange(self.domain - self.classes + 1, self.domain) - 1,
                [0, self.classes - 1],
                [-0.5, self.domain + 0.5],
            ).astype(int)
        )
        self._back_thres = thresholds.copy()
        self._forward_thres = thresholds.copy()
        self.statistic_fit(data, labels)
        last = -0.5
        for n in range(self.classes):
            G = np.inf
            Gis = deque([], early_stop_after)
            # forward pass
            if n < self.classes - 1:
                new_forward_n: float = self._forward_thres[n]
                for i in range(
                    max(0, self._back_thres[n - 1]) if n - 1 >= 0 else int(last),
                    min(self.domain, self._forward_thres[n + 1])
                    if n + 2 < self.classes
                    else self.domain - 1,
                ):
                    thresholds.array[n + 1] = i + 0.5
                    Gi = step_impurity(
                        thresholds.array, data, labels, classes=[n, n + 1]
                    )[0]
                    Gis.append(Gi)
                    if Gi <= G:
                        last = i + 0.5
                        new_forward_n = last
                        G = Gi
                    elif (
                        (not fixed or i - last > self.domain / self.classes)
                        and len(Gis) == early_stop_after
                        and all(
                            itertools.starmap(
                                operator.ge,
                                zip(Gis, itertools.islice(Gis, 1, early_stop_after)),
                            )
                        )
                    ):
                        break
                thresholds.array[n + 1] = new_forward_n
                self._forward_thres.array[n + 1] = new_forward_n
                self._back_thres.array[n + 1] = new_forward_n
            G = np.inf
        self._forward_thres = self._forward_thres.toarray().clip(
            min=0, max=self.domain - 1
        )
        self._back_thres = self._back_thres.toarray().clip(min=0, max=self.domain - 1)
        self.thresholds = (self._forward_thres + self._back_thres) / 2
        self.uncertainty = np.abs(self._forward_thres - self._back_thres) / 2

    @property
    def score(self):
        try:
            return gini_impurity(self.thresholds, self._data, self._labels) / len(
                self._data
            )
        except AttributeError:
            return np.inf

    def predict(self, x: np.ndarray) -> np.ndarray:
        return stair_func(x, self.get_boundaries())

    def __call__(self, *args):
        return self.predict(*args)

    def get_boundaries(self) -> np.ndarray:
        return self.thresholds.astype(int).clip(min=0, max=self.domain - 1) + 0.5


================================================
FILE: nougat/dataset/tokenizer.json
================================================
{
  "version": "1.0",
  "truncation": {
    "direction": "Right",
    "max_length": 4096,
    "strategy": "LongestFirst",
    "stride": 0
  },
  "padding": {
    "strategy": {
      "Fixed": 4096
    },
    "direction": "Right",
    "pad_to_multiple_of": null,
    "pad_id": 1,
    "pad_type_id": 0,
    "pad_token": "<pad>"
  },
  "added_tokens": [
    {
      "id": 0,
      "content": "<s>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 1,
      "content": "<pad>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 2,
      "content": "</s>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 3,
      "content": "<unk>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 4,
      "content": "[START_REF]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 5,
      "content": "[END_REF]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 6,
      "content": "[IMAGE]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 7,
      "content": "<fragments>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 8,
      "content": "</fragments>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 9,
      "content": "<work>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 10,
      "content": "</work>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 11,
      "content": "[START_SUP]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 12,
      "content": "[END_SUP]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 13,
      "content": "[START_SUB]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 14,
      "content": "[END_SUB]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 15,
      "content": "[START_DNA]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 16,
      "content": "[END_DNA]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 17,
      "content": "[START_AMINO]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 18,
      "content": "[END_AMINO]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 19,
      "content": "[START_SMILES]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 20,
      "content": "[END_SMILES]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 21,
      "content": "[START_I_SMILES]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 22,
      "content": "[END_I_SMILES]",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    }
  ],
  "normalizer": {
    "type": "NFKC"
  },
  "pre_tokenizer": {
    "type": "Sequence",
    "pretokenizers": [
      {
        "type": "Split",
        "pattern": {
          "String": "SPL1T-TH1S-Pl3A5E"
        },
        "behavior": "Removed",
        "invert": false
      },
      {
        "type": "Digits",
        "individual_digits": true
      },
      {
        "type": "Split",
        "pattern": {
          "Regex": "[\\(\\)\\[\\]\\{\\}]|([!\"\\#\\$%\\&'\\*\\+,\\-\\./:;<=>\\?\\\\\\^_`\\|\\~])\\1*"
        },
        "behavior": "Isolated",
        "invert": false
      },
      {
        "type": "Split",
        "pattern": {
          "String": "\n"
        },
        "behavior": "Isolated",
        "invert": false
      },
      {
        "type": "ByteLevel",
        "add_prefix_space": false,
        "trim_offsets": true,
        "use_regex": true
      }
    ]
  },
  "post_processor": {
    "type": "TemplateProcessing",
    "single": [
      {
        "SpecialToken": {
          "id": "<s>",
          "type_id": 0
        }
      },
      {
        "Sequence": {
          "id": "A",
          "type_id": 0
        }
      },
      {
        "SpecialToken": {
          "id": "</s>",
          "type_id": 0
        }
      }
    ],
    "pair": [
      {
        "Sequence": {
          "id": "A",
          "type_id": 0
        }
      },
      {
        "Sequence": {
          "id": "B",
          "type_id": 1
        }
      }
    ],
    "special_tokens": {
      "</s>": {
        "id": "</s>",
        "ids": [
          2
        ],
        "tokens": [
          "</s>"
        ]
      },
      "<s>": {
        "id": "<s>",
        "ids": [
          0
        ],
        "tokens": [
          "<s>"
        ]
      }
    }
  },
  "decoder": {
    "type": "ByteLevel",
    "add_prefix_space": true,
    "trim_offsets": true,
    "use_regex": true
  },
  "model": {
    "type": "BPE",
    "dropout": null,
    "unk_token": null,
    "continuing_subword_prefix": null,
    "end_of_word_suffix": null,
    "fuse_unk": false,
    "vocab": {
      "<s>": 0,
      "<pad>": 1,
      "</s>": 2,
      "<unk>": 3,
      "[START_REF]": 4,
      "[END_REF]": 5,
      "[IMAGE]": 6,
      "<fragments>": 7,
      "</fragments>": 8,
      "<work>": 9,
      "</work>": 10,
      "[START_SUP]": 11,
      "[END_SUP]": 12,
      "[START_SUB]": 13,
      "[END_SUB]": 14,
      "[START_DNA]": 15,
      "[END_DNA]": 16,
      "[START_AMINO]": 17,
      "[END_AMINO]": 18,
      "[START_SMILES]": 19,
      "[END_SMILES]": 20,
      "[START_I_SMILES]": 21,
      "[END_I_SMILES]": 22,
      "!": 23,
      "\"": 24,
      "#": 25,
      "$": 26,
      "%": 27,
      "&": 28,
      "'": 29,
      "(": 30,
      ")": 31,
      "*": 32,
      "+": 33,
      ",": 34,
      "-": 35,
      ".": 36,
      "/": 37,
      "0": 38,
      "1": 39,
      "2": 40,
      "3": 41,
      "4": 42,
      "5": 43,
      "6": 44,
      "7": 45,
      "8": 46,
      "9": 47,
      ":": 48,
      ";": 49,
      "<": 50,
      "=": 51,
      ">": 52,
      "?": 53,
      "@": 54,
      "A": 55,
      "B": 56,
      "C": 57,
      "D": 58,
      "E": 59,
      "F": 60,
      "G": 61,
      "H": 62,
      "I": 63,
      "J": 64,
      "K": 65,
      "L": 66,
      "M": 67,
      "N": 68,
      "O": 69,
      "P": 70,
      "Q": 71,
      "R": 72,
      "S": 73,
      "T": 74,
      "U": 75,
      "V": 76,
      "W": 77,
      "X": 78,
      "Y": 79,
      "Z": 80,
      "[": 81,
      "\\": 82,
      "]": 83,
      "^": 84,
      "_": 85,
      "`": 86,
      "a": 87,
      "b": 88,
      "c": 89,
      "d": 90,
      "e": 91,
      "f": 92,
      "g": 93,
      "h": 94,
      "i": 95,
      "j": 96,
      "k": 97,
      "l": 98,
      "m": 99,
      "n": 100,
      "o": 101,
      "p": 102,
      "q": 103,
      "r": 104,
      "s": 105,
      "t": 106,
      "u": 107,
      "v": 108,
      "w": 109,
      "x": 110,
      "y": 111,
      "z": 112,
      "{": 113,
      "|": 114,
      "}": 115,
      "~": 116,
      "¡": 117,
      "¢": 118,
      "£": 119,
      "¤": 120,
      "¥": 121,
      "¦": 122,
      "§": 123,
      "¨": 124,
      "©": 125,
      "ª": 126,
      "«": 127,
      "¬": 128,
      "®": 129,
      "¯": 130,
      "°": 131,
      "±": 132,
      "²": 133,
      "³": 134,
      "´": 135,
      "µ": 136,
      "¶": 137,
      "·": 138,
      "¸": 139,
      "¹": 140,
      "º": 141,
      "»": 142,
      "¼": 143,
      "½": 144,
      "¾": 145,
      "¿": 146,
      "À": 147,
      "Á": 148,
      "Â": 149,
      "Ã": 150,
      "Ä": 151,
      "Å": 152,
      "Æ": 153,
      "Ç": 154,
      "È": 155,
      "É": 156,
      "Ê": 157,
      "Ë": 158,
      "Ì": 159,
      "Í": 160,
      "Î": 161,
      "Ï": 162,
      "Ð": 163,
      "Ñ": 164,
      "Ò": 165,
      "Ó": 166,
      "Ô": 167,
      "Õ": 168,
      "Ö": 169,
      "×": 170,
      "Ø": 171,
      "Ù": 172,
      "Ú": 173,
      "Û": 174,
      "Ü": 175,
      "Ý": 176,
      "Þ": 177,
      "ß": 178,
      "à": 179,
      "á": 180,
      "â": 181,
      "ã": 182,
      "ä": 183,
      "å": 184,
      "æ": 185,
      "ç": 186,
      "è": 187,
      "é": 188,
      "ê": 189,
      "ë": 190,
      "ì": 191,
      "í": 192,
      "î": 193,
      "ï": 194,
      "ð": 195,
      "ñ": 196,
      "ò": 197,
      "ó": 198,
      "ô": 199,
      "õ": 200,
      "ö": 201,
      "÷": 202,
      "ø": 203,
      "ù": 204,
      "ú": 205,
      "û": 206,
      "ü": 207,
      "ý": 208,
      "þ": 209,
      "ÿ": 210,
      "Ā": 211,
      "ā": 212,
      "Ă": 213,
      "ă": 214,
      "Ą": 215,
      "ą": 216,
      "Ć": 217,
      "ć": 218,
      "Ĉ": 219,
      "ĉ": 220,
      "Ċ": 221,
      "ċ": 222,
      "Č": 223,
      "č": 224,
      "Ď": 225,
      "ď": 226,
      "Đ": 227,
      "đ": 228,
      "Ē": 229,
      "ē": 230,
      "Ĕ": 231,
      "ĕ": 232,
      "Ė": 233,
      "ė": 234,
      "Ę": 235,
      "ę": 236,
      "Ě": 237,
      "ě": 238,
      "Ĝ": 239,
      "ĝ": 240,
      "Ğ": 241,
      "ğ": 242,
      "Ġ": 243,
      "ġ": 244,
      "Ģ": 245,
      "ģ": 246,
      "Ĥ": 247,
      "ĥ": 248,
      "Ħ": 249,
      "ħ": 250,
      "Ĩ": 251,
      "ĩ": 252,
      "Ī": 253,
      "ī": 254,
      "Ĭ": 255,
      "ĭ": 256,
      "Į": 257,
      "į": 258,
      "İ": 259,
      "ı": 260,
      "IJ": 261,
      "ij": 262,
      "Ĵ": 263,
      "ĵ": 264,
      "Ķ": 265,
      "ķ": 266,
      "ĸ": 267,
      "Ĺ": 268,
      "ĺ": 269,
      "Ļ": 270,
      "ļ": 271,
      "Ľ": 272,
      "ľ": 273,
      "Ŀ": 274,
      "ŀ": 275,
      "Ł": 276,
      "ł": 277,
      "Ń": 278,
      "Ġt": 279,
      "in": 280,
      "Ġa": 281,
      "he": 282,
      "on": 283,
      "re": 284,
      "at": 285,
      "Ġthe": 286,
      "er": 287,
      "Ġs": 288,
      "Ġo": 289,
      "en": 290,
      "al": 291,
      "Ġc": 292,
      "ti": 293,
      "or": 294,
      "ed": 295,
      "es": 296,
      "is": 297,
      "Ġp": 298,
      "Ġof": 299,
      "nd": 300,
      "Ġin": 301,
      "Ġf": 302,
      "Ġw": 303,
      "ĠĠ": 304,
      "it": 305,
      "an": 306,
      "ro": 307,
      "ar": 308,
      "Ġd": 309,
      "Ġm": 310,
      "Ġb": 311,
      "Ġand": 312,
      "ic": 313,
      "le": 314,
      "ing": 315,
      "ion": 316,
      "as": 317,
      "Ġe": 318,
      "Ġre": 319,
      "ation": 320,
      "Ġto": 321,
      "el": 322,
      "ent": 323,
      "ac": 324,
      "et": 325,
      "ec": 326,
      "tion": 327,
      "om": 328,
      "st": 329,
      "ĠT": 330,
      "Ġn": 331,
      "Ġth": 332,
      "ol": 333,
      "ul": 334,
      "im": 335,
      "RE": 336,
      "ig": 337,
      "us": 338,
      "REF": 339,
      "Ġl": 340,
      "Ġh": 341,
      "ur": 342,
      "Ġis": 343,
      "ĠĠĠĠ": 344,
      "Ġfor": 345,
      "id": 346,
      "am": 347,
      "ĠS": 348,
      "ve": 349,
      "il": 350,
      "ĠA": 351,
      "ĠC": 352,
      "Ġg": 353,
      "ot": 354,
      "ith": 355,
      "ly": 356,
      "ce": 357,
      "Ġcon": 358,
      "ow": 359,
      "Ġst": 360,
      "ut": 361,
      "os": 362,
      "Ġwith": 363,
      "od": 364,
      "ra": 365,
      "Ġv": 366,
      "Ġpro": 367,
      "um": 368,
      "ĠI": 369,
      "if": 370,
      "uc": 371,
      "ter": 372,
      "un": 373,
      "AR": 374,
      "ST": 375,
      "res": 376,
      "Ġon": 377,
      "EN": 378,
      "ere": 379,
      "ĠP": 380,
      "ĠThe": 381,
      "ĠM": 382,
      "Ġas": 383,
      "ART": 384,
      "Ġan": 385,
      "END": 386,
      "START": 387,
      "Ġthat": 388,
      "qu": 389,
      "em": 390,
      "Ġbe": 391,
      "Ġex": 392,
      "ri": 393,
      "ab": 394,
      "ity": 395,
      "tic": 396,
      "ver": 397,
      "Ġal": 398,
      "pl": 399,
      "ts": 400,
      "ĠF": 401,
      "Ġâ": 402,
      "ure": 403,
      "Ġby": 404,
      "ate": 405,
      "ag": 406,
      "ir": 407,
      "oc": 408,
      "per": 409,
      "ĠB": 410,
      "ay": 411,
      "ĠD": 412,
      "Ġcom": 413,
      "ĠH": 414,
      "ated": 415,
      "ĠR": 416,
      "Ġare": 417,
      "rom": 418,
      "ĠE": 419,
      "op": 420,
      "ad": 421,
      "se": 422,
      "ĠL": 423,
      "igh": 424,
      "ĠN": 425,
      "ment": 426,
      "her": 427,
      "og": 428,
      "ain": 429,
      "ect": 430,
      "ud": 431,
      "Ġde": 432,
      "Ġr": 433,
      "Ġat": 434,
      "Ġwas": 435,
      "Ġus": 436,
      "Ġres": 437,
      "ell": 438,
      "iz": 439,
      "ine": 440,
      "ph": 441,
      "Ġac": 442,
      "ess": 443,
      "ore": 444,
      "ical": 445,
      "th": 446,
      "und": 447,
      "rac": 448,
      "Ġwe": 449,
      "ath": 450,
      "ĠG": 451,
      "Ġfrom": 452,
      "ati": 453,
      "up": 454,
      "ist": 455,
      "ant": 456,
      "Ġor": 457,
      "ff": 458,
      "Ġcomp": 459,
      "Ġwh": 460,
      "ĠW": 461,
      "ch": 462,
      "ers": 463,
      "Ġsp": 464,
      "orm": 465,
      "Ġch": 466,
      "ations": 467,
      "ran": 468,
      "ub": 469,
      "te": 470,
      "di": 471,
      "Ġsh": 472,
      "ge": 473,
      "ase": 474,
      "Ġwere": 475,
      "ĠĠĠĠĠĠĠĠ": 476,
      "ĠÎ": 477,
      "ap": 478,
      "ĠIn": 479,
      "and": 480,
      "Ġse": 481,
      "vel": 482,
      "Ġim": 483,
      "ĠâĪ": 484,
      "ens": 485,
      "ies": 486,
      "ich": 487,
      "ight": 488,
      "duc": 489,
      "ĠO": 490,
      "Ġit": 491,
      "tions": 492,
      "end": 493,
      "Ġco": 494,
      "Ġthis": 495,
      "Ġcan": 496,
      "Ġk": 497,
      "âĢ": 498,
      "lec": 499,
      "ted": 500,
      "Ġmod": 501,
      "math": 502,
      "Ġcont": 503,
      "Ġne": 504,
      "Ġpar": 505,
      "ib": 506,
      "ĠĠĠ": 507,
      "Ġle": 508,
      "iv": 509,
      "ug": 510,
      "ence": 511,
      "ign": 512,
      "ous": 513,
      "ents": 514,
      "ys": 515,
      "ave": 516,
      "red": 517,
      "ress": 518,
      "able": 519,
      "por": 520,
      "all": 521,
      "iff": 522,
      "est": 523,
      "Ġap": 524,
      "Ġinc": 525,
      "nt": 526,
      "ary": 527,
      "iti": 528,
      "Ġwhich": 529,
      "Ġnot": 530,
      "form": 531,
      "Ġsy": 532,
      "Ġad": 533,
      "low": 534,
      "ak": 535,
      "Ġper": 536,
      "Ġhe": 537,
      "pro": 538,
      "ance": 539,
      "ial": 540,
      "ue": 541,
      "Ġen": 542,
      "Ġcl": 543,
      "ass": 544,
      "ip": 545,
      "rans": 546,
      "Ġob": 547,
      "Ġgen": 548,
      "tim": 549,
      "Ġdis": 550,
      "unc": 551,
      "Ġint": 552,
      "ep": 553,
      "etw": 554,
      "Ġdiff": 555,
      "ach": 556,
      "ther": 557,
      "ime": 558,
      "age": 559,
      "ple": 560,
      "ill": 561,
      "yp": 562,
      "ĠK": 563,
      "act": 564,
      "ari": 565,
      "Ġmet": 566,
      "ors": 567,
      "Ġhave": 568,
      "Ġstud": 569,
      "ong": 570,
      "ĠU": 571,
      "Ġpl": 572,
      "ide": 573,
      "ma": 574,
      "hen": 575,
      "ific": 576,
      "ome": 577,
      "Ġi": 578,
      "ular": 579,
      "ĠV": 580,
      "ally": 581,
      "Ġshow": 582,
      "rib": 583,
      "ia": 584,
      "enti": 585,
      "Ġass": 586,
      "ond": 587,
      "ft": 588,
      "Ġab": 589,
      "Ġinter": 590,
      "ĠTh": 591,
      "The": 592,
      "str": 593,
      "Ġcell": 594,
      "cal": 595,
      "Ġmodel": 596,
      "ata": 597,
      "ast": 598,
      "Ġeff": 599,
      "Ġtrans": 600,
      "ates": 601,
      "ased": 602,
      "ost": 603,
      "vi": 604,
      "ang": 605,
      "our": 606,
      "Ġme": 607,
      "ard": 608,
      "Ġdiffere": 609,
      "Ġpre": 610,
      "Ġdi": 611,
      "ĠâĪĴ": 612,
      "olog": 613,
      "ution": 614,
      "ound": 615,
      "ace": 616,
      "Ġresul": 617,
      "erm": 618,
      "pos": 619,
      "here": 620,
      "tive": 621,
      "ord": 622,
      "so": 623,
      "stem": 624,
      "yl": 625,
      "Ġph": 626,
      "Ġy": 627,
      "ame": 628,
      "ork": 629,
      "ative": 630,
      "Ġqu": 631,
      "ric": 632,
      "SU": 633,
      "wo": 634,
      "Ġun": 635,
      "Ġev": 636,
      "are": 637,
      "##": 638,
      "de": 639,
      "een": 640,
      "tiv": 641,
      "Ġgro": 642,
      "ory": 643,
      "Ġcons": 644,
      "Ġsub": 645,
      "ta": 646,
      "--": 647,
      "Ġstr": 648,
      "ber": 649,
      "erv": 650,
      "etween": 651,
      "enc": 652,
      "Ġanal": 653,
      "int": 654,
      "Ġhas": 655,
      "uch": 656,
      "Ġreg": 657,
      "Ġbetween": 658,
      "Ġdet": 659,
      "Ġall": 660,
      "cess": 661,
      "Ġexp": 662,
      "ection": 663,
      "ĠâĢ": 664,
      "ind": 665,
      "ater": 666,
      "Ġsign": 667,
      "pt": 668,
      "ugh": 669,
      "ite": 670,
      "ility": 671,
      "Ġusing": 672,
      "Ġval": 673,
      "Ġro": 674,
      "ree": 675,
      "Ġrel": 676,
      "out": 677,
      "Ġfunc": 678,
      "ition": 679,
      "Ġcor": 680,
      "Ġalso": 681,
      "Ġtwo": 682,
      "ne": 683,
      "ĠJ": 684,
      "Ġsystem": 685,
      "cl": 686,
      "uct": 687,
      "Ġsim": 688,
      "tain": 689,
      "ust": 690,
      "ied": 691,
      "port": 692,
      "Ġrec": 693,
      "Ġresp": 694,
      "Ġdata": 695,
      "rm": 696,
      "resent": 697,
      "uld": 698,
      "xt": 699,
      "Ġj": 700,
      "ry": 701,
      "ack": 702,
      "Ġra": 703,
      "par": 704,
      "Ġform": 705,
      "Ġsc": 706,
      "frac": 707,
      "ĠWe": 708,
      "ating": 709,
      "ech": 710,
      "hod": 711,
      "Ġfol": 712,
      "ined": 713,
      "ĠSt": 714,
      "ual": 715,
      "Ġused": 716,
      "Ġone": 717,
      "Ġdes": 718,
      "ĠÏ": 719,
      "Ġvari": 720,
      "Ġdist": 721,
      "Ġnum": 722,
      "ym": 723,
      "ew": 724,
      "rec": 725,
      "ob": 726,
      "Ġinf": 727,
      "Ġar": 728,
      "lect": 729,
      "ll": 730,
      "ons": 731,
      "ĠThis": 732,
      "ose": 733,
      "ile": 734,
      "play": 735,
      "ear": 736,
      "ox": 737,
      "ures": 738,
      "one": 739,
      "Ġstudy": 740,
      "ysis": 741,
      "Ġfollow": 742,
      "yle": 743,
      "ract": 744,
      "dis": 745,
      "Ġpos": 746,
      "right": 747,
      "Ġthan": 748,
      "ros": 749,
      "av": 750,
      "Fig": 751,
      "Ġtime": 752,
      "ization": 753,
      "ulation": 754,
      "ized": 755,
      "Ġsur": 756,
      "oth": 757,
      "Ġout": 758,
      "Ġcol": 759,
      "ature": 760,
      "ive": 761,
      "Ġsol": 762,
      "Ġx": 763,
      "eld": 764,
      "Ġother": 765,
      "plic": 766,
      "Ġdef": 767,
      "erg": 768,
      "Ġgener": 769,
      "ely": 770,
      "Ġbeen": 771,
      "Ġincre": 772,
      "Ġthese": 773,
      "Ġno": 774,
      "ax": 775,
      "style": 776,
      "arg": 777,
      "ian": 778,
      "Ġind": 779,
      "Ġsuch": 780,
      "Ġfunction": 781,
      "ting": 782,
      "Ġequ": 783,
      "aus": 784,
      "Ġund": 785,
      "mathb": 786,
      "tical": 787,
      "Ġhigh": 788,
      "rain": 789,
      "Ġam": 790,
      "ield": 791,
      "oun": 792,
      "ression": 793,
      "Ġspec": 794,
      "Ġop": 795,
      "Ġdec": 796,
      "Ġover": 797,
      "Ġmethod": 798,
      "Ġset": 799,
      "âĪ": 800,
      "Ġif": 801,
      "dition": 802,
      "ues": 803,
      "ects": 804,
      "display": 805,
      "hem": 806,
      "Ġpati": 807,
      "Ġresults": 808,
      "old": 809,
      "anc": 810,
      "displaystyle": 811,
      "Ġeach": 812,
      "Ġmore": 813,
      "les": 814,
      "pr": 815,
      "acter": 816,
      "Ġtheir": 817,
      "Ġacc": 818,
      "Ġappro": 819,
      "iss": 820,
      "ize": 821,
      "Ġinv": 822,
      "ases": 823,
      "Ġcells": 824,
      "irst": 825,
      "lu": 826,
      "ail": 827,
      "Ġmeas": 828,
      "Ġlow": 829,
      "ov": 830,
      "the": 831,
      "ik": 832,
      "**": 833,
      "ef": 834,
      "Ġbut": 835,
      "hes": 836,
      "fter": 837,
      "Ġdifferent": 838,
      "vely": 839,
      "Ġext": 840,
      "Ġthere": 841,
      "oci": 842,
      "Ġprob": 843,
      "Ġits": 844,
      "ron": 845,
      "ments": 846,
      "Ġag": 847,
      "NA": 848,
      "Ġpo": 849,
      "ice": 850,
      "ype": 851,
      "Ġgroup": 852,
      "âĢĵ": 853,
      "ever": 854,
      "ult": 855,
      "ism": 856,
      "tern": 857,
      "ability": 858,
      "ions": 859,
      "ark": 860,
      "Ġnon": 861,
      "to": 862,
      "ĠĠĠĠĠĠĠ": 863,
      "Ġobs": 864,
      "Ġtre": 865,
      "als": 866,
      "left": 867,
      "ĠPro": 868,
      "Ġonly": 869,
      "Ġman": 870,
      "der": 871,
      "Ġpol": 872,
      "uring": 873,
      "amet": 874,
      "rol": 875,
      "In": 876,
      "yn": 877,
      "Ġunder": 878,
      "ĠCh": 879,
      "Ġwhere": 880,
      "ood": 881,
      "ĠX": 882,
      "nce": 883,
      "Ġpartic": 884,
      "ected": 885,
      "ĠFig": 886,
      "Ġem": 887,
      "Ġfact": 888,
      "ĠAn": 889,
      "Ġperform": 890,
      "Ġso": 891,
      "Ġanalysis": 892,
      "stract": 893,
      "hed": 894,
      "Ġmay": 895,
      "atic": 896,
      "Ġrep": 897,
      "tein": 898,
      "duced": 899,
      "Ġup": 900,
      "Ġinto": 901,
      "Ġnumber": 902,
      "Ġour": 903,
      "Ġet": 904,
      "eg": 905,
      "itle": 906,
      "over": 907,
      "ix": 908,
      "ator": 909,
      "ulti": 910,
      "Ġincl": 911,
      "ould": 912,
      "ici": 913,
      "bstract": 914,
      "Ġcomple": 915,
      "Ġpatients": 916,
      "Ġdo": 917,
      "Ġexper": 918,
      "vid": 919,
      "ange": 920,
      "Ġlevel": 921,
      "Ġprocess": 922,
      
Download .txt
gitextract_rzpn8tp_/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE-MODEL.md
├── MANIFEST.in
├── NOTICE
├── README.md
├── app.py
├── config/
│   └── train_nougat.yaml
├── docker/
│   ├── Dockerfile
│   └── README.md
├── lightning_module.py
├── nougat/
│   ├── __init__.py
│   ├── _version.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── create_index.py
│   │   ├── gen_seek.py
│   │   ├── parser/
│   │   │   ├── __init__.py
│   │   │   ├── document.py
│   │   │   ├── html2md.py
│   │   │   ├── latexml_parser.py
│   │   │   └── markdown.py
│   │   ├── pdffigures.py
│   │   ├── rasterize.py
│   │   ├── split_htmls_to_pages.py
│   │   ├── split_md_to_pages.py
│   │   ├── splitter.py
│   │   ├── staircase.py
│   │   ├── tokenizer.json
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── latex_conversion.py
│   │       ├── pdf_text_extract.py
│   │       └── utils.py
│   ├── metrics.py
│   ├── model.py
│   ├── postprocessing.py
│   ├── transforms.py
│   └── utils/
│       ├── __init__.py
│       ├── checkpoint.py
│       ├── dataset.py
│       └── device.py
├── predict.py
├── setup.cfg
├── setup.py
├── test.py
└── train.py
Download .txt
SYMBOL INDEX (289 symbols across 28 files)

FILE: app.py
  function load_model (line 50) | async def load_model(
  function root (line 63) | def root():
  function predict (line 73) | async def predict(
  function main (line 166) | def main():

FILE: lightning_module.py
  class NougatModelPLModule (line 23) | class NougatModelPLModule(pl.LightningModule):
    method __init__ (line 24) | def __init__(self, config):
    method training_step (line 60) | def training_step(self, batch, batch_idx):
    method validation_step (line 78) | def validation_step(self, batch, batch_idx, dataset_idx=0):
    method on_validation_epoch_end (line 102) | def on_validation_epoch_end(self):
    method configure_optimizers (line 110) | def configure_optimizers(self):
    method cosine_scheduler (line 159) | def cosine_scheduler(optimizer, training_steps, warmup_steps):
    method exponential_scheduler (line 170) | def exponential_scheduler(optimizer, warmup_steps, lr, min_lr=5e-5, ga...
    method get_progress_bar_dict (line 182) | def get_progress_bar_dict(self):
    method on_save_checkpoint (line 190) | def on_save_checkpoint(self, checkpoint):
  class NougatDataPLModule (line 200) | class NougatDataPLModule(pl.LightningDataModule):
    method __init__ (line 201) | def __init__(self, config):
    method train_dataloader (line 211) | def train_dataloader(self):
    method val_dataloader (line 226) | def val_dataloader(self):
    method seed_worker (line 239) | def seed_worker(wordker_id):
    method ignore_none_collate (line 245) | def ignore_none_collate(batch):

FILE: nougat/dataset/create_index.py
  function convert_pt2px (line 30) | def convert_pt2px(pt, dpi=96):
  function read_metadata (line 39) | def read_metadata(data: Dict) -> List[List[Dict]]:
  function index_paper (line 58) | def index_paper(directory: Path, args: argparse.Namespace):
  function create_index (line 102) | def create_index(args):

FILE: nougat/dataset/gen_seek.py
  function get_args (line 13) | def get_args():

FILE: nougat/dataset/parser/document.py
  class Element (line 35) | class Element(Generic[EL]):
    method plaintext (line 48) | def plaintext(self):
    method append (line 51) | def append(self, child: EL) -> EL:
    method find_parent (line 56) | def find_parent(self, class_or_tuple: Type[T]) -> T:
  class UnknownElement (line 66) | class UnknownElement(Element):
  class TextElement (line 71) | class TextElement(Element):
    method plaintext (line 75) | def plaintext(self):
    method append (line 78) | def append(self, child: "Element"):
  class Math (line 83) | class Math(Element):
  class PlaintextMath (line 88) | class PlaintextMath(Math):
  class LatexMath (line 93) | class LatexMath(Math):
    method plaintext (line 98) | def plaintext(self):
  class Author (line 103) | class Author:
  class Link (line 110) | class Link(Element):
  class InlineRef (line 115) | class InlineRef(Element):
    method as_dict (line 118) | def as_dict(self):
  class Reference (line 125) | class Reference:
    method as_dict (line 150) | def as_dict(self):
  class SpanElement (line 163) | class SpanElement(Element):
  class Italic (line 168) | class Italic(SpanElement):
  class Bold (line 173) | class Bold(SpanElement):
  class Superscript (line 178) | class Superscript(SpanElement):
  class Subscript (line 183) | class Subscript(SpanElement):
  class Paragraph (line 188) | class Paragraph(Element):
  class TableRow (line 193) | class TableRow(Element):
    method add_cell (line 196) | def add_cell(self, cell: Element):
    method plaintext (line 202) | def plaintext(self):
    method add_cell (line 535) | def add_cell(self, cell: TableCell):
    method __iter__ (line 540) | def __iter__(self):
    method __len__ (line 543) | def __len__(self) -> int:
    method __bool__ (line 546) | def __bool__(self) -> bool:
    method cum_cell_widths (line 550) | def cum_cell_widths(self) -> List[int]:
    method cell_widths (line 554) | def cell_widths(self) -> List[int]:
    method width (line 558) | def width(self) -> int:
    method _hline (line 561) | def _hline(self, orientation: str) -> str:
    method hline_above (line 592) | def hline_above(self) -> str:
    method hline_below (line 596) | def hline_below(self) -> str:
    method plaintext (line 600) | def plaintext(self) -> str:
  class TableHead (line 207) | class TableHead(TableRow):
  class Table (line 212) | class Table(Element):
    method add_row (line 219) | def add_row(self, row: TableRow) -> TableRow:
    method plaintext (line 225) | def plaintext(self):
  class Equation (line 230) | class Equation(Element):
  class EquationList (line 235) | class EquationList(Element):
    method add_equation (line 238) | def add_equation(self, eqn: Equation) -> Equation:
    method plaintext (line 244) | def plaintext(self):
  class Algorithm (line 249) | class Algorithm(Element):
    method add_line (line 254) | def add_line(self, line: Element) -> Element:
    method plaintext (line 260) | def plaintext(self):
  class Definition (line 265) | class Definition(Element):
    method plaintext (line 270) | def plaintext(self):
  class DefinitionList (line 280) | class DefinitionList(Element):
    method add_item (line 295) | def add_item(self, item: Definition) -> Definition:
    method plaintext (line 301) | def plaintext(self):
  class Figure (line 310) | class Figure(Element):
  class Section (line 317) | class Section(Element):
  class SectionHeader (line 325) | class SectionHeader(Element):
  class ListItem (line 332) | class ListItem(Element):
  class ListContainer (line 337) | class ListContainer(Element):
    method add_item (line 342) | def add_item(self, item: ListItem) -> ListItem:
    method plaintext (line 348) | def plaintext(self):
  class Footnote (line 353) | class Footnote(Element):
  class Document (line 358) | class Document(Element, Reference):
    method add_reference (line 366) | def add_reference(self, reference):
    method add_inline_ref (line 369) | def add_inline_ref(self, in_ref):
    method set_bib (line 372) | def set_bib(self, reference):
  class Spec (line 377) | class Spec:
    method __hash__ (line 405) | def __hash__(self) -> int:
    method __eq__ (line 408) | def __eq__(self, __o: object) -> bool:
    method set_align (line 411) | def set_align(self, classes: List[str], style: Optional[str] = None) -...
    method set_border (line 439) | def set_border(self, classes: List[str]) -> None:
    method set_attrs (line 446) | def set_attrs(self, attrs: Dict[str, Any]) -> None:
    method __str__ (line 454) | def __str__(self) -> str:
  class TableCell (line 463) | class TableCell(Element):
    method __post_init__ (line 486) | def __post_init__(self, *args, **kwargs) -> None:
    method __hash__ (line 491) | def __hash__(self) -> int:
    method __eq__ (line 494) | def __eq__(self, __o: object) -> bool:
    method set_attrs (line 497) | def set_attrs(self, attrs: Dict[str, Any]) -> None:
    method plaintext (line 505) | def plaintext(self):
  class TableRow (line 512) | class TableRow(Element):
    method add_cell (line 196) | def add_cell(self, cell: Element):
    method plaintext (line 202) | def plaintext(self):
    method add_cell (line 535) | def add_cell(self, cell: TableCell):
    method __iter__ (line 540) | def __iter__(self):
    method __len__ (line 543) | def __len__(self) -> int:
    method __bool__ (line 546) | def __bool__(self) -> bool:
    method cum_cell_widths (line 550) | def cum_cell_widths(self) -> List[int]:
    method cell_widths (line 554) | def cell_widths(self) -> List[int]:
    method width (line 558) | def width(self) -> int:
    method _hline (line 561) | def _hline(self, orientation: str) -> str:
    method hline_above (line 592) | def hline_above(self) -> str:
    method hline_below (line 596) | def hline_below(self) -> str:
    method plaintext (line 600) | def plaintext(self) -> str:
  class Tabular (line 605) | class Tabular(Element):
    method add_row (line 622) | def add_row(self, row: TableRow) -> TableRow:
    method width (line 628) | def width(self) -> int:
    method cols (line 635) | def cols(self) -> List[List[TableCell]]:
    method _square_table (line 643) | def _square_table(self) -> None:
    method get_table_spec (line 660) | def get_table_spec(self) -> str:
    method plaintext (line 696) | def plaintext(self):
  class Table (line 701) | class Table(Element):
    method add_row (line 219) | def add_row(self, row: TableRow) -> TableRow:
    method plaintext (line 225) | def plaintext(self):

FILE: nougat/dataset/parser/html2md.py
  function check_file_path (line 17) | def check_file_path(paths: List[Path], wdir: Optional[Path] = None) -> L...

FILE: nougat/dataset/parser/latexml_parser.py
  function printerr (line 17) | def printerr(*args, **kwargs):
  function is_wrapper_element (line 43) | def is_wrapper_element(element: BeautifulSoup) -> bool:
  function ignore_element (line 47) | def ignore_element(element: BeautifulSoup) -> bool:
  function _get_classes (line 51) | def _get_classes(el: BeautifulSoup) -> Set[str]:
  function _detach_selected (line 60) | def _detach_selected(element: BeautifulSoup, selector: str) -> None:
  function parse_latexml_authors (line 65) | def parse_latexml_authors(ltx_authors: BeautifulSoup) -> List[Author]:
  function parse_latexml_citations (line 71) | def parse_latexml_citations(cite: BeautifulSoup, parent: Element) -> None:
  function _clean_html_whitespace (line 89) | def _clean_html_whitespace(text: str) -> str:
  function parse_latexml_children (line 98) | def parse_latexml_children(html: BeautifulSoup, parent: Element) -> None:
  function parse_latexml_references (line 420) | def parse_latexml_references(html: BeautifulSoup, doc: Document) -> None:
  function parse_latexml (line 429) | def parse_latexml(

FILE: nougat/dataset/parser/markdown.py
  function remove_trailing_whitespace (line 39) | def remove_trailing_whitespace(parts: List[str]) -> None:
  function remove_line_breaks (line 48) | def remove_line_breaks(parts: List[str]):
  function leading_trailing_whitespace (line 55) | def leading_trailing_whitespace(
  function latex_escape (line 84) | def latex_escape(string: str) -> str:
  function is_empty (line 88) | def is_empty(content: List) -> bool:
  function format_element (line 98) | def format_element(
  function format_iterator (line 330) | def format_iterator(
  function format_children (line 359) | def format_children(
  function format_document (line 367) | def format_document(

FILE: nougat/dataset/pdffigures.py
  function call_pdffigures (line 19) | def call_pdffigures(

FILE: nougat/dataset/rasterize.py
  function rasterize_paper (line 18) | def rasterize_paper(

FILE: nougat/dataset/split_htmls_to_pages.py
  function process_paper (line 29) | def process_paper(
  function process_htmls (line 130) | def process_htmls(args):

FILE: nougat/dataset/split_md_to_pages.py
  function ratio (line 37) | def ratio(*args, **kwargs):
  class BagOfWords (line 41) | class BagOfWords:
    method __init__ (line 51) | def __init__(
    method train (line 60) | def train(self):
    method __call__ (line 77) | def __call__(
  function remove_short_seqs (line 90) | def remove_short_seqs(seqs: List[str], minimum: int = 10) -> List[str]:
  function find_figures (line 99) | def find_figures(
  function flatten (line 136) | def flatten(l: List) -> List:
  function get_doc_text (line 140) | def get_doc_text(
  function clean_pdf_text (line 176) | def clean_pdf_text(pages: List[List[str]], num_words: int = 10) -> List[...
  function split_markdown (line 239) | def split_markdown(

FILE: nougat/dataset/splitter.py
  function ratio (line 18) | def ratio(*args, **kwargs):
  function reverse (line 22) | def reverse(lst: List[str]) -> List[str]:
  function get_first_last (line 37) | def get_first_last(
  function get_glob_index (line 66) | def get_glob_index(
  class Splitter (line 84) | class Splitter:
    method __init__ (line 87) | def __init__(self, paragraphs: List[str]) -> None:
    method remove_special_chars (line 95) | def remove_special_chars(string: str) -> str:
    method count_special_chars (line 129) | def count_special_chars(string: str, char_ind: int) -> int:
    method split_first_last (line 213) | def split_first_last(
    method split (line 280) | def split(
    method _find_match (line 315) | def _find_match(
    method _fuzzy (line 325) | def _fuzzy(
    method fuzzysearch (line 338) | def fuzzysearch(
    method evaluate_split (line 350) | def evaluate_split(self, page_num: int, page_content: str) -> float:

FILE: nougat/dataset/staircase.py
  function stair_func (line 17) | def stair_func(x: np.ndarray, thresholds: np.ndarray) -> np.ndarray:
  function compute_gini (line 21) | def compute_gini(labels: np.ndarray) -> float:
  function compute_binary_gini (line 29) | def compute_binary_gini(labels: np.ndarray) -> float:
  function gini_impurity (line 37) | def gini_impurity(
  function step_impurity (line 87) | def step_impurity(
  class PaddedArray (line 112) | class PaddedArray:
    method __init__ (line 121) | def __init__(
    method __len__ (line 129) | def __len__(self):
    method _process_index (line 132) | def _process_index(self, index):
    method __getitem__ (line 147) | def __getitem__(self, index):
    method __setitem__ (line 151) | def __setitem__(self, index, value):
    method copy (line 154) | def copy(self):
    method toarray (line 157) | def toarray(self):
  class Staircase (line 161) | class Staircase:
    method __init__ (line 170) | def __init__(self, domain: int, n_classes: int) -> None:
    method statistic_fit (line 180) | def statistic_fit(
    method fit (line 216) | def fit(
    method score (line 299) | def score(self):
    method predict (line 307) | def predict(self, x: np.ndarray) -> np.ndarray:
    method __call__ (line 310) | def __call__(self, *args):
    method get_boundaries (line 313) | def get_boundaries(self) -> np.ndarray:

FILE: nougat/dataset/utils/latex_conversion.py
  function remove_style (line 60) | def remove_style(string: str) -> str:
  function replace_duplicate_definitions (line 69) | def replace_duplicate_definitions(string: str) -> str:
  function unicode_to_latex (line 76) | def unicode_to_latex(s: str) -> str:
  function remove_line_breaks (line 108) | def remove_line_breaks(string: str) -> str:
  function normalize_tex (line 113) | def normalize_tex(math: str, inline: bool) -> str:

FILE: nougat/dataset/utils/pdf_text_extract.py
  function replace_ligatures (line 18) | def replace_ligatures(text: str) -> str:
  function remove_hyphens (line 36) | def remove_hyphens(text: str) -> str:
  function dehyphenate (line 59) | def dehyphenate(lines: List[str], line_no: int) -> List[str]:
  function get_pages (line 68) | def get_pages(pdf: str) -> List[str]:
  function get_paragraphs (line 84) | def get_paragraphs(pdf: str) -> List[List[str]]:

FILE: nougat/dataset/utils/utils.py
  function remove_pretty_linebreaks (line 10) | def remove_pretty_linebreaks(string: str) -> str:

FILE: nougat/metrics.py
  function compute_metrics (line 27) | def compute_metrics(pred, gt, minlen=4):
  function get_parser (line 47) | def get_parser():
  function split_text (line 63) | def split_text(pages: List[str]):
  function get_metrics (line 86) | def get_metrics(gt: List[str], pred: List[str], pool: bool = True):

FILE: nougat/model.py
  class SwinEncoder (line 37) | class SwinEncoder(nn.Module):
    method __init__ (line 52) | def __init__(
    method forward (line 116) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method crop_margin (line 127) | def crop_margin(img: Image.Image) -> Image.Image:
    method to_tensor (line 142) | def to_tensor(self):
    method prepare_input (line 148) | def prepare_input(
  class BARTDecoder (line 191) | class BARTDecoder(nn.Module):
    method __init__ (line 207) | def __init__(
    method add_special_tokens (line 271) | def add_special_tokens(self, list_of_tokens: List[str]):
    method prepare_inputs_for_inference (line 281) | def prepare_inputs_for_inference(
    method forward (line 312) | def forward(
    method resize_bart_abs_pos_emb (line 337) | def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> ...
  class NougatConfig (line 359) | class NougatConfig(PretrainedConfig):
    method __init__ (line 385) | def __init__(
  class RunningVarTorch (line 418) | class RunningVarTorch:
    method __init__ (line 419) | def __init__(self, L=15, norm=False):
    method push (line 424) | def push(self, x: torch.Tensor):
    method variance (line 433) | def variance(self):
  class StoppingCriteriaScores (line 442) | class StoppingCriteriaScores(StoppingCriteria):
    method __init__ (line 443) | def __init__(self, threshold: float = 0.015, window_size: int = 200):
    method __call__ (line 454) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function batch (line 477) | def batch(l, b=15):
  function subdiv (line 484) | def subdiv(l, b=10):
  class NougatModel (line 491) | class NougatModel(PreTrainedModel):
    method __init__ (line 501) | def __init__(self, config: NougatConfig):
    method forward (line 521) | def forward(
    method _init_weights (line 544) | def _init_weights(self, *args, **kwargs):
    method inference (line 547) | def inference(
    method from_pretrained (line 671) | def from_pretrained(

FILE: nougat/postprocessing.py
  function ratio (line 18) | def ratio(*args, **kwargs):
  function markdown_compatible (line 25) | def markdown_compatible(s: str) -> str:
  function find_next_punctuation (line 70) | def find_next_punctuation(s: str, start_inx=0):
  function find_last_punctuation (line 86) | def find_last_punctuation(s: str, start_inx=0):
  function truncate_repetitions (line 102) | def truncate_repetitions(s: str, min_len=30):
  function close_envs (line 168) | def close_envs(s: str) -> str:
  function remove_numbers (line 178) | def remove_numbers(lines):
  function get_slices (line 190) | def get_slices(lines, clean_lines):
  function remove_slice_from_lines (line 233) | def remove_slice_from_lines(lines, clean_text, sli) -> str:
  function remove_hallucinated_references (line 301) | def remove_hallucinated_references(text: str) -> str:
  function postprocess_single (line 332) | def postprocess_single(generation: str, markdown_fix: bool = True) -> str:
  function postprocess (line 487) | def postprocess(

FILE: nougat/transforms.py
  function alb_wrapper (line 16) | def alb_wrapper(transform):
  class Erosion (line 23) | class Erosion(alb.ImageOnlyTransform):
    method __init__ (line 41) | def __init__(self, scale, always_apply=False, p=0.5):
    method apply (line 49) | def apply(self, img, **params):
  class Dilation (line 57) | class Dilation(alb.ImageOnlyTransform):
    method __init__ (line 75) | def __init__(self, scale, always_apply=False, p=0.5):
    method apply (line 83) | def apply(self, img, **params):
  class Bitmap (line 91) | class Bitmap(alb.ImageOnlyTransform):
    method __init__ (line 107) | def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
    method apply (line 112) | def apply(self, img, **params):

FILE: nougat/utils/checkpoint.py
  function download_as_bytes_with_progress (line 20) | def download_as_bytes_with_progress(url: str, name: str = None) -> bytes:
  function download_checkpoint (line 49) | def download_checkpoint(checkpoint: Path, model_tag: str = MODEL_TAG):
  function torch_hub (line 74) | def torch_hub(model_tag: Optional[str] = MODEL_TAG) -> Path:
  function get_checkpoint (line 85) | def get_checkpoint(

FILE: nougat/utils/dataset.py
  class ImageDataset (line 25) | class ImageDataset(torch.utils.data.Dataset):
    method __init__ (line 40) | def __init__(self, img_list, prepare: Callable):
    method __len__ (line 45) | def __len__(self):
    method ignore_none_collate (line 49) | def ignore_none_collate(batch):
    method __getitem__ (line 60) | def __getitem__(self, idx):
  class LazyDataset (line 68) | class LazyDataset(Dataset):
    method __init__ (line 83) | def __init__(self, pdf, prepare: Callable, pages: Optional[List[int]] ...
    method __len__ (line 91) | def __len__(self):
    method __getitem__ (line 94) | def __getitem__(self, i):
    method ignore_none_collate (line 103) | def ignore_none_collate(batch):
  class SciPDFDataset (line 125) | class SciPDFDataset(Dataset):
    method __init__ (line 144) | def __init__(
    method __len__ (line 172) | def __len__(self) -> int:
    method __getitem__ (line 175) | def __getitem__(self, index: int) -> Dict:
    method __iter__ (line 203) | def __iter__(self):
  class NougatDataset (line 208) | class NougatDataset(Dataset):
    method __init__ (line 214) | def __init__(
    method __len__ (line 234) | def __len__(self) -> int:
    method __getitem__ (line 237) | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:

FILE: nougat/utils/device.py
  function default_batch_size (line 11) | def default_batch_size():
  function move_to_device (line 28) | def move_to_device(model, bf16: bool = True, cuda: bool = True):

FILE: predict.py
  function get_args (line 28) | def get_args():
  function main (line 125) | def main():

FILE: setup.py
  function read_version (line 14) | def read_version():
  function read_long_description (line 22) | def read_long_description():

FILE: test.py
  function test (line 27) | def test(args):

FILE: train.py
  class CustomCheckpointIO (line 42) | class CustomCheckpointIO(CheckpointIO):
    method save_checkpoint (line 62) | def save_checkpoint(self, checkpoint, path, storage_options=None):
    method load_checkpoint (line 73) | def load_checkpoint(self, path, storage_options=None):
    method remove_checkpoint (line 101) | def remove_checkpoint(self, path) -> None:
  class GradNormCallback (line 105) | class GradNormCallback(Callback):
    method gradient_norm (line 111) | def gradient_norm(model):
    method on_after_backward (line 120) | def on_after_backward(self, trainer, model):
  function save_config_file (line 125) | def save_config_file(config, path):
  function train (line 135) | def train(config):
Condensed preview — 47 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,652K chars).
[
  {
    "path": ".gitignore",
    "chars": 1941,
    "preview": "core.*\n*.bin\n.nfs*\n.vscode/*\nresult/*\n!result/extract.py\nmisc/*\nwandb/\n!misc/*.png\n!dataset/gen_seek.py\n!result/.gitkeep"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3537,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 569,
    "preview": "# Contributing to Nougat\n\n## Pull Requests\n\nIn order to accept your pull request, we need you to submit a CLA. You only "
  },
  {
    "path": "LICENSE",
    "chars": 1088,
    "preview": "MIT License\n\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nPermission is hereby granted, free of charge, to any pe"
  },
  {
    "path": "LICENSE-MODEL.md",
    "chars": 13567,
    "preview": "# Creative Commons Attribution-NonCommercial 4.0 International Public License\n\nBy exercising the Licensed Rights (define"
  },
  {
    "path": "MANIFEST.in",
    "chars": 14,
    "preview": "include ./*.*\n"
  },
  {
    "path": "NOTICE",
    "chars": 8932,
    "preview": "Donut\nCopyright (c) 2022-present NAVER Corp.\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "README.md",
    "chars": 7995,
    "preview": "<div align=\"center\">\n<h1>Nougat: Neural Optical Understanding for Academic Documents</h1>\n\n[![Paper](https://img.shields"
  },
  {
    "path": "app.py",
    "chars": 5214,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "config/train_nougat.yaml",
    "chars": 748,
    "preview": "resume_from_checkpoint_path: null\nresult_path: \"result\"\nmodel_path: null\ndataset_paths: [\"path/to/train.jsonl\"]\ntokenize"
  },
  {
    "path": "docker/Dockerfile",
    "chars": 766,
    "preview": "FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04\n# replace CUDA version to your CUDA version.\n# You can check your CUDA "
  },
  {
    "path": "docker/README.md",
    "chars": 2263,
    "preview": "## Prerequisites\nEnsure you have Docker installed on your machine. \nAnd you must also have NVIDIA CUDA and CuDNN install"
  },
  {
    "path": "lightning_module.py",
    "chars": 9121,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nimpo"
  },
  {
    "path": "nougat/__init__.py",
    "chars": 311,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nfrom"
  },
  {
    "path": "nougat/_version.py",
    "chars": 204,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nougat/dataset/create_index.py",
    "chars": 5482,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/gen_seek.py",
    "chars": 1015,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/parser/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nougat/dataset/parser/document.py",
    "chars": 19783,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/parser/html2md.py",
    "chars": 2220,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/parser/latexml_parser.py",
    "chars": 18357,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/parser/markdown.py",
    "chars": 15343,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/pdffigures.py",
    "chars": 2300,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/rasterize.py",
    "chars": 2806,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/split_htmls_to_pages.py",
    "chars": 7717,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/split_md_to_pages.py",
    "chars": 17432,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/splitter.py",
    "chars": 13880,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/staircase.py",
    "chars": 10486,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/tokenizer.json",
    "chars": 2068443,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": {\n    \"direction\": \"Right\",\n    \"max_length\": 4096,\n    \"strategy\": \"LongestFirst\""
  },
  {
    "path": "nougat/dataset/utils/__init__.py",
    "chars": 273,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/utils/latex_conversion.py",
    "chars": 4234,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/utils/pdf_text_extract.py",
    "chars": 2531,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/dataset/utils/utils.py",
    "chars": 528,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/metrics.py",
    "chars": 3844,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/model.py",
    "chars": 26035,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nimpo"
  },
  {
    "path": "nougat/postprocessing.py",
    "chars": 16892,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/transforms.py",
    "chars": 5986,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nougat/utils/checkpoint.py",
    "chars": 3890,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "nougat/utils/dataset.py",
    "chars": 9275,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nimpo"
  },
  {
    "path": "nougat/utils/device.py",
    "chars": 1198,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "predict.py",
    "chars": 7439,
    "preview": "\"\"\"\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nThis source code is licensed under the MIT license found in the\n"
  },
  {
    "path": "setup.cfg",
    "chars": 39,
    "preview": "[metadata]\ndescription_file = README.md"
  },
  {
    "path": "setup.py",
    "chars": 2775,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\n\nimp"
  },
  {
    "path": "test.py",
    "chars": 3810,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nimpo"
  },
  {
    "path": "train.py",
    "chars": 7430,
    "preview": "\"\"\"\nDonut\nCopyright (c) 2022-present NAVER Corp.\nMIT License\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\"\"\"\nimpo"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/nougat GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 47 files (2.2 MB), approximately 586.5k tokens, and a symbol index with 289 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!