Repository: opendatalab/PDF-Extract-Kit Branch: main Commit: fdb25fd4bd90 Files: 143 Total size: 459.0 KB Directory structure: gitextract_dsl5uwwk/ ├── .gitignore ├── .readthedocs.yaml ├── .vscode/ │ └── launch.json ├── LICENSE.md ├── README.md ├── README_zh-CN.md ├── configs/ │ ├── config.yaml │ ├── formula_detection.yaml │ ├── formula_recognition.yaml │ ├── layout_detection.yaml │ ├── layout_detection_layoutlmv3.yaml │ ├── layout_detection_yolo.yaml │ ├── ocr.yaml │ └── table_parsing.yaml ├── docs/ │ ├── en/ │ │ ├── .readthedocs.yaml │ │ ├── Makefile │ │ ├── algorithm/ │ │ │ ├── formula_detection.rst │ │ │ ├── formula_recognition.rst │ │ │ ├── layout_detection.rst │ │ │ ├── ocr.rst │ │ │ ├── reading_order.rst │ │ │ └── table_recognition.rst │ │ ├── conf copy.py │ │ ├── conf.bak │ │ ├── conf.py │ │ ├── evaluation/ │ │ │ ├── formula_detection.rst │ │ │ ├── formula_recognition.rst │ │ │ ├── layout_detection.rst │ │ │ ├── ocr.rst │ │ │ ├── pdf_extract.rst │ │ │ ├── reading_order.rst │ │ │ └── table_recognition.rst │ │ ├── get_started/ │ │ │ ├── installation.rst │ │ │ ├── pretrained_model.rst │ │ │ └── quickstart.rst │ │ ├── index.rst │ │ ├── make.bat │ │ ├── models/ │ │ │ └── supported.md │ │ ├── notes/ │ │ │ └── changelog.md │ │ ├── project/ │ │ │ ├── doc_translate.rst │ │ │ ├── pdf_extract.rst │ │ │ └── speed_up.rst │ │ ├── switch_language.md │ │ └── task_extend/ │ │ ├── code.rst │ │ ├── doc.rst │ │ └── evaluation.rst │ ├── requirements.txt │ └── zh_cn/ │ ├── .readthedocs.yaml │ ├── Makefile │ ├── algorithm/ │ │ ├── formula_detection.rst │ │ ├── formula_recognition.rst │ │ ├── layout_detection.rst │ │ ├── ocr.rst │ │ ├── reading_order.rst │ │ └── table_recognition.rst │ ├── conf.py │ ├── evaluation/ │ │ ├── formula_detection.rst │ │ ├── formula_recognition.rst │ │ ├── layout_detection.rst │ │ ├── ocr.rst │ │ ├── pdf_extract.rst │ │ ├── reading_order.rst │ │ └── table_recognition.rst │ ├── get_started/ │ │ ├── installation.rst │ │ ├── pretrained_model.rst │ │ └── quickstart.rst │ ├── index.rst │ ├── make.bat │ ├── models/ │ │ └── supported.md │ ├── notes/ │ │ └── changelog.md │ ├── project/ │ │ ├── doc_translate.rst │ │ ├── pdf_extract.rst │ │ └── speed_up.rst │ ├── switch_language.md │ └── task_extend/ │ ├── code.rst │ ├── doc.rst │ └── evaluation.rst ├── pdf_extract_kit/ │ ├── __init__.py │ ├── configs/ │ │ └── unimernet.yaml │ ├── dataset/ │ │ ├── __init__.py │ │ └── dataset.py │ ├── registry/ │ │ ├── __init__.py │ │ └── registry.py │ ├── tasks/ │ │ ├── __init__.py │ │ ├── base_task.py │ │ ├── formula_detection/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ └── yolo.py │ │ │ └── task.py │ │ ├── formula_recognition/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ └── unimernet.py │ │ │ └── task.py │ │ ├── layout_detection/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ ├── layoutlmv3.py │ │ │ │ ├── layoutlmv3_util/ │ │ │ │ │ ├── backbone.py │ │ │ │ │ ├── beit.py │ │ │ │ │ ├── deit.py │ │ │ │ │ ├── layoutlmft/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── data/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── cord.py │ │ │ │ │ │ │ ├── data_collator.py │ │ │ │ │ │ │ ├── funsd.py │ │ │ │ │ │ │ ├── image_utils.py │ │ │ │ │ │ │ └── xfund.py │ │ │ │ │ │ └── models/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── layoutlmv3/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── configuration_layoutlmv3.py │ │ │ │ │ │ ├── modeling_layoutlmv3.py │ │ │ │ │ │ ├── tokenization_layoutlmv3.py │ │ │ │ │ │ └── tokenization_layoutlmv3_fast.py │ │ │ │ │ ├── layoutlmv3_base_inference.yaml │ │ │ │ │ ├── model_init.py │ │ │ │ │ ├── rcnn_vl.py │ │ │ │ │ └── visualizer.py │ │ │ │ └── yolo.py │ │ │ └── task.py │ │ ├── ocr/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ └── paddle_ocr.py │ │ │ └── task.py │ │ └── table_parsing/ │ │ ├── __init__.py │ │ ├── models/ │ │ │ └── struct_eqtable.py │ │ └── task.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── config_loader.py │ │ ├── data_preprocess.py │ │ ├── merge_blocks_and_spans.py │ │ ├── pdf_utils.py │ │ └── visualization.py │ └── version.py ├── project/ │ └── pdf2markdown/ │ ├── README.md │ ├── configs/ │ │ └── pdf2markdown.yaml │ └── scripts/ │ ├── pdf2markdown.py │ └── run_project.py ├── pyproject.toml ├── requirements/ │ └── docs.txt ├── requirements-cpu.txt ├── requirements.txt └── scripts/ ├── formula_detection.py ├── formula_recognition.py ├── layout_detection.py ├── ocr.py ├── run_task.py └── table_parsing.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.ipynb* *.ipynb # local data outputs/* data/* temp* test* # python .ipynb_checkpoints *.ipynb **/__pycache__/ # logs *.log *.out models/* # Sphinx documentation docs/*/_build/ ================================================ FILE: .readthedocs.yaml ================================================ version: 2 build: os: ubuntu-22.04 tools: python: "3.10" formats: - epub python: install: - requirements: requirements/docs.txt sphinx: configuration: docs/zh_cn/conf.py ================================================ FILE: .vscode/launch.json ================================================ { // 使用 IntelliSense 了解相关属性。 // 悬停以查看现有属性的描述。 // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { "name": "run_mfd", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/run_mfd.py", "console": "integratedTerminal", "args": [ "--config", "configs/config_mfd.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } }, { "name": "run_formula_recognition", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/formula_recognition.py", "console": "integratedTerminal", "args": [ "--config", "configs/formula_recognition.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } }, { "name": "run_ocr", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/ocr.py", "console": "integratedTerminal", "args": [ "--config", "configs/ocr.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } }, { "name": "run_formula_detection", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/formula_detection.py", "console": "integratedTerminal", "args": [ "--config", "configs/formula_detection.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } }, { "name": "run_layout_detection", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/layout_detection.py", "console": "integratedTerminal", "args": [ "--config", "configs/layout_detection.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } }, { "name": "run_layout_detection_layoutlmv3", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/scripts/layout_detection.py", "console": "integratedTerminal", "args": [ "--config", "configs/layout_detection_layoutlmv3.yaml" ], "env": { "PYTHONPATH": "/Users/bin/anaconda3/envs/mfd_test" } } ] } ================================================ FILE: LICENSE.md ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . ================================================ FILE: README.md ================================================

English | [简体中文](./README_zh-CN.md) [PDF-Extract-Kit-1.0 Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html) [[Models (🤗Hugging Face)]](https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0) | [[Models(ModelScope)]](https://www.modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0) 🔥🔥🔥 [MinerU: Efficient Document Content Extraction Tool Based on PDF-Extract-Kit](https://github.com/opendatalab/MinerU)

👋 join us on Discord and WeChat

## Overview `PDF-Extract-Kit` is a powerful open-source toolkit designed to efficiently extract high-quality content from complex and diverse PDF documents. Here are its main features and advantages: - **Integration of Leading Document Parsing Models**: Incorporates state-of-the-art models for layout detection, formula detection, formula recognition, OCR, and other core document parsing tasks. - **High-Quality Parsing Across Diverse Documents**: Fine-tuned with diverse document annotation data to deliver high-quality results across various complex document types. - **Modular Design**: The flexible modular design allows users to easily combine and construct various applications by modifying configuration files and minimal code, making application building as straightforward as stacking blocks. - **Comprehensive Evaluation Benchmarks**: Provides diverse and comprehensive PDF evaluation benchmarks, enabling users to choose the most suitable model based on evaluation results. **Experience PDF-Extract-Kit now and unlock the limitless potential of PDF documents!** > **Note:** PDF-Extract-Kit is designed for high-quality document processing and functions as a model toolbox. > If you are interested in extracting high-quality document content (e.g., converting PDFs to Markdown), please use [MinerU](https://github.com/opendatalab/MinerU), which combines the high-quality predictions from PDF-Extract-Kit with specialized engineering optimizations for more convenient and efficient content extraction. > If you're a developer looking to create engaging applications such as document translation, document Q&A, or document assistants, you'll find it very convenient to build your own projects using PDF-Extract-Kit. In particular, we will periodically update the PDF-Extract-Kit/project directory with interesting applications, so stay tuned! **We welcome researchers and engineers from the community to contribute outstanding models and innovative applications by submitting PRs to become contributors to the PDF-Extract-Kit project.** ## Model Overview | **Task Type** | **Description** | **Models** | |-------------------|---------------------------------------------------------------------------------|-------------------------------| | **Layout Detection** | Locate different elements in a document: including images, tables, text, titles, formulas | `DocLayout-YOLO_ft`, `YOLO-v10_ft`, `LayoutLMv3_ft` | | **Formula Detection** | Locate formulas in documents: including inline and block formulas | `YOLOv8_ft` | | **Formula Recognition** | Recognize formula images into LaTeX source code | `UniMERNet` | | **OCR** | Extract text content from images (including location and recognition) | `PaddleOCR` | | **Table Recognition** | Recognize table images into corresponding source code (LaTeX/HTML/Markdown) | `PaddleOCR+TableMaster`, `StructEqTable` | | **Reading Order** | Sort and concatenate discrete text paragraphs | Coming Soon! | ## News and Updates - `2024.10.22` 🎉🎉🎉 We are excited to announce that table recognition model [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B), which supports output LaTeX, HTML and MarkdDown formats has been officially integrated into `PDF-Extract-Kit 1.0`. Please refer to the [table recognition algorithm documentation](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/table_recognition.html) for usage instructions! - `2024.10.17` 🎉🎉🎉 We are excited to announce that the more accurate and faster layout detection model, [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO), has been officially integrated into `PDF-Extract-Kit 1.0`. Please refer to the [layout detection algorithm documentation](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/layout_detection.html) for usage instructions! - `2024.10.10` 🎉🎉🎉 The official release of `PDF-Extract-Kit 1.0`, rebuilt with modularity for more convenient and flexible model usage! Please switch to the [release/0.1.1](https://github.com/opendatalab/PDF-Extract-Kit/tree/release/0.1.1) branch for the old version. - `2024.08.01` 🎉🎉🎉 Added the [StructEqTable](demo/TabRec/StructEqTable/README_TABLE.md) module for table content extraction. Welcome to use it! - `2024.07.01` 🎉🎉🎉 We released `PDF-Extract-Kit`, a comprehensive toolkit for high-quality PDF content extraction, including `Layout Detection`, `Formula Detection`, `Formula Recognition`, and `OCR`. ## Performance Demonstration Many current open-source SOTA models are trained and evaluated on academic datasets, achieving high-quality results only on single document types. To enable models to achieve stable and robust high-quality results on diverse documents, we constructed diverse fine-tuning datasets and fine-tuned some SOTA models to obtain practical parsing models. Below are some visual results of the models. ### Layout Detection We trained robust `Layout Detection` models using diverse PDF document annotations. Our fine-tuned models achieve accurate extraction results on diverse PDF documents such as papers, textbooks, research reports, and financial reports, and demonstrate high robustness to challenges like blurring and watermarks. The visualization example below shows the inference results of the fine-tuned LayoutLMv3 model. ![](assets/readme/layout_example.png) ### Formula Detection Similarly, we collected and annotated documents containing formulas in both English and Chinese, and fine-tuned advanced formula detection models. The visualization result below shows the inference results of the fine-tuned YOLO formula detection model: ![](assets/readme/mfd_example.png) ### Formula Recognition [UniMERNet](https://github.com/opendatalab/UniMERNet) is an algorithm designed for diverse formula recognition in real-world scenarios. By constructing large-scale training data and carefully designed results, it achieves excellent recognition performance for complex long formulas, handwritten formulas, and noisy screenshot formulas. ### Table Recognition [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) is a high efficiency toolkit that can converts table images into LaTeX/HTML/MarkDown. The latest version, powered by the InternVL2-1B foundation model, improves Chinese recognition accuracy and expands multi-format output options. #### For more visual and inference results of the models, please refer to the [PDF-Extract-Kit tutorial documentation](xxx). ## Evaluation Metrics Coming Soon! ## Usage Guide ### Environment Setup ```bash conda create -n pdf-extract-kit-1.0 python=3.10 conda activate pdf-extract-kit-1.0 pip install -r requirements.txt ``` > **Note:** If your device does not support GPU, please install the CPU version dependencies using `requirements-cpu.txt` instead of `requirements.txt`. > **Note:** Current Doclayout-YOLO only supports installation from pypi,if error raises during DocLayout-YOLO installation,please install through `pip3 install doclayout-yolo==0.0.2 --extra-index-url=https://pypi.org/simple` . ### Model Download Please refer to the [Model Weights Download Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html) to download the required model weights. Note: You can choose to download all the weights or select specific ones. For detailed instructions, please refer to the tutorial. ### Running Demos #### Layout Detection Model ```bash python scripts/layout_detection.py --config=configs/layout_detection.yaml ``` Layout detection models support **DocLayout-YOLO** (default model), YOLO-v10, and LayoutLMv3. For YOLO-v10 and LayoutLMv3, please refer to [Layout Detection Algorithm](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/layout_detection.html). You can view the layout detection results in the `outputs/layout_detection` folder. #### Formula Detection Model ```bash python scripts/formula_detection.py --config=configs/formula_detection.yaml ``` You can view the formula detection results in the `outputs/formula_detection` folder. #### OCR Model ```bash python scripts/ocr.py --config=configs/ocr.yaml ``` You can view the OCR results in the `outputs/ocr` folder. #### Formula Recognition Model ```bash python scripts/formula_recognition.py --config=configs/formula_recognition.yaml ``` You can view the formula recognition results in the `outputs/formula_recognition` folder. #### Table Recognition Model ```bash python scripts/table_parsing.py --config configs/table_parsing.yaml ``` You can view the table recognition results in the `outputs/table_parsing` folder. > **Note:** For more details on using the model, please refer to the[PDF-Extract-Kit-1.0 Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html). > This project focuses on using models for `high-quality` content extraction from `diverse` documents and does not involve reconstructing extracted content into new documents, such as PDF to Markdown. For such needs, please refer to our other GitHub project: [MinerU](https://github.com/opendatalab/MinerU). ## To-Do List - [x] **Table Parsing**: Develop functionality to convert table images into corresponding LaTeX/Markdown format source code. - [ ] **Chemical Equation Detection**: Implement automatic detection of chemical equations. - [ ] **Chemical Equation/Diagram Recognition**: Develop models to recognize and parse chemical equations and diagrams. - [ ] **Reading Order Sorting Model**: Build a model to determine the correct reading order of text in documents. **PDF-Extract-Kit** aims to provide high-quality PDF content extraction capabilities. We encourage the community to propose specific and valuable needs and welcome everyone to participate in continuously improving the PDF-Extract-Kit tool to advance research and industry development. ## License This project is open-sourced under the [AGPL-3.0](LICENSE) license. Since this project uses YOLO code and PyMuPDF for file processing, these components require compliance with the AGPL-3.0 license. Therefore, to ensure adherence to the licensing requirements of these dependencies, this repository as a whole adopts the AGPL-3.0 license. ## Acknowledgement - [LayoutLMv3](https://github.com/microsoft/unilm/tree/master/layoutlmv3): Layout detection model - [UniMERNet](https://github.com/opendatalab/UniMERNet): Formula recognition model - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy): Table recognition model - [YOLO](https://github.com/ultralytics/ultralytics): Formula detection model - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR): OCR model - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO): Layout detection model ## Citation If you find our models / code / papers useful in your research, please consider giving ⭐ and citations 📝, thx :) ```bibtex @article{wang2024mineru, title={MinerU: An Open-Source Solution for Precise Document Content Extraction}, author={Wang, Bin and Xu, Chao and Zhao, Xiaomeng and Ouyang, Linke and Wu, Fan and Zhao, Zhiyuan and Xu, Rui and Liu, Kaiwen and Qu, Yuan and Shang, Fukai and others}, journal={arXiv preprint arXiv:2409.18839}, year={2024} } @misc{zhao2024doclayoutyoloenhancingdocumentlayout, title={DocLayout-YOLO: Enhancing Document Layout Analysis through Diverse Synthetic Data and Global-to-Local Adaptive Perception}, author={Zhiyuan Zhao and Hengrui Kang and Bin Wang and Conghui He}, year={2024}, eprint={2410.12628}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2410.12628}, } @misc{wang2024unimernet, title={UniMERNet: A Universal Network for Real-World Mathematical Expression Recognition}, author={Bin Wang and Zhuangcheng Gu and Chao Xu and Bo Zhang and Botian Shi and Conghui He}, year={2024}, eprint={2404.15254}, archivePrefix={arXiv}, primaryClass={cs.CV} } @article{he2024opendatalab, title={Opendatalab: Empowering general artificial intelligence with open datasets}, author={He, Conghui and Li, Wei and Jin, Zhenjiang and Xu, Chao and Wang, Bin and Lin, Dahua}, journal={arXiv preprint arXiv:2407.13773}, year={2024} } ``` ## Star History Star History Chart ## Related Links - [UniMERNet (Real-World Formula Recognition Algorithm)](https://github.com/opendatalab/UniMERNet) - [LabelU (Lightweight Multimodal Annotation Tool)](https://github.com/opendatalab/labelU) - [LabelLLM (Open Source LLM Dialogue Annotation Platform)](https://github.com/opendatalab/LabelLLM) - [MinerU (One-Stop High-Quality Data Extraction Tool)](https://github.com/opendatalab/MinerU) ================================================ FILE: README_zh-CN.md ================================================

[English](./README.md) | 简体中文 [PDF-Extract-Kit-1.0中文教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html) [[Models (🤗Hugging Face)]](https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0) | [[Models(ModelScope)]](https://www.modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0) 🔥🔥🔥 [MinerU:基于PDF-Extract-Kit的高效文档内容提取工具](https://github.com/opendatalab/MinerU)

👋 join us on Discord and WeChat

## 整体介绍 `PDF-Extract-Kit` 是一款功能强大的开源工具箱,旨在从复杂多样的 PDF 文档中高效提取高质量内容。以下是其主要功能和优势: - **集成文档解析主流模型**:汇聚布局检测、公式检测、公式识别、OCR等文档解析核心任务的众多SOTA模型; - **多样性文档下高质量解析结果**:结合多样性文档标注数据在进行模型微调,在复杂多样的文档下提供高质量解析结果; - **模块化设计**:模块化设计使用户可以通过修改配置文件及少量代码即可自由组合构建各种应用,让应用构建像搭积木一样简便; - **全面评测基准**:提供多样性全面的PDF评测基准,用户可根据评测结果选择最适合自己的模型。 **立即体验 PDF-Extract-Kit,解锁 PDF 文档的无限潜力!** > **注意:** PDF-Extract-Kit 专注于高质量文档处理,适合作为模型工具箱使用。 > 如果你想提取高质量文档内容(PDF转Markdown),请直接使用[MinerU](https://github.com/opendatalab/MinerU),MinerU结合PDF-Extract-Kit的高质量预测结果,进行了专门的工程优化,使得PDF文档内容提取更加便捷高效; > 如果你是一位开发者,希望搭建更多有意思的应用(如文档翻译,文档问答,文档助手等),基于PDF-Extract-Kit自行进行DIY将会十分便捷。特别地,我们会在`PDF-Extract-Kit/project`下面不定期更新一些有趣的应用,敬请期待! **我们欢迎社区研究员和工程师贡献优秀模型和创新应用,通过提交 PR 成为 PDF-Extract-Kit 的贡献者。** ## 模型概览 | **任务类型** | **任务描述** | **模型** | |--------------|---------------------------------------------------------------------------------|------------------------------| | **布局检测** | 定位文档中不同元素位置:包含图像、表格、文本、标题、公式等 | `DocLayout-YOLO_ft`, `YOLO-v10_ft`, `LayoutLMv3_ft` | | **公式检测** | 定位文档中公式位置:包含行内公式和行间公式 | `YOLOv8_ft` | | **公式识别** | 识别公式图像为latex源码 | `UniMERNet` | | **OCR** | 提取图像中的文本内容(包括定位和识别) | `PaddleOCR` | | **表格识别** | 识别表格图像为对应源码(Latex/HTML/Markdown) | `PaddleOCR+TableMaster`,`StructEqTable` | | **阅读顺序** | 将离散的文本段落进行排序拼接 | Coming Soon ! | ## 新闻和更新 - `2024.10.22` 🎉🎉🎉 支持LaTex和HTML等多种输出格式的表格模型[StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B)正式接入`PDF-Extract-Kit 1.0`,请参考[表格识别算法文档](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/table_recognition.html)进行使用! - `2024.10.17` 🎉🎉🎉 检测结果更准确,速度更快的布局检测模型`DocLayout-YOLO`正式接入`PDF-Extract-Kit 1.0`,请参考[布局检测算法文档](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/layout_detection.html)进行使用! - `2024.10.10` 🎉🎉🎉 基于模块化重构的`PDF-Extract-Kit 1.0`正式版本正式发布,模型使用更加便捷灵活!老版本请切换至[release/0.1.1](https://github.com/opendatalab/PDF-Extract-Kit/tree/release/0.1.1)分支进行使用。 - `2024.08.01` 🎉🎉🎉 新增了[StructEqTable](demo/TabRec/StructEqTable/README_TABLE.md)表格识别模块用于表格内容提取,欢迎使用! - `2024.07.01` 🎉🎉🎉 我们发布了`PDF-Extract-Kit`,一个用于高质量PDF内容提取的综合工具包,包括`布局检测`、`公式检测`、`公式识别`和`OCR`。 ## 效果展示 当前的一些开源SOTA模型多基于学术数据集进行训练评测,仅能在单一的文档类型上获取高质量结果。为了使得模型能够在多样性文档上也能获得稳定鲁棒的高质量结果,我们构建多样性的微调数据集,并在一些SOTA模型上微调已得到可实用解析模型。下边是一些模型的可视化结果。 ### 布局检测 结合多样性PDF文档标注,我们训练了鲁棒的`布局检测`模型。在论文、教材、研报、财报等多样性的PDF文档上,我们微调后的模型都能得到准确的提取结果,对于扫描模糊、水印等情况也有较高鲁棒性。下面可视化示例是经过微调后的LayoutLMv3模型的推理结果。 ![](assets/readme/layout_example.png) ### 公式检测 同样的,我们收集了包含公式的中英文文档进行标注,基于先进的公式检测模型进行微调,下面可视化结果是微调后的YOLO公式检测模型的推理结果: ![](assets/readme/mfd_example.png) ### 公式识别 [UniMERNet](https://github.com/opendatalab/UniMERNet)是针对真实场景下多样性公式识别的算法,通过构建大规模训练数据及精心设计的结果,使得其可以对复杂长公式、手写公式、含噪声的截图公式均有不错的识别效果。 ### 表格识别 [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)是一个高效表格内容提取工具,能够将表格图像转换为LaTeX/HTML/Markdown格式,最新版本使用InternVL2-1B基础模型,提高了中文识别准确度并增加了多格式输出能力。 #### 更多模型的可视化结果及推理结果可以参考[PDF-Extract-Kit教程文档](xxx) ## 评测指标 Coming Soon! ## 使用教程 ### 环境安装 ```bash conda create -n pdf-extract-kit-1.0 python=3.10 conda activate pdf-extract-kit-1.0 pip install -r requirements.txt ``` > **注意:** 如果你的设备不支持 GPU,请使用 `requirements-cpu.txt` 安装 CPU 版本的依赖。 > **注意:** 目前doclayout-yolo仅支持从pypi源安装,如果出现doclayout-yolo无法安装,请通过 `pip3 install doclayout-yolo==0.0.2 --extra-index-url=https://pypi.org/simple` 安装。 ### 模型下载 参考[模型权重下载教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html)下载所需模型权重。注:可以选择全部下载,也可以选择部分下载,具体操作参考教程。 ### Demo运行 #### 布局检测模型 ```bash python scripts/layout_detection.py --config=configs/layout_detection.yaml ``` 布局检测模型支持**DocLayout-YOLO**(默认模型),YOLO-v10,以及LayoutLMv3。对于YOLO-v10和LayoutLMv3的布局检测,请参考[Layout Detection Algorithm](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/layout_detection.html)。你可以在 `outputs/layout_detection` 文件夹下查看布局检测结果。 #### 公式检测模型 ```bash python scripts/formula_detection.py --config=configs/formula_detection.yaml ``` 你可以在 `outputs/formula_detection` 文件夹下查看公式检测结果。 #### 文本识别(OCR)模型 ```bash python scripts/ocr.py --config=configs/ocr.yaml ``` 你可以在 `outputs/ocr` 文件夹下查看OCR结果。 #### 公式识别模型 ```bash python scripts/formula_recognition.py --config=configs/formula_recognition.yaml ``` 你可以在 `outputs/formula_recognition` 文件夹下查看公式识别结果。 #### 表格识别模型 ```bash python scripts/table_parsing.py --config configs/table_parsing.yaml ``` 你可以在 `outputs/table_parsing` 文件夹下查看表格内容识别结果。 > **注意:** 更多模型使用细节请查看[PDF-Extract-Kit-1.0 中文教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html). > 本项目专注使用模型对`多样性`文档进行`高质量`内容提取,不涉及提取后内容拼接成新文档,如PDF转Markdown。如果有此类需求,请参考我们另一个Github项目: [MinerU](https://github.com/opendatalab/MinerU) ## 待办事项 - [x] **表格解析**:开发能够将表格图像转换成对应的LaTeX/Markdown格式源码的功能。 - [ ] **化学方程式检测**:实现对化学方程式的自动检测。 - [ ] **化学方程式/图解识别**:开发识别并解析化学方程式的模型。 - [ ] **阅读顺序排序模型**:构建模型以确定文档中文本的正确阅读顺序。 **PDF-Extract-Kit** 旨在提供高质量PDF文件的提取能力。我们鼓励社区提出具体且有价值的需求,并欢迎大家共同参与,以不断改进PDF-Extract-Kit工具,推动科研及产业发展。 ## 协议 本项目采用 [AGPL-3.0](LICENSE) 协议开源。 由于本项目中使用了 YOLO 代码和 PyMuPDF 进行文件处理,这些组件都需要遵循 AGPL-3.0 协议。因此,为了确保遵守这些依赖项的许可证要求,本仓库整体采用 AGPL-3.0 协议。 ## 致谢 - [LayoutLMv3](https://github.com/microsoft/unilm/tree/master/layoutlmv3): 布局检测模型 - [UniMERNet](https://github.com/opendatalab/UniMERNet): 公式识别模型 - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy): 表格识别模型 - [YOLO](https://github.com/ultralytics/ultralytics): 公式检测模型 - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR): OCR模型 - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO): 布局检测模型 ## Citation 如果你觉得我们模型/代码/技术报告对你有帮助,请给我们⭐和引用📝,谢谢 :) ```bibtex @article{wang2024mineru, title={MinerU: An Open-Source Solution for Precise Document Content Extraction}, author={Wang, Bin and Xu, Chao and Zhao, Xiaomeng and Ouyang, Linke and Wu, Fan and Zhao, Zhiyuan and Xu, Rui and Liu, Kaiwen and Qu, Yuan and Shang, Fukai and others}, journal={arXiv preprint arXiv:2409.18839}, year={2024} } @misc{wang2024unimernet, title={UniMERNet: A Universal Network for Real-World Mathematical Expression Recognition}, author={Bin Wang and Zhuangcheng Gu and Chao Xu and Bo Zhang and Botian Shi and Conghui He}, year={2024}, eprint={2404.15254}, archivePrefix={arXiv}, primaryClass={cs.CV} } @misc{zhao2024doclayoutyoloenhancingdocumentlayout, title={DocLayout-YOLO: Enhancing Document Layout Analysis through Diverse Synthetic Data and Global-to-Local Adaptive Perception}, author={Zhiyuan Zhao and Hengrui Kang and Bin Wang and Conghui He}, year={2024}, eprint={2410.12628}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2410.12628}, } @article{he2024opendatalab, title={Opendatalab: Empowering general artificial intelligence with open datasets}, author={He, Conghui and Li, Wei and Jin, Zhenjiang and Xu, Chao and Wang, Bin and Lin, Dahua}, journal={arXiv preprint arXiv:2407.13773}, year={2024} } ``` ## Star历史 Star History Chart ## 友情链接 - [UniMERNet(真实场景公式识别算法)](https://github.com/opendatalab/UniMERNet) - [LabelU(轻量级多模态标注工具)](https://github.com/opendatalab/labelU) - [LabelLLM(开源LLM对话标注平台)](https://github.com/opendatalab/LabelLLM) - [MinerU(一站式高质量数据提取工具)](https://github.com/opendatalab/MinerU) ================================================ FILE: configs/config.yaml ================================================ inputs: assets/demo/formula_detection_pdfs outputs: outputs/formula_detection_pdfs tasks: formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 model_path: models/MFD/weights.pt visualize: True formula_recognition: model: formula_recognition_unimernet model_config: cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/UniMERNet visualize: True ================================================ FILE: configs/formula_detection.yaml ================================================ inputs: assets/demo/formula_detection outputs: outputs/formula_detection tasks: formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/YOLO/yolo_v8_ft.pt visualize: True ================================================ FILE: configs/formula_recognition.yaml ================================================ inputs: assets/demo/formula_recognition outputs: outputs/formula_recognition tasks: formula_recognition: model: formula_recognition_unimernet model_config: cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny visualize: False ================================================ FILE: configs/layout_detection.yaml ================================================ inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: models/Layout/YOLO/doclayout_yolo_ft.pt visualize: True ================================================ FILE: configs/layout_detection_layoutlmv3.yaml ================================================ inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_layoutlmv3 model_config: model_path: models/Layout/LayoutLMv3/model_final.pth ================================================ FILE: configs/layout_detection_yolo.yaml ================================================ inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: models/Layout/YOLO/doclayout_yolo_ft.pt visualize: True device: 0 ================================================ FILE: configs/ocr.yaml ================================================ inputs: assets/demo/ocr outputs: outputs/ocr visualize: True tasks: ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 ================================================ FILE: configs/table_parsing.yaml ================================================ inputs: assets/demo/table_parsing outputs: outputs/table_parsing tasks: table_parsing: model: table_parsing_struct_eqtable model_config: model_path: models/TabRec/StructEqTable max_new_tokens: 1024 max_time: 30 output_format: latex lmdeploy: False flash_atten: True ================================================ FILE: docs/en/.readthedocs.yaml ================================================ version: 2 build: os: ubuntu-22.04 tools: python: "3.10" formats: - epub python: install: - requirements: requirements/docs.txt sphinx: configuration: docs/en/conf.py ================================================ FILE: docs/en/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/en/algorithm/formula_detection.rst ================================================ .. _algorithm_formula_detection: ==================== Formula Detection Algorithm ==================== Introduction ==================== Formula detection involves identifying the positions of all formulas (including inline and block formulas) in a given input image. .. note:: Formula detection is technically a subtask of layout detection. However, due to its complexity, we recommend using a dedicated formula detection model to decouple it. This approach typically makes data annotation easier and improves detection performance. Model Usage ==================== With the environment properly set up, simply run the layout detection algorithm script by executing ``scripts/formula_detection.py``. .. code:: shell $ python scripts/formula_detection.py --config configs/formula_detection.yaml Model Configuration -------------------- .. code:: yaml inputs: assets/demo/formula_detection outputs: outputs/formula_detection tasks: formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/yolov8/weights.pt visualize: True - inputs/outputs: Define the input file path and the visualization output directory, respectively. - tasks: Define the task type, currently only a formula detection task is included. - model: Define the specific model type: currently, only the YOLO formula detection model is available. - model_config: Define the model configuration. - img_size: Define the image's longer side size; the shorter side will be scaled proportionally. - conf_thres: Define the confidence threshold; only targets above this threshold will be detected. - iou_thres: Define the IoU threshold to remove targets with an overlap greater than this value. - batch_size: Define the batch size; the number of images inferred simultaneously. Generally, the larger the batch size, the faster the inference speed. A better GPU allows for a larger batch size. - model_path: Path to the model weights. - visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory. Diverse Input Support -------------------- The formula detection script in PDF-Extract-Kit supports various input formats such as ``a single image``, ``a directory of image files``, ``a single PDF file``, and ``a directory of PDF files``. .. note:: Modify the ``inputs`` path in ``configs/formula_detection.yaml`` according to your actual data format: - Single image: path/to/image - Image directory: path/to/images - Single PDF file: path/to/pdf - PDF directory: path/to/pdfs .. note:: When using a PDF as input, you need to change ``predict_images`` to ``predict_pdfs`` in ``formula_detection.py``. .. code:: python # for image detection detection_results = model_formula_detection.predict_images(input_data, result_path) Change to: .. code:: python # for pdf detection detection_results = model_formula_detection.predict_pdfs(input_data, result_path) Viewing Visualization Results -------------------- When the ``visualize`` option in the config file is set to ``True``, visualization results will be saved in the ``outputs/formula_detection`` directory. .. note:: Visualization facilitates the analysis of model results. However, for large-scale tasks, it is recommended to disable visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage. ================================================ FILE: docs/en/algorithm/formula_recognition.rst ================================================ .. _algorithm_formula_recognition: ============ Formula Recognition Algorithm ============ Introduction ================= Formula detection involves recognizing the content of a given input formula image and converting it to ``LaTeX`` format. Model Usage ================= With the environment properly configured, you can run the layout detection algorithm script by executing ``scripts/formula_recognition.py``. .. code:: shell $ python scripts/formula_recognition.py --config configs/formula_recognition.yaml Model Configuration ----------------- .. code:: yaml inputs: assets/demo/formula_recognition outputs: outputs/formula_recognition tasks: formula_recognition: model: formula_recognition_unimernet model_config: cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny visualize: False - inputs/outputs: Define the input file path and the directory for LaTeX prediction results, respectively. - tasks: Define the task type, currently only containing a formula recognition task. - model: Define the specific model type: Currently, only the `UniMERNet `_ formula recognition model is provided. - model_config: Define the model configuration. - cfg_path: Path to the UniMERNet configuration file. - model_path: Path to the model weights. - visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory. Support for Diverse Inputs ----------------- The formula detection script in PDF-Extract-Kit supports ``single formula images`` and ``document images with corresponding formula regions``. Viewing Visualization Results ----------------- When the visualize setting in the config file is set to True, ``LaTeX`` prediction results will be saved in the outputs directory. ================================================ FILE: docs/en/algorithm/layout_detection.rst ================================================ .. _algorithm_layout_detection: ================= Layout Detection Algorithm ================= Introduction ================= Layout detection is a fundamental task in document content extraction, aiming to locate different types of regions on a page, such as images, tables, text, and headings, to facilitate high-quality content extraction. For text and heading regions, OCR models can be used for text recognition, while table regions can be converted using table recognition models. Model Usage ================= Layout detection supports following models: .. raw:: html
Model Description Characteristics Model weight Config file
DocLayout-YOLO Improved based on YOLO-v10:
1. Generate diverse pre-training data,enhance generalization ability across multiple document types
2. Model architecture improvement, improve perception ability on scale-varing instances
Details in DocLayout-YOLO
Speed:Fast, Accuracy:High doclayout_yolo_ft.pt layout_detection.yaml
YOLO-v10 Base YOLO-v10 model Speed:Fast, Accuracy:Moderate yolov10l_ft.pt layout_detection_yolo.yaml
LayoutLMv3 Base LayoutLMv3 model Speed:Slow, Accuracy:High layoutlmv3_ft layout_detection_layoutlmv3.yaml
Once enciroment is setup, you can perform layout detection by executing ``scripts/layout_detection.py`` directly. **Run demo** .. code:: shell $ python scripts/layout_detection.py --config configs/layout_detection.yaml Model Configuration ----------------- **1. DocLayout-YOLO / YOLO-v10** .. code:: yaml inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: path/to/doclayout_yolo_model visualize: True - inputs/outputs: Define the input file path and the directory for visualization output. - tasks: Define the task type, currently only a layout detection task is included. - model: Specify the specific model type, e.g., layout_detection_yolo. - model_config: Define the model configuration. - img_size: Define the image long edge size; the short edge will be scaled proportionally based on the long edge, with the default long edge being 1024. - conf_thres: Define the confidence threshold, detecting only targets above this threshold. - iou_thres: Define the IoU threshold, removing targets with an overlap greater than this threshold. - model_path: Path to the model weights. - visualize: Whether to visualize the model results; visualized results will be saved in the outputs directory. **2. layoutlmv3** .. note:: LayoutLMv3 cannot run directly by default. Please follow the steps below to modify the configuration: 1. **Detectron2 Environment Setup** .. code-block:: bash # For Linux pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-linux_x86_64.whl # For macOS pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-macosx_10_9_universal2.whl # For Windows pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-win_amd64.whl 2. **Enable LayoutLMv3 Registration Code** Uncomment the lines at the following links: - `line 2 `_ - `line 8 `_ .. code-block:: python from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3 from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "LayoutDetectionYOLO", "LayoutDetectionLayoutlmv3", ] .. code:: yaml inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_layoutlmv3 model_config: model_path: path/to/layoutlmv3_model - inputs/outputs: Define the input file path and the directory for visualization output. - tasks: Define the task type, currently only a layout detection task is included. - model: Specify the specific model type, e.g., layout_detection_layoutlmv3. - model_config: Define the model configuration. - model_path: Path to the model weights. Diverse Input Support ----------------- The layout detection script in PDF-Extract-Kit supports input formats such as a ``single image``, a ``directory containing only image files``, a ``single PDF file``, and a ``directory containing only PDF files``. .. note:: Modify the path to inputs in configs/layout_detection.yaml according to your actual data format: - Single image: path/to/image - Image directory: path/to/images - Single PDF file: path/to/pdf - PDF directory: path/to/pdfs .. note:: When using PDF as input, you need to change ``predict_images`` to ``predict_pdfs`` in ``layout_detection.py``. .. code:: python # for image detection detection_results = model_layout_detection.predict_images(input_data, result_path) Change to: .. code:: python # for pdf detection detection_results = model_layout_detection.predict_pdfs(input_data, result_path) Viewing Visualization Results ----------------- When ``visualize`` is set to ``True`` in the config file, the visualization results will be saved in the ``outputs`` directory. .. note:: Visualization is helpful for analyzing model results, but for large-scale tasks, it is recommended to turn off visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage. ================================================ FILE: docs/en/algorithm/ocr.rst ================================================ .. _algorithm_ocr: ========================== OCR (Optical Character Recognition) Algorithm ========================== Introduction ==================== OCR(Optical Character Recognition) involves identifying the positions ajnd contents of all text blocks in pictures. Model Usage ==================== With the environment properly set up, simply run the ocr algorithm script by executing ``scripts/ocr.py`` . .. code:: shell $ python scripts/ocr.py --config configs/ocr.yaml Model Configuration -------------------- .. code:: yaml inputs: assets/demo/ocr outputs: outputs/ocr visualize: True tasks: ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 - inputs/outputs: Define the input path and the output path, respectively. - visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory. - tasks: Define the task type, currently only a OCR task is included. - model: Define the specific model type, currently, only the PaddleOCR model is available. - model_config: Define the model configuration. - lang: Define the language, default language ch supports both english and chinese. - show_log: Whether to print running logs. - det_model_dir: Define the path of PaddleOCR' detection model, If the specified path does not exist, the model weight will be automatically downloaded to the path. - rec_model_dir: Define the path of PaddleOCR' recognize model, If the specified path does not exist, the model weight will be automatically downloaded to the path. - det_db_box_thresh: Confidence filter threshold, bounding boxes whose confidence is lower than the threshold are discarded. Diverse Input Support -------------------- The OCR script in PDF-Extract-Kit supports various input formats such as ``a single image/PDF``, ``a directory of image/PDF files``. Viewing Visualization Results -------------------- When the ``visualize`` option in the config file is set to ``True``, visualization results will be saved in the ``outputs`` directory. .. note:: Visualization facilitates the analysis of model results. However, for large-scale tasks, it is recommended to disable visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage. ================================================ FILE: docs/en/algorithm/reading_order.rst ================================================ .. _algorithm_reading_oder: ============== Reading Order Algorithm ============== Comming soon. ================================================ FILE: docs/en/algorithm/table_recognition.rst ================================================ .. _algorithm_table_recognition: ======================== Table Recognition Algorithm ======================== Introduction ================= Table recognition refers to the process of inputting a table image, identifying the table structure and content, and converting it into formats such as ``LaTeX`` or ``HTML``. Model Usage ================= With the environment properly configured, you can run the table recognition algorithm script by directly executing ``scripts/table_parsing.py``. .. code:: shell $ python scripts/table_parsing.py --config configs/table_parsing.yaml Model Configuration ----------------- .. code:: yaml inputs: assets/demo/table_parsing outputs: outputs/table_parsing tasks: table_parsing: model: table_parsing_struct_eqtable model_config: model_path: models/TabRec/StructEqTable max_new_tokens: 1024 max_time: 30 output_format: latex lmdeploy: False flash_attn: True - inputs/outputs: Define the input file path and table recognition result directory respectively - tasks: Define the task type, currently only including one table recognition task - model: Define the specific model type: currently using the `StructEqTable `_ table recognition model - model_config: Define the model configuration - model_path: Path to the model weights - max_new_tokens: Maximum number of tokens to generate, default is 1024, maximum supported is 4096 - max_time: Maximum runtime for the model (in seconds) - output_format: Output format, default is set to ``latex``, options include ``html`` and ``markdown`` - lmdeploy: Whether to use LMDeploy for deployment, currently set to False - flash_attn: Whether to use flash attention, only available for Ampere GPUs Diverse Input Support ----------------- The table recognition script in PDF-Extract-Kit supports ``single table images`` and ``multiple table images`` as input. .. note:: The StructEqTable model only supports running on GPU devices .. note:: Adjust ``max_new_tokens`` and ``max_time`` according to the table content, defaults are 1024 and 30 respectively. .. note:: lmdeploy is an option for accelerated inference. If set to True, it will use LMDeploy for accelerated inference deployment. To use LMDeploy deployment, you need to install LMDeploy. For installation methods, refer to `LMDeploy `_. ================================================ FILE: docs/en/conf copy.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import subprocess import sys # def install(package): # subprocess.check_call([sys.executable, "-m", "pip", "install", package]) # # 安装 requirements.txt 中的依赖项 # requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) # if os.path.exists(requirements_path): # with open(requirements_path) as f: # packages = f.readlines() # for package in packages: # install(package.strip()) from sphinx.ext import autodoc sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- project = 'PDF-Extract-Kit' copyright = '2024, OpenDataLab' author = 'PDF-Extract-Kit Contributors' # The full version, including alpha/beta/rc tags version_file = '../../pdf_extract_kit/version.py' with open(version_file) as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', 'sphinx_copybutton', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'myst_parser', 'sphinxarg.ext', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # Exclude the prompt "$" when copying code copybutton_prompt_text = r'\$ ' copybutton_prompt_is_regexp = True language = 'zh_CN' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_book_theme' html_logo = '_static/image/logo.png' html_theme_options = { 'path_to_docs': 'docs/zh_cn', 'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit', 'use_repository_button': True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Mock out external dependencies here. autodoc_mock_imports = [ 'cpuinfo', 'torch', 'transformers', 'psutil', 'prometheus_client', 'sentencepiece', 'vllm.cuda_utils', 'vllm._C', 'numpy', 'tqdm', ] class MockedClassDocumenter(autodoc.ClassDocumenter): """Remove note about base class when a class is derived from object.""" def add_line(self, line: str, source: str, *lineno: int) -> None: if line == ' Bases: :py:class:`object`': return super().add_line(line, source, *lineno) autodoc.ClassDocumenter = MockedClassDocumenter navigation_with_keys = False ================================================ FILE: docs/en/conf.bak ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys from sphinx.ext import autodoc sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- project = 'PDF-Extract-Kit' copyright = '2024, PDF-Extract-Kit Contributors' author = 'PDF-Extract-Kit Contributors' # The full version, including alpha/beta/rc tags version_file = '../../pdf_extract_kit/version.py' with open(version_file) as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', 'sphinx_copybutton', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'myst_parser', 'sphinxarg.ext', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # Exclude the prompt "$" when copying code copybutton_prompt_text = r'\$ ' copybutton_prompt_is_regexp = True language = 'en' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_book_theme' html_logo = '_static/image/logo.png' html_theme_options = { 'path_to_docs': 'docs/en', 'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit', 'use_repository_button': True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Mock out external dependencies here. autodoc_mock_imports = [ 'cpuinfo', 'torch', 'transformers', 'psutil', 'prometheus_client', 'sentencepiece', 'vllm.cuda_utils', 'vllm._C', 'numpy', 'tqdm', ] class MockedClassDocumenter(autodoc.ClassDocumenter): """Remove note about base class when a class is derived from object.""" def add_line(self, line: str, source: str, *lineno: int) -> None: if line == ' Bases: :py:class:`object`': return super().add_line(line, source, *lineno) autodoc.ClassDocumenter = MockedClassDocumenter navigation_with_keys = False ================================================ FILE: docs/en/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import subprocess import sys def install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) # 安装 requirements.txt 中的依赖项 requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) if os.path.exists(requirements_path): with open(requirements_path) as f: packages = f.readlines() for package in packages: install(package.strip()) from sphinx.ext import autodoc sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- project = 'PDF-Extract-Kit' copyright = '2024, PDF-Extract-Kit Contributors' author = 'OpenDataLab' # The full version, including alpha/beta/rc tags version_file = '../../pdf_extract_kit/version.py' with open(version_file) as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', 'sphinx_copybutton', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'myst_parser', 'sphinxarg.ext', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # Exclude the prompt "$" when copying code copybutton_prompt_text = r'\$ ' copybutton_prompt_is_regexp = True language = 'en' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_book_theme' html_logo = '_static/image/logo.png' html_theme_options = { 'path_to_docs': 'docs/en', 'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit', 'use_repository_button': True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Mock out external dependencies here. autodoc_mock_imports = [ 'cpuinfo', 'torch', 'transformers', 'psutil', 'prometheus_client', 'sentencepiece', 'vllm.cuda_utils', 'vllm._C', 'numpy', 'tqdm', ] class MockedClassDocumenter(autodoc.ClassDocumenter): """Remove note about base class when a class is derived from object.""" def add_line(self, line: str, source: str, *lineno: int) -> None: if line == ' Bases: :py:class:`object`': return super().add_line(line, source, *lineno) autodoc.ClassDocumenter = MockedClassDocumenter navigation_with_keys = False ================================================ FILE: docs/en/evaluation/formula_detection.rst ================================================ ===================== Formula Detection Evaluation ===================== XXX ================================================ FILE: docs/en/evaluation/formula_recognition.rst ================================================ ===================== Formula Recognition Evaluation ===================== XXX ================================================ FILE: docs/en/evaluation/layout_detection.rst ================================================ ===================== Layout Detection Evaluation ===================== XXX ================================================ FILE: docs/en/evaluation/ocr.rst ================================================ ===================== OCR Evaluation ===================== XXX ================================================ FILE: docs/en/evaluation/pdf_extract.rst ================================================ ===================== PDF Content Extraction Evaluation [End-to-End] ===================== XXX ================================================ FILE: docs/en/evaluation/reading_order.rst ================================================ ===================== Reading Order Evaluation ===================== XXX ================================================ FILE: docs/en/evaluation/table_recognition.rst ================================================ ===================== Table Recognition Evaluation ===================== XXX ================================================ FILE: docs/en/get_started/installation.rst ================================================ ================================== Installation ================================== In this section, we will demonstrate how to install PDF-Extract-Kit. Best Practices ============== We recommend users follow our best practices for installing PDF-Extract-Kit. It is recommended to use a Python 3.10 conda virtual environment for the installation. **Step 1.** Create a Python 3.10 virtual environment using conda. .. code-block:: console $ conda create -n pdf-extract-kit-1.0 python=3.10 -y $ conda activate pdf-extract-kit-1.0 **Step 2.** Install the dependencies for PDF-Extract-Kit. .. code-block:: console $ # For GPU devices $ pip install -r requirements.txt $ # For CPU-only devices $ pip install -r requirements-cpu.txt .. note:: For the convenience of user environment configuration, requirements.txt only includes the environment needed for the current best models, which currently include: - Layout Detection: YOLO series (YOLOv10, DocLayout-YOLO) - Formula Detection: YOLO series (YOLOv8) - Formula Recognition: UniMERNet - OCR: PaddleOCR For other models, such as LayoutLMv3, additional environment setup is required. For details, see \ :ref:`Layout Detection Algorithms `. ================================================ FILE: docs/en/get_started/pretrained_model.rst ================================================ ================================== Model Weights Download ================================== Before using the PDF-Extract-Kit, we need to download the required model weights. You can download all models or specific model files (e.g., formula detection MFD) according to your needs. [Recommended] Method 1: ``snapshot_download`` ======================================== HuggingFace ------------ ``huggingface_hub.snapshot_download`` supports downloading specific model weights from the HuggingFace Hub and allows multithreading. You can use the following code to download model weights in parallel: .. code:: python from huggingface_hub import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', max_workers=20) If you want to download a single algorithm model (e.g., the YOLO model for the formula detection task), use the following code: .. code:: python from huggingface_hub import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') .. note:: Here, ``repo_id`` represents the name of the model on HuggingFace Hub, ``local_dir`` indicates the desired local storage path, ``max_workers`` specifies the maximum number of parallel downloads, and ``allow_patterns`` specifies the files you want to download. .. tip:: If ``local_dir`` is not specified, it will be downloaded to the default cache path of HuggingFace (``~/.cache/huggingface/hub``). To change the default cache path, modify the relevant environment variables: .. code:: console $ # Default is `~/.cache/huggingface/` $ export HF_HOME=Comming soon! .. tip:: If the download speed is slow (e.g., unable to reach maximum bandwidth), try setting ``export HF_HUB_ENABLE_HF_TRANSFER=1`` for higher download speeds. ModelScope ----------- ``modelscope.snapshot_download`` supports downloading specified model weights. You can use the following command to download the model: .. code:: python from modelscope import snapshot_download snapshot_download(model_id='opendatalab/pdf-extract-kit-1.0', cache_dir='./') If you want to download a single algorithm model (e.g., the YOLO model for the formula detection task), use the following code: .. code:: python from modelscope import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') .. note:: Here, ``model_id`` represents the name of the model in the ModelScope library, ``cache_dir`` indicates the desired local storage path, and ``allow_patterns`` specifies the files you want to download. .. note:: ``modelscope.snapshot_download`` does not support multithreaded parallel downloads. .. tip:: If ``cache_dir`` is not specified, it will be downloaded to the default cache path of ModelScope (``~/.cache/huggingface/hub``). To change the default cache path, modify the relevant environment variables: .. code:: console $ # Default is ~/.cache/modelscope/hub/ $ export MODELSCOPE_CACHE=XXXX Method 2: Git LFS =================== The remote model repositories of HuggingFace and ModelScope are Git repositories managed by Git LFS. Therefore, we can use ``git clone`` to download the weights: .. code:: console $ git lfs install $ # From HuggingFace $ git lfs clone https://huggingface.co/opendatalab/pdf-extract-kit-1.0 $ # From ModelScope $ git clone https://www.modelscope.cn/opendatalab/pdf-extract-kit-1.0.git ================================================ FILE: docs/en/get_started/quickstart.rst ================================================ ================================== Quick Start ================================== Once the PDF-Extract-Kit environment is set up and the models are downloaded, we can start using PDF-Extract-Kit. Layout Detection Example ============== Layout detection offers several models: ``LayoutLMv3``, ``YOLOv10``, and ``DocLayout-YOLO``. Compared to ``LayoutLMv3``, ``YOLOv10`` is faster. ``DocLayout-YOLO`` is based on YOLOv10 and includes diverse document pre-training and model optimization, offering both speed and high accuracy. **1. Using Layout Detection Models** .. code-block:: console $ python scripts/layout_detection.py --config configs/layout_detection.yaml After execution, we can view the detection results in the `outputs/layout_detection` directory. .. note:: The ``layout_detection.yaml`` file sets the input, output, and model configuration. For a more detailed tutorial on layout detection, see :ref:`Layout Detection Algorithm `. Formula Detection Example ============== .. code-block:: console $ python scripts/formula_detection.py --config configs/formula_detection.yaml After execution, we can view the detection results in the `outputs/formula_detection` directory. .. note:: The ``formula_detection.yaml`` file sets the input, output, and model configuration. For a more detailed tutorial on formula detection, see :ref:`Formula Detection Algorithm `. ================================================ FILE: docs/en/index.rst ================================================ .. xtuner documentation master file, created by sphinx-quickstart on Tue Jan 9 16:33:06 2024. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to the PDF-Extract-Kit Documentation ============================================== .. figure:: ./_static/image/logo.png :align: center :alt: pdf-extract-kit :class: no-scaled-link .. raw:: html

High-Quality Document Parsing Toolkit

Star Watch Fork

Tutorial ------------- .. toctree:: :maxdepth: 2 :caption: Getting Started get_started/installation.rst get_started/pretrained_model.rst get_started/quickstart.rst .. toctree:: :maxdepth: 2 :caption: Core Algorithm Modules algorithm/layout_detection.rst algorithm/formula_detection.rst algorithm/formula_recognition.rst algorithm/ocr.rst algorithm/table_recognition.rst algorithm/reading_order.rst .. toctree:: :maxdepth: 2 :caption: Task Extensions task_extend/code.rst task_extend/doc.rst task_extend/evaluation.rst .. toctree:: :maxdepth: 2 :caption: Supported Models models/supported.md .. toctree:: :maxdepth: 2 :caption: Model Performance Evaluation evaluation/layout_detection.rst evaluation/formula_detection.rst evaluation/formula_recognition.rst evaluation/ocr.rst evaluation/table_recognition.rst evaluation/reading_order.rst evaluation/pdf_extract.rst .. toctree:: :maxdepth: 2 :caption: PDF Projects project/pdf_extract.md project/doc_translate.md project/speed_up.md ================================================ FILE: docs/en/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/en/models/supported.md ================================================ # The Supported Models ================================================ FILE: docs/en/notes/changelog.md ================================================ # Changelog ## v1.0.0 (2024-10-10) The PDF-Extract-Kit-1.0 has been refactored with a more streamlined and user-friendly modular design! 🔥🔥🔥 ## v0.1.0 (2024-07-01) Official release of PDF-Extract-Kit! 🔥🔥🔥 ### Highlights - PDF-Extract-Kit-1.0 offers a high-quality layout detection model, DocLayout-YOLO. ================================================ FILE: docs/en/project/doc_translate.rst ================================================ ================= Document Translation Project ================= XXXX XXXX ================================================ FILE: docs/en/project/pdf_extract.rst ================================================ ================= Document Content Extraction Project ================= Introduction ==================== Document content extraction aiming to extract all information of document file and convert it to computer readable result(such as markdown file). It's subtasks including layout detection, formula detection, formula recognition, OCR and other tasks. Project Usage ==================== With the environment properly set up, simply run the project by executing ``project/pdf2markdown/scripts/run_project.py`` . .. code:: shell $ python project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml Project Configuration -------------------- .. code:: yaml inputs: assets/demo/formula_detection outputs: outputs/pdf2markdown visualize: True merge2markdown: True tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: models/Layout/YOLO/doclayout_yolo_ft.pt formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/YOLO/yolo_v8_ft.pt formula_recognition: model: formula_recognition_unimernet model_config: batch_size: 128 cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 - inputs/outputs: Define the input path and the output path, respectively. - visualize: Whether to visualize the project results. Visualized results will be saved in the outputs directory. - merge2markdown: Whether to merge the results into markdown documents. Only simple single-column text is supported. For markdown conversion of more complex layout documents, please refer to `MinerU `_ . - tasks: Define the task types, PDF document extraction includes layout detection, formula detection, formula recognition, and OCR tasks. - For details about the parameter meanings of each task and model, see the tutorial documentation of each task. Diverse Input Support -------------------- The Document content extraction script in PDF-Extract-Kit supports various input formats such as ``a single image/PDF``, ``a directory of image/PDF files``. Output result -------------------- The extracted results of PDF documents are stored in the outputs path in the form of json. The format of json is as follows: .. code:: json [ { "layout_dets": [ { "category_type": "text", "poly": [ 380.6792698635707, 159.85058512958923, 765.1419999999998, 159.85058512958923, 765.1419999999998, 192.51073013642917, 380.6792698635707, 192.51073013642917 ], "text": "this is an example text", "score": 0.97 }, ... ], "page_info": { "page_no": 0, "height": 2339, "width": 1654, } }, ... ] - layout_dets: Single page of PDF or image content extraction results - category_type: The attribution of a single piece of content, such as headings, images, inline formulas, and so on - poly: The location coordinates of a single content block - text: Text content of a single content block - score: Confidence score - page_info: Page information, including page number and page size - page_no: Page number, counting from 0 - height: Page size: height - width: Page size: width If the ``merge2markdown`` parameter is True, an additional markdown file will be saved. ================================================ FILE: docs/en/project/speed_up.rst ================================================ ================= Model Acceleration Project ================= XXXX XXXX ================================================ FILE: docs/en/switch_language.md ================================================ ## English ## 简体中文 ================================================ FILE: docs/en/task_extend/code.rst ================================================ ================================== Code Implementation ================================== The core code of the PDF-Extract-Kit project is implemented in the `pdf_extract_kit` directory, which contains the following modules: - configs: Configuration files for specific modules, such as `pdf_extract_kit/configs/unimernet.yaml`. If the configuration is simple, it is recommended to define it in the `yaml` file's `model_config` in `repo_root/configs` for easier user modification. - dataset: A custom `ImageDataset` class used for loading and preprocessing image data. It supports various input types and can perform unified preprocessing operations on images (such as resizing, converting to tensors, etc.) to accelerate subsequent model inference. - evaluation: A module for evaluating model results, supporting evaluations for various task types such as `layout detection`, `formula detection`, `formula recognition`, etc., allowing users to fairly compare different tasks and models. - registry: The `Registry` class is a generic registry class that provides functions for registering, retrieving, and listing registered items. Users can use this class to create different types of registries, such as task registries, model registries, etc. - tasks: The core task module contains many different types of tasks, such as `layout detection`, `formula detection`, `formula recognition`, etc. Users typically only need to add code here to add new tasks and models. .. note:: Based on the above modular design, users generally only need to implement their new task class and corresponding model in `tasks` to extend new modules (in most cases, only the corresponding model needs to be implemented, as the task is already defined), and then register it in `registry`. Below we take adding a YOLO-based `layout detection` model as an example to introduce how to add new tasks and models. Task Definition and Registration ============== First, we add a `layout_detection` directory under `tasks`, and then add a `task.py` file in that directory to define the layout detection task class, as follows: .. code-block:: python from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("layout_detection") class LayoutDetectionTask(BaseTask): def __init__(self, model): super().__init__(model) def predict_images(self, input_data, result_path): """ Predict layouts in images. Args: input_data (str): Path to a single image file or a directory containing image files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ images = self.load_images(input_data) # Perform detection return self.model.predict(images, result_path) def predict_pdfs(self, input_data, result_path): """ Predict layouts in PDF files. Args: input_data (str): Path to a single PDF file or a directory containing PDF files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ pdf_images = self.load_pdf_images(input_data) # Perform detection return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys())) As you can see, the task definition includes the following key points: * Use the `@TASK_REGISTRY.register("layout_detection")` syntax to directly register the layout task class under `TASK_REGISTRY`. * The `__init__` initialization function takes `model` as an argument, specifically referring to the `BaseTask` class. * Implement inference functions. Considering that layout detection usually processes images and PDF files, two functions `predict_images` and `predict_pdfs` are provided for users to choose flexibly. Model Definition and Registration ============== Next, we implement the specific model by creating a `models` directory under `task` and adding `yolo.py` for YOLO model definition, as follows: .. code-block:: python import os import cv2 import torch from torch.utils.data import DataLoader, Dataset from ultralytics import YOLO from pdf_extract_kit.registry import MODEL_REGISTRY from pdf_extract_kit.utils.visualization import visualize_bbox from pdf_extract_kit.dataset.dataset import ImageDataset import torchvision.transforms as transforms @MODEL_REGISTRY.register('layout_detection_yolo') class LayoutDetectionYOLO: def __init__(self, config): """ Initialize the LayoutDetectionYOLO class. Args: config (dict): Configuration dictionary containing model parameters. """ # Mapping from class IDs to class names self.id_to_names = { 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } # Load the YOLO model from the specified path self.model = YOLO(config['model_path']) # Set model parameters self.img_size = config.get('img_size', 1280) self.pdf_dpi = config.get('pdf_dpi', 200) self.conf_thres = config.get('conf_thres', 0.25) self.iou_thres = config.get('iou_thres', 0.45) self.visualize = config.get('visualize', False) self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') self.batch_size = config.get('batch_size', 1) def predict(self, images, result_path, image_ids=None): """ Predict layouts in images. Args: images (list): List of images to be predicted. result_path (str): Path to save the prediction results. image_ids (list, optional): List of image IDs corresponding to the images. Returns: list: List of prediction results. """ results = [] for idx, image in enumerate(images): result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0] if self.visualize: if not os.path.exists(result_path): os.makedirs(result_path) boxes = result.__dict__['boxes'].xyxy classes = result.__dict__['boxes'].cls vis_result = visualize_bbox(image, boxes, classes, self.id_to_names) # Determine the base name of the image if image_ids: base_name = image_ids[idx] else: base_name = os.path.basename(image) result_name = f"{base_name}_MFD.png" # Save the visualized result cv2.imwrite(os.path.join(result_path, result_name), vis_result) results.append(result) return results As you can see, the model definition includes the following key points: * Use the `@MODEL_REGISTRY.register('layout_detection_yolo')` syntax to directly register the YOLO layout model under `MODEL_REGISTRY`. * The initialization function needs to implement: + The `id_to_names` category mapping for visualization. + Model parameter configuration. + Model initialization. * The model inference function needs to implement various types of model inference: it supports image lists and `PIL.Image` class, allowing users to perform inference directly based on image paths or image streams. After implementing the above class definition, add `LayoutDetectionYOLO` to the `__all__` in `__init__.py` under the `layout_detection` task. .. code-block:: python from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "LayoutDetectionYOLO", ] .. note:: For the same task, we support multiple models. Users can choose which one to use based on evaluation results, considering model `accuracy`, `speed`, and `scenario adaptability`. After implementing the tasks and models, you can add a script program `layout_detection.py` under `repo_root/scripts`. Example Script ============== .. code-block:: python import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks # Ensure all task modules are imported TASK_NAME = 'layout_detection' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # layout_detection_task model_layout_detection = task_instances[TASK_NAME] # for image detection detection_results = model_layout_detection.predict_images(input_data, result_path) # for pdf detection # detection_results = model_layout_detection.predict_pdfs(input_data, result_path) # print(detection_results) print(f'The predicted results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) Support Type Extension ============== Batch Processing Extension ============== ================================================ FILE: docs/en/task_extend/doc.rst ================================================ ================================== Documentation Supplement ================================== ================================================ FILE: docs/en/task_extend/evaluation.rst ================================================ ================================== Model Performance Evaluation ================================== ================================================ FILE: docs/requirements.txt ================================================ sphinx sphinx_rtd_theme myst-parser sphinx-copybutton sphinx-argparse sphinx-book-theme ================================================ FILE: docs/zh_cn/.readthedocs.yaml ================================================ version: 2 build: os: ubuntu-22.04 tools: python: "3.10" formats: - epub python: install: - requirements: requirements/docs.txt sphinx: configuration: docs/zh_cn/conf.py ================================================ FILE: docs/zh_cn/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/zh_cn/algorithm/formula_detection.rst ================================================ .. _algorithm_formula_detection: ==================== 公式检测算法 ==================== 简介 ==================== 公式检测是针对给定的输入图像,检测出图像中所有包含公式的位置(包含行内公式和行间公式) .. note:: 公式检测实际上属于布局检测子任务,但由于公式检查的复杂性,我们建议使用单独的公式检测模型解耦。 这样通常使得数据标注更加方便,且公式检测效果也更好。 模型使用 ==================== 在配置好环境的情况下,直接执行 ``scripts/formula_detection.py`` 即可运行布局检测算法脚本。 .. code:: shell $ python scripts/formula_detection.py --config configs/formula_detection.yaml 模型配置 -------------------- .. code:: yaml inputs: assets/demo/formula_detection outputs: outputs/formula_detection tasks: formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/yolov8/weights.pt visualize: True - inputs/outputs: 分别定义输入文件路径和可视化输出目录 - tasks: 定义任务类型,当前只包含一个公式检测任务 - model: 定义具体模型类型: 当前仅提供YOLO公式检测模型 - model_config: 定义模型配置 - img_size: 定义图像长边大小,短边会根据长边等比例缩放 - conf_thres: 定义置信度阈值,仅检测大于该阈值的目标 - iou_thres: 定义IoU阈值,去除重叠度大于该阈值的目标 - batch_size: 定义批量大小,推理时每次同时推理的图像数,一般情况下越大推理速度越快,显卡越好该数值可以设置的越大 - model_path: 模型权重路径 - visualize: 是否对模型结果进行可视化,可视化结果会保存在outputs目录下。 多样化输入支持 -------------------- PDF-Extract-Kit中的公式检测脚本支持 ``单个图像`` 、 ``只包含图像文件的目录`` 、 ``单个PDF文件`` 、 ``只包含PDF文件的目录`` 等输入形式。 .. note:: 根据自己实际数据形式,修改 ``configs/formula_detection.yaml`` 中 ``inputs`` 的路径即可 - 单个图像: path/to/image - 图像文件夹: path/to/images - 单个PDF文件: path/to/pdf - PDF文件夹: path/to/pdfs .. note:: 当使用PDF作为输入时,需要将 ``formula_detection.py`` 中的 ``predict_images`` 修改为 ``predict_pdfs`` 。 .. code:: python # for image detection detection_results = model_formula_detection.predict_images(input_data, result_path) .. code:: python # for pdf detection detection_results = model_formula_detection.predict_pdfs(input_data, result_path) 可视化结果查看 -------------------- 当config文件中 ``visualize`` 设置为 ``True`` 时,可视化结果会保存在 ``outputs/formula_detection`` 目录下。 .. note:: 可视化可以方便对模型结果进行分析,但当进行大批量任务时,建议关掉可视化(设置 ``visualize`` 为 ``False`` ),减少内存和磁盘占用。 ================================================ FILE: docs/zh_cn/algorithm/formula_recognition.rst ================================================ .. _algorithm_formula_recognition: ============ 公式识别算法 ============ 简介 ================= 公式检测是指给定输入公式图像,识别公式图像内容并转为 ``LaTeX`` 格式。 模型使用 ================= 在配置好环境的情况下,直接执行 ``scripts/formula_recognition.py`` 即可运行布局检测算法脚本。 .. code:: shell $ python scripts/formula_recognition.py --config configs/formula_recognition.yaml 模型配置 ----------------- .. code:: yaml inputs: assets/demo/formula_recognition outputs: outputs/formula_recognition tasks: formula_recognition: model: formula_recognition_unimernet model_config: cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny visualize: False - inputs/outputs: 分别定义输入文件路径和LaTeX预测结果目录 - tasks: 定义任务类型,当前只包含一个公式识别任务 - model: 定义具体模型类型: 当前仅提供 `UniMERNet `_ 公式识别模型 - model_config: 定义模型配置 - cfg_path: UniMERNet配置文件路径 - model_path: 模型权重路径 - visualize: 是否对模型结果进行可视化,可视化结果会保存在outputs目录下。 多样化输入支持 ----------------- PDF-Extract-Kit中的公式检测脚本支持 ``单个公式图像`` 、 ``文档图像及对应公式区域`` 可视化结果查看 ----------------- 当config文件中visualize设置为True时, ``LaTeX`` 预测结果会保存在outputs目录下。 ================================================ FILE: docs/zh_cn/algorithm/layout_detection.rst ================================================ .. _algorithm_layout_detection: ================= 布局检测算法 ================= 简介 ================= ``布局检测`` 是文档内容提取的基础任务,目标对页面中不同类型的区域进行定位:如 ``图像`` 、 ``表格`` 、 ``文本`` 、 ``标题`` 等,方便后续高质量内容提取。对于 ``文本`` 、 ``标题`` 等区域,可以基于 ``OCR模型`` 进行文字识别,对于表格区域可以基于表格识别模型进行转换。 模型使用 ================= 布局检测模型支持以下模型: .. raw:: html
模型 简述 特点 模型权重 配置文件
DocLayout-YOLO 基于YOLO-v10模型改进:
1. 生成多样性预训练数据,提升对多种类型文档泛化性
2. 模型结构改进,提升对多尺度目标感知能力
详见DocLayout-YOLO
速度快、精度高 doclayout_yolo_ft.pt layout_detection.yaml
YOLO-v10 基础YOLO-v10模型 速度快,精度一般 yolov10l_ft.pt layout_detection_yolo.yaml
LayoutLMv3 基础LayoutLMv3模型 速度慢,精度较好 layoutlmv3_ft layout_detection_layoutlmv3.yaml
在配置好环境的情况下,直接执行 ``scripts/layout_detection.py`` 即可运行布局检测算法脚本。 **执行布局检测程序** .. code:: shell $ python scripts/layout_detection.py --config configs/layout_detection.yaml 模型配置 ----------------- **1. DocLayout-YOLO / YOLO-v10** .. code:: yaml inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: path/to/doclayout_yolo_model visualize: True - inputs/outputs: 分别定义输入文件路径和可视化输出目录 - tasks: 定义任务类型,当前只包含一个布局检测任务 - model: 定义具体模型类型,例如 ``layout_detection_yolo`` - model_config: 定义模型配置 - img_size: 定义图像长边大小,短边会根据长边等比例缩放,默认长边保持1024 - conf_thres: 定义置信度阈值,仅检测大于该阈值的目标 - iou_thres: 定义IoU阈值,去除重叠度大于该阈值的目标 - model_path: 模型权重路径 - visualize: 是否对模型结果进行可视化,可视化结果会保存在outputs目录下 **2. LayoutLMv3** .. note:: LayoutLMv3 默认情况下不能直接运行。运行时请将配置文件修改为configs/layout_detection_layoutlmv3.yaml,并且请按照以下步骤进行配置修改: 1. **Detectron2 环境配置** .. code-block:: bash # 对于 Linux pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-linux_x86_64.whl # 对于 macOS pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-macosx_10_9_universal2.whl # 对于 Windows pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-win_amd64.whl 2. **启用 LayoutLMv3 注册代码** 请取消注释以下链接中的代码行: - `第2行 `_ - `第8行 `_ .. code-block:: python from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3 from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "LayoutDetectionYOLO", "LayoutDetectionLayoutlmv3", ] .. code:: yaml inputs: assets/demo/layout_detection outputs: outputs/layout_detection tasks: layout_detection: model: layout_detection_layoutlmv3 model_config: model_path: path/to/layoutlmv3_model - inputs/outputs: 分别定义输入文件路径和可视化输出目录 - tasks: 定义任务类型,当前只包含一个布局检测任务 - model: 定义具体模型类型,例如layout_detection_layoutlmv3 - model_config: 定义模型配置 - model_path: 模型权重路径 多样化输入支持 ----------------- PDF-Extract-Kit中的布局检测脚本支持 ``单个图像`` 、 ``只包含图像文件的目录`` 、 ``单个PDF文件`` 、 ``只包含PDF文件的目录`` 等输入形式。 .. note:: 根据自己实际数据形式,修改configs/layout_detection.yaml中inputs的路径即可 - 单个图像: path/to/image - 图像文件夹: path/to/images - 单个PDF文件: path/to/pdf - PDF文件夹: path/to/pdfs .. note:: 当使用PDF作为输入时,需要将 ``layout_detection.py`` .. code:: python # for image detection detection_results = model_layout_detection.predict_images(input_data, result_path) 中的 ``predict_images`` 修改为 ``predict_pdfs`` 。 .. code:: python # for pdf detection detection_results = model_layout_detection.predict_pdfs(input_data, result_path) 可视化结果查看 ----------------- 当config文件中 ``visualize`` 设置为 ``True`` 时,可视化结果会保存在 ``outputs`` 目录下。 .. note:: 可视化可以方便对模型结果进行分析,但当进行大批量任务时,建议关掉可视化(设置 ``visualize`` 为 ``False`` ),减少内存和磁盘占用。 ================================================ FILE: docs/zh_cn/algorithm/ocr.rst ================================================ .. _algorithm_ocr: ========================== 光学字符识别(OCR)算法 ========================== 简介 ==================== 光学字符识别(OCR)是指对图片中的文字块进行检测和识别。 模型使用 ==================== 在配置好环境的情况下,直接执行 ``scripts/ocr.py`` 即可运行OCR算法脚本。 .. code:: shell $ python scripts/ocr.py --config configs/ocr.yaml 模型配置 -------------------- .. code:: yaml inputs: assets/demo/ocr outputs: outputs/ocr visualize: True tasks: ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 - inputs/outputs: 分别定义输入文件路径和输出路径 - visualize: 是否对模型结果进行可视化,可视化结果会保存在outputs目录下。 - tasks: 定义任务类型,当前只包含一个OCR任务 - model: 定义具体模型类型, 当前仅提供PaddleOCR模型 - model_config: 定义模型配置 - lang: 定义语种,默认语种ch支持中英文文字的检测和识别 - show_log: 是否打印检测识别过程的日志 - det_model_dir: 定义PaddleOCR检测模型的路径,指定路径不存在时,会自动下载模型权重到该路径 - rec_model_dir: 定义PaddleOCR识别模型的路径,指定路径不存在时,会自动下载模型权重到该路径 - det_db_box_thresh: 检测框筛选阈值,置信度低于该阈值的框会被舍弃 多样化输入支持 -------------------- PDF-Extract-Kit中的OCR脚本支持 ``单个图像/PDF文件`` 、 ``包含图像/PDF文件的目录`` 等输入形式。 可视化结果查看 -------------------- 当config文件中 ``visualize`` 设置为 ``True`` 时,可视化结果会保存在 ``outputs`` 参数指定的目录下。 .. note:: 可视化可以方便对模型结果进行分析,但当进行大批量任务时,建议关掉可视化(设置 ``visualize`` 为 ``False`` ),减少内存和磁盘占用。 ================================================ FILE: docs/zh_cn/algorithm/reading_order.rst ================================================ .. _algorithm_reading_oder: ============== 阅读顺序算法 ============== Comming soon. ================================================ FILE: docs/zh_cn/algorithm/table_recognition.rst ================================================ .. _algorithm_table_recognition: ============ 表格识别算法 ============ 简介 ================= 表格识别是指输入表格图像,识别表格结构和内容,并将其转换为 ``LaTeX`` 或 ``HTML`` 等格式。 模型使用 ================= 在配置好环境的情况下,直接执行 ``scripts/table_parsing.py`` 即可运行表格识别算法脚本。 .. code:: shell $ python scripts/table_parsing.py --config configs/table_parsing.yaml 模型配置 ----------------- .. code:: yaml inputs: assets/demo/table_parsing outputs: outputs/table_parsing tasks: table_parsing: model: table_parsing_struct_eqtable model_config: model_path: models/TabRec/StructEqTable max_new_tokens: 1024 max_time: 30 output_format: latex lmdeploy: False flash_attn: True - inputs/outputs: 分别定义输入文件路径和表格识别结果目录 - tasks: 定义任务类型,当前只包含一个表格识别任务 - model: 定义具体模型类型: 当前使用 `StructEqTable `_ 表格识别模型 - model_config: 定义模型配置 - model_path: 模型权重路径 - max_new_tokens: 生成的最大token数量, 默认为1024, 最大支持4096 - max_time: 模型运行的最大时间(秒) - output_format: 输出格式,默认设置为 ``latex``, 可选有 ``html`` 和 ``markdown`` - lmdeploy: 是否使用 LMDeploy 进行部署,当前设置为 False - flash_attn: 是否使用flash attention,仅适用于Ampere GPU 多样化输入支持 ----------------- PDF-Extract-Kit中的表格识别脚本支持 ``单个表格图像`` 和 ``多个表格图像`` 作为输入。 .. note:: StructEqTable表格模型仅支持GPU设备下运行 .. note:: 根据表格内容调整 ``max_new_tokens`` 和 ``max_time``, 默认分别为1024和30。 .. note:: lmdeploy为加速推理的选项,如果设置为True,将使用LMDeploy进行加速推理部署。 使用LMDeploy部署需要安装LMDeploy,安装方法参考 `LMDeploy `_ 。 ================================================ FILE: docs/zh_cn/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import subprocess import sys def install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) # 安装 requirements.txt 中的依赖项 requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) if os.path.exists(requirements_path): with open(requirements_path) as f: packages = f.readlines() for package in packages: install(package.strip()) from sphinx.ext import autodoc sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- project = 'PDF-Extract-Kit' copyright = '2024, OpenDataLab' author = 'PDF-Extract-Kit Contributors' # The full version, including alpha/beta/rc tags version_file = '../../pdf_extract_kit/version.py' with open(version_file) as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', 'sphinx_copybutton', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'myst_parser', 'sphinxarg.ext', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # Exclude the prompt "$" when copying code copybutton_prompt_text = r'\$ ' copybutton_prompt_is_regexp = True language = 'zh_CN' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_book_theme' html_logo = '_static/image/logo.png' html_theme_options = { 'path_to_docs': 'docs/zh_cn', 'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit', 'use_repository_button': True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Mock out external dependencies here. autodoc_mock_imports = [ 'cpuinfo', 'torch', 'transformers', 'psutil', 'prometheus_client', 'sentencepiece', 'vllm.cuda_utils', 'vllm._C', 'numpy', 'tqdm', ] class MockedClassDocumenter(autodoc.ClassDocumenter): """Remove note about base class when a class is derived from object.""" def add_line(self, line: str, source: str, *lineno: int) -> None: if line == ' Bases: :py:class:`object`': return super().add_line(line, source, *lineno) autodoc.ClassDocumenter = MockedClassDocumenter navigation_with_keys = False ================================================ FILE: docs/zh_cn/evaluation/formula_detection.rst ================================================ ===================== 公式检测算法评测 ===================== XXX ================================================ FILE: docs/zh_cn/evaluation/formula_recognition.rst ================================================ ===================== 公式识别算法评测 ===================== Comming soon! ================================================ FILE: docs/zh_cn/evaluation/layout_detection.rst ================================================ ===================== 布局检测算法评测 ===================== Comming soon! ================================================ FILE: docs/zh_cn/evaluation/ocr.rst ================================================ ===================== OCR算法评测 ===================== Comming soon! ================================================ FILE: docs/zh_cn/evaluation/pdf_extract.rst ================================================ ===================== PDF内容提取评测【端到端】 ===================== Comming soon! ================================================ FILE: docs/zh_cn/evaluation/reading_order.rst ================================================ ===================== 阅读顺序算法评测 ===================== XXX ================================================ FILE: docs/zh_cn/evaluation/table_recognition.rst ================================================ ===================== 表格识别算法评测 ===================== Comming soon! ================================================ FILE: docs/zh_cn/get_started/installation.rst ================================================ ================================== 安装 ================================== 本节中,我们将演示如何安装 PDF-Extract-Kit。 最佳实践 ======== 我们推荐用户参照我们的最佳实践安装 PDF-Extract-Kit。 推荐使用 Python-3.10 的 conda 虚拟环境安装 PDF-Extract-Kit。 **步骤 1.** 使用 conda 先构建一个 Python-3.10 的虚拟环境 .. code-block:: console $ conda create -n pdf-extract-kit-1.0 python=3.10 -y $ conda activate pdf-extract-kit-1.0 **步骤 2.** 安装 PDF-Extract-Kit 的依赖项 .. code-block:: console $ # 对于GPU设备 $ pip install -r requirements.txt $ # 对于CPU设备 $ pip install -r requirements-cpu.txt .. note:: 考虑到用户环境配置的便捷性,我们在requirements.txt只包含当前最好模型需要的环境,目前包含 - 布局检测:YOLO系列(YOLOv10, DocLayout-YOLO) - 公式检测:YOLO系列 (YOLOv8) - 公式识别:UniMERNet - OCR: PaddleOCR 对于其他模型请,如LayoutLMv3需要单独安装环境,具体见\ :ref:`布局检测算法 ` ================================================ FILE: docs/zh_cn/get_started/pretrained_model.rst ================================================ ================================== 模型权重下载 ================================== 在使用PDF-Extract-Kit前,我们需要下载所需要的模型权重。可以根据自己需求下载全部模型或者特定的模型文件(如公式检测MFD) [推荐] 方法 1:``snapshot_download`` ======================================== HuggingFace ------------ ``huggingface_hub.snapshot_download`` 支持下载特定的 HuggingFace Hub 模型权重,并且允许多线程。您可以利用下列代码并行下载模型权重: .. code:: python from huggingface_hub import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', max_workers=20) 如果想仅下载单个算法模型(如公式检测任务的YOLO模型),可以使用如下代码: .. code:: python from huggingface_hub import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') .. note:: 其中,\ ``repo_id`` 表示模型在 HuggingFace Hub 的名字、\ ``local_dir`` 表示期望存储到的本地路径、\ ``max_workers`` 表示下载的最大并行数,\ ``allow_patterns`` 表示想要现在的文件。 .. tip:: 如果未指定 ``local_dir``\ ,则将下载至 HuggingFace 的默认 cache 路径中(\ ``~/.cache/huggingface/hub``\ )。若要修改默认 cache 路径,需要修改相关环境变量: .. code:: console $ # 默认为 ~/.cache/huggingface/ $ export HF_HOME=Comming soon! .. tip:: 如果觉得下载较慢(例如无法达到最大带宽等情况),可以尝试设置\ ``export HF_HUB_ENABLE_HF_TRANSFER=1`` 以获得更高的下载速度。 ModelScope ----------- ``modelscope.snapshot_download`` 支持下载指定的模型权重,您可以利用下列命令下载模型: .. code:: python from modelscope import snapshot_download snapshot_download(model_id='opendatalab/pdf-extract-kit-1.0', cache_dir='./') 如果想仅下载单个算法模型(如公式检测任务的YOLO模型),可以使用如下代码: .. code:: python from modelscope import snapshot_download snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') .. note:: 其中,\ ``model_id`` 表示模型在 ModelScope 模型库的名字,\ ``cache_dir`` 表示期望存储到的本地路径, \ ``allow_patterns`` 表示想要现在的文件。 .. note:: ``modelscope.snapshot_download`` 不支持多线程并行下载。 .. tip:: 如果未指定 ``cache_dir``\ ,则将下载至 ModelScope 的默认 cache 路径中(\ ``~/.cache/huggingface/hub``\ )。 若要修改默认 cache 路径,需要修改相关环境变量: .. code:: console $ # 默认为 ~/.cache/modelscope/hub/ $ export MODELSCOPE_CACHE=XXXX 方法 2: Git LFS =================== HuggingFace 和 ModelScope 的远程模型仓库就是一个由 Git LFS 管理的 Git 仓库。因此,我们可以利用 ``git clone`` 完成权重的下载: .. code:: console $ git lfs install $ # From HuggingFace $ git lfs clone https://huggingface.co/opendatalab/pdf-extract-kit-1.0 $ # From ModelScope $ git clone https://www.modelscope.cn/opendatalab/pdf-extract-kit-1.0.git ================================================ FILE: docs/zh_cn/get_started/quickstart.rst ================================================ ================================== 快速开始 ================================== 配置好PDF-Extract-Kit环境,并下载好模型后,我们可以开始使用PDF-Extract-Kit了。 布局检测示例 ============== 布局检测提供了多种模型: ``LayoutLMv3``、 ``YOLOv10``、 ``DocLayout-YOLO``, 相比与 ``LayoutLMv3``, ``YOLOv10`` 速度更快, ``DocLayout-YOLO`` 则是基于 ``YOLOv10`` 的基础上进行多样性文档预训练及模型优化,速度快,精度高。 **1. 使用布局检测模型** .. code-block:: console $ python scripts/layout_detection.py --config configs/layout_detection.yaml 执行完之后,我们可以在 ``outpus/layout_detection`` 目录下查看检测结果。 .. note:: ``layout_detection.yaml`` 设置输入、输出及模型配置,布局检测更详细教程见\ :ref:`布局检测算法 ` \ 。 公式检测示例 ============== .. code-block:: console $ python scripts/formula_detection.py --config configs/formula_detection.yaml 执行完之后,我们可以在 ``outpus/formula_detection`` 目录下查看检测结果。 .. note:: ``formula_detection.yaml`` 设置输入、输出及模型配置,公式检测更详细教程见 \ :ref:`公式检测算法 ` \ 。 ================================================ FILE: docs/zh_cn/index.rst ================================================ .. xtuner documentation master file, created by sphinx-quickstart on Tue Jan 9 16:33:06 2024. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. 欢迎来到 PDF-Extract-Kit 的中文文档 ============================================== .. figure:: ./_static/image/logo.png :align: center :alt: pdf-extract-kit :class: no-scaled-link .. raw:: html

高质量文档解析工具箱

Star Watch Fork

文档 ------------- .. toctree:: :maxdepth: 2 :caption: 快速上手 get_started/installation.rst get_started/pretrained_model.rst get_started/quickstart.rst .. toctree:: :maxdepth: 2 :caption: 基础算法模块 algorithm/layout_detection.rst algorithm/formula_detection.rst algorithm/formula_recognition.rst algorithm/ocr.rst algorithm/table_recognition.rst algorithm/reading_order.rst .. toctree:: :maxdepth: 2 :caption: 新任务拓展 task_extend/code.rst task_extend/doc.rst task_extend/evaluation.rst .. toctree:: :maxdepth: 2 :caption: 支持的模型列表 models/supported.md .. toctree:: :maxdepth: 2 :caption: 模型性能评测 evaluation/layout_detection.rst evaluation/formula_detection.rst evaluation/formula_recognition.rst evaluation/ocr.rst evaluation/table_recognition.rst evaluation/reading_order.rst evaluation/pdf_extract.rst .. toctree:: :maxdepth: 2 :caption: PDF项目 project/pdf_extract.md project/doc_translate.md project/speed_up.md ================================================ FILE: docs/zh_cn/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/zh_cn/models/supported.md ================================================ # 已支持的模型 Comming soon! ================================================ FILE: docs/zh_cn/notes/changelog.md ================================================ # 变更日志 ## v0.2.0 (2024.09.30) PDF-Extract-Kit 代码重构,模块化设计更加简洁易用! 🔥🔥🔥 ## v0.1.0 (2024.07.01) PDF-Extract-Kit 正式发布!🔥🔥🔥 ### 亮点 - PDF-Extract-Kit提供高质量布局检测模型 DocLayout-YOLO - PDF-Extract-Kit提供高质量公式检测模型 YOLOv8 ================================================ FILE: docs/zh_cn/project/doc_translate.rst ================================================ ================= 文档翻译项目 ================= Comming soon! ================================================ FILE: docs/zh_cn/project/pdf_extract.rst ================================================ ================= 文档内容提取项目 ================= 简介 ==================== 文档内容提取是利用布局检测,公式检测,公式识别,OCR等模型,提取文档中的信息,并转换为markdown文本。 项目使用 ==================== 在配置好环境的情况下,直接执行 ``project/pdf2markdown/scripts/run_project.py`` 即可运行文档内容提取项目。 .. code:: shell $ python project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml 项目配置 -------------------- .. code:: yaml inputs: assets/demo/formula_detection outputs: outputs/pdf2markdown visualize: True merge2markdown: True tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: models/Layout/YOLO/doclayout_yolo_ft.pt formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/YOLO/yolo_v8_ft.pt formula_recognition: model: formula_recognition_unimernet model_config: batch_size: 128 cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 - inputs/outputs: 分别定义输入文件路径和输出路径 - visualize: 是否对模型结果进行可视化,可视化结果会保存在outputs目录下。 - merge2markdown: 是否将结果合并为markdown文档,这里只支持简单的单栏文本从上往下进行拼接,更复杂布局文档的markdown转换请参考 `MinerU `_ - tasks: 定义任务类型,PDF文档提取包含了布局检测、公式检测、公式识别、OCR等任务 - 具体每个任务和模型的参数含义请参考各任务的教程文档 多样化输入支持 -------------------- PDF文档内容提取支持 ``单个图像/PDF文件`` 、 ``包含图像/PDF文件的目录`` 等输入形式。 输出结果 -------------------- PDF文档提取的结果以json形式保存在 ``outputs`` 路径下,json的格式如下所示: .. code:: json [ { "layout_dets": [ { "category_type": "text", "poly": [ 380.6792698635707, 159.85058512958923, 765.1419999999998, 159.85058512958923, 765.1419999999998, 192.51073013642917, 380.6792698635707, 192.51073013642917 ], "text": "this is an example text", "score": 0.97 }, ... ], "page_info": { "page_no": 0, "height": 2339, "width": 1654, } }, ... ] - layout_dets: 单页PDF或图片的内容提取结果 - category_type: 单个内容块的所属内别,比如标题、图片、行内公式等等 - poly: 单个内容块的位置坐标 - text: 该文本块的文本内容 - score: 检测的置信度 - page_info: 页面信息,包含页码和页面尺寸 - page_no: 页码,从0开始计数 - height: 页面尺寸: 高 - width: 页面尺寸: 宽 如果 ``merge2markdown`` 参数为True的话,则会额外保存一个markdown文件。 ================================================ FILE: docs/zh_cn/project/speed_up.rst ================================================ ================= 模型加速项目 ================= Comming soon! ================================================ FILE: docs/zh_cn/switch_language.md ================================================ ## English ## 简体中文 ================================================ FILE: docs/zh_cn/task_extend/code.rst ================================================ ================================== 代码实现 ================================== PDF-Extract-Kit项目的核心代码实现在pdf_extract_kit目录下,该路径下包含下述几个模块: - configs: 特定模块的配置文件,如 ``pdf_extract_kit/configs/unimernet.yaml`` ,如果本身配置简单,建议放在 ``repo_root/configs`` 的 ``yaml`` 文件中的 ``model_config`` 里进行定义,方便用户修改。 - dataset: 自定义的 ``ImageDataset`` 类,用于加载和预处理图像数据。它支持多种输入类型,并且可以对图像进行统一的预处理操作(如调整大小、转换为张量等),以便于后续的模型推理加速。 - evaluation: 模型结果评测模块,支持多种任务类型评测,如 ``布局检测`` 、 ``公式检测`` 、 ``公式识别`` 等等,方便用户对不同任务、不同模型进行公平对比。 - registry: ``Registry`` 类是一个通用的注册表类,提供了注册、获取和列出注册项的功能。用户可以使用该类创建不同类型的注册表,例如任务注册表、模型注册表等。 - tasks: 最核心的任务模块,包含了许多不同类型的任务,如 ``布局检测`` 、 ``公式检测`` 、 ``公式识别`` 等等,用户添加新任务和新模型一般仅需要在这里进行代码添加。 .. note:: 基于上述的模块化设计,用户拓展新模块一般只需要在tasks里实现自己的新任务类及对应模型(更多情况下仅需要实现对应模型,任务已经定义好),然后在registry里注册即可。 下面我们以添加基于 ``YOLO``的 ``布局检测`` 模型为例,介绍如何添加新任务和新模型. 任务定义及注册 ============== 首先我们在 ``tasks`` 下添加一个 ``layout_detection`` 目录,然后在该目录下添加一个 ``task.py`` 文件用于定义布局检测任务类,具体如下: .. code-block:: python from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("layout_detection") class LayoutDetectionTask(BaseTask): def __init__(self, model): super().__init__(model) def predict_images(self, input_data, result_path): """ Predict layouts in images. Args: input_data (str): Path to a single image file or a directory containing image files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ images = self.load_images(input_data) # Perform detection return self.model.predict(images, result_path) def predict_pdfs(self, input_data, result_path): """ Predict layouts in PDF files. Args: input_data (str): Path to a single PDF file or a directory containing PDF files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ pdf_images = self.load_pdf_images(input_data) # Perform detection return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys())) 可以看到,任务定义包含下面几个要点: * 使用 ``@TASK_REGISTRY.register("layout_detection")`` 语法直接将布局任务类注册到 ``TASK_REGISTRY`` 下 ; * ``__init__`` 初始化函数传入 ``model`` , 具体参考 ``BaseTask`` 类 * 实现推理函数,这里考虑到布局检测通常会处理图像类及PDF文件,所以提供了两个函数 ``predict_images`` 和 ``predict_pdfs`` ,方便用户灵活选择。 模型定义及注册 ============== 接下来我们实现具体模型,在task下面新建models目录,并添加yolo.py用于YOLO模型定义,具体定义如下: .. code-block:: python import os import cv2 import torch from torch.utils.data import DataLoader, Dataset from ultralytics import YOLO from pdf_extract_kit.registry import MODEL_REGISTRY from pdf_extract_kit.utils.visualization import visualize_bbox from pdf_extract_kit.dataset.dataset import ImageDataset import torchvision.transforms as transforms @MODEL_REGISTRY.register('layout_detection_yolo') class LayoutDetectionYOLO: def __init__(self, config): """ Initialize the LayoutDetectionYOLO class. Args: config (dict): Configuration dictionary containing model parameters. """ # Mapping from class IDs to class names self.id_to_names = { 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } # Load the YOLO model from the specified path self.model = YOLO(config['model_path']) # Set model parameters self.img_size = config.get('img_size', 1280) self.pdf_dpi = config.get('pdf_dpi', 200) self.conf_thres = config.get('conf_thres', 0.25) self.iou_thres = config.get('iou_thres', 0.45) self.visualize = config.get('visualize', False) self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') self.batch_size = config.get('batch_size', 1) def predict(self, images, result_path, image_ids=None): """ Predict layouts in images. Args: images (list): List of images to be predicted. result_path (str): Path to save the prediction results. image_ids (list, optional): List of image IDs corresponding to the images. Returns: list: List of prediction results. """ results = [] for idx, image in enumerate(images): result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0] if self.visualize: if not os.path.exists(result_path): os.makedirs(result_path) boxes = result.__dict__['boxes'].xyxy classes = result.__dict__['boxes'].cls vis_result = visualize_bbox(image, boxes, classes, self.id_to_names) # Determine the base name of the image if image_ids: base_name = image_ids[idx] else: base_name = os.path.basename(image) result_name = f"{base_name}_MFD.png" # Save the visualized result cv2.imwrite(os.path.join(result_path, result_name), vis_result) results.append(result) return results 可以看到,模型定义包含下面几个要点: * 使用 ``@MODEL_REGISTRY.register('layout_detection_yolo')`` 语法直接将yolo布局模型注册到 ``MODEL_REGISTRY`` 下; * 初始化函数需要实现: + id_to_names的类别映射,用于可视化展示 + 模型参数配置 + 模型初始化 * 模型推理函数需要实现多种类型的模型推理:这里支持图像列表和PIL.Image类,可以方便用户直接基于图像路径或者图像流进行推理。 实现上述类定义后,将 ``LayoutDetectionYOLO`` 添加到 ``layout_detection`` 任务下 ``__init__.py`` 的 ``__all__`` 中即可。 .. code-block:: python from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "LayoutDetectionYOLO", ] .. note:: 对于同一个任务,我们支持多种模型,用户具体选择哪个可以根据评测结果进行选择,结合模型 ``精度`` 、 ``速度`` 和 ``场景适配程度`` 进行选择。 实现了任务和模型后,可以在 repo_root/scripts下添加脚本程序 ``layout_detection.py`` 示例脚本 ============== .. code-block:: python import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks # 确保所有任务模块被导入 TASK_NAME = 'layout_detection' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # layout_detection_task model_layout_detection = task_instances[TASK_NAME] # for image detection detection_results = model_layout_detection.predict_images(input_data, result_path) # for pdf detection # detection_results = model_layout_detection.predict_pdfs(input_data, result_path) # print(detection_results) print(f'The predicted results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) 支持类型拓展 ============== 批处理拓展 ============== ================================================ FILE: docs/zh_cn/task_extend/doc.rst ================================================ ================================== 文档补充 ================================== 在实现新的任务和模块后,需要在文档中补充相关内容,以便用户了解如何使用。 具体可以参考布局检测任务使用文档:\ :ref:`布局检测算法 ` 主要补充下述几个部分: * 任务简介 * 模型使用方式 * 配置文件解释 * 多样化输入支持(如果有) * 可视化结果查看 ================================================ FILE: docs/zh_cn/task_extend/evaluation.rst ================================================ ================================== 模型评测 ================================== Comming soon! ================================================ FILE: pdf_extract_kit/__init__.py ================================================ import os import sys current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.abspath(os.path.join(current_dir, '..')) if root_dir not in sys.path: sys.path.insert(0, root_dir) ================================================ FILE: pdf_extract_kit/configs/unimernet.yaml ================================================ model: arch: unimernet model_type: unimernet model_config: model_name: ./models/unimernet_tiny max_seq_len: 1536 load_pretrained: True pretrained: './models/unimernet_tiny/pytorch_model.pth' tokenizer_config: path: ./models/unimernet_tiny datasets: formula_rec_eval: vis_processor: eval: name: "formula_image_eval" image_size: - 192 - 672 run: runner: runner_iter task: unimernet_train batch_size_train: 64 batch_size_eval: 64 num_workers: 1 iters_per_inner_epoch: 2000 max_iters: 60000 seed: 42 output_dir: "../output/demo" evaluate: True test_splits: [ "eval" ] device: "cuda" world_size: 1 dist_url: "env://" distributed: True distributed_type: ddp # or fsdp when train llm generate_cfg: temperature: 0.0 ================================================ FILE: pdf_extract_kit/dataset/__init__.py ================================================ ================================================ FILE: pdf_extract_kit/dataset/dataset.py ================================================ import numpy as np import torch from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms class ResizeLongestSide: def __init__(self, size): self.size = size def __call__(self, img): # Get the original dimensions width, height = img.size # Determine the scaling factor if width > height: new_width = self.size new_height = int(height * (self.size / float(width))) else: new_height = self.size new_width = int(width * (self.size / float(height))) # Resize the image return img.resize((new_width, new_height), Image.BILINEAR) class ImageDataset(Dataset): def __init__(self, images, image_ids=None, img_size=1280): """ Initialize the ImageDataset class. Args: - images (list): List of image paths or PIL.Image.Image objects. - image_ids (list, optional): List of corresponding image IDs. If None, assumes images are paths. - img_size (int): Size to which images' longest side will be resized. """ self.images = images self.image_ids = image_ids if image_ids is not None else images self.img_size = img_size self.transform = transforms.Compose([ ResizeLongestSide(self.img_size), transforms.ToTensor() ]) def __len__(self): """ Return the size of the dataset. Returns: int: Number of images in the dataset. """ return len(self.images) def __getitem__(self, idx): """ Get an image and its corresponding ID by index. Args: - idx (int): Index of the image to retrieve. Returns: tuple: Transformed image tensor and corresponding image ID. """ image = self.images[idx] image_id = self.image_ids[idx] # Check if the image is a path or a PIL.Image object if isinstance(image, str): image = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): image = image.convert('RGB') else: raise ValueError("Image must be a file path or a PIL.Image object") # Apply transformations image = self.transform(image) return image, image_id class MathDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # if not pil image, then convert to pil image if isinstance(self.image_paths[idx], str): raw_image = Image.open(self.image_paths[idx]) else: raw_image = self.image_paths[idx] if self.transform: image = self.transform(raw_image) return image ================================================ FILE: pdf_extract_kit/registry/__init__.py ================================================ from .registry import TASK_REGISTRY, MODEL_REGISTRY ================================================ FILE: pdf_extract_kit/registry/registry.py ================================================ class Registry: def __init__(self): self._registry = {} def register(self, name): def decorator(item): if name in self._registry: raise ValueError(f"Item {name} already registered.") self._registry[name] = item return item return decorator def get(self, name): if name not in self._registry: raise ValueError(f"Item {name} not found in registry.") return self._registry[name] def list_items(self): return list(self._registry.keys()) # Create global registries for tasks and models TASK_REGISTRY = Registry() MODEL_REGISTRY = Registry() ================================================ FILE: pdf_extract_kit/tasks/__init__.py ================================================ from pdf_extract_kit.tasks.base_task import BaseTask from pdf_extract_kit.tasks.formula_detection.task import FormulaDetectionTask from pdf_extract_kit.tasks.formula_recognition.task import FormulaRecognitionTask from pdf_extract_kit.tasks.layout_detection.task import LayoutDetectionTask from pdf_extract_kit.tasks.ocr.task import OCRTask from pdf_extract_kit.tasks.table_parsing.task import TableParsingTask from pdf_extract_kit.registry.registry import TASK_REGISTRY __all__ = [ "BaseTask", "LayoutDetectionTask", "FormulaRecognitionTask", "LayoutDetectionTask", "OCRTask", "TableParsingTask", ] def load_task(name, cfg=None): """ Example >>> task = load_task("formula_detection", cfg=None) """ task_class = TASK_REGISTRY.get(name) task_instance = task_class(cfg) return task_instance ================================================ FILE: pdf_extract_kit/tasks/base_task.py ================================================ import os from pdf_extract_kit.utils.data_preprocess import load_pdf class BaseTask: def __init__(self, model): self.model = model def load_images(self, input_data): """ Loads images from a single image path or a directory containing multiple images. Args: input_data (str): Path to a single image file or a directory containing image files. Returns: list: List of paths to all images to be predicted. """ images = [] if os.path.isdir(input_data): # If input_data is a directory, check for nested directories for root, dirs, files in os.walk(input_data): if dirs: raise ValueError("Input directory should not contain nested directories: {}".format(input_data)) for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg')): image_path = os.path.join(root, file) images.append(image_path) images = sorted(images) break # Only process the top-level directory else: # Determine the type of input data and process accordingly if input_data.lower().endswith(('.png', '.jpg', '.jpeg')): # If input is a single image file images = [input_data] else: raise ValueError("Unsupported input data format: {}".format(input_data)) return images def load_pdf_images(self, input_data): """ Loads images from a single PDF file or directory containing multiple PDF files. Args: input_data (str): Path to a single PDF file or a directory containing PDF files. Returns: dict: Dictionary with image IDs (formed by PDF path and page number) as keys and corresponding PIL.Image objects as values. Note: Loading multiple PDFs at once is not recommended due to high memory consumption. Consider processing one PDF at a time externally using loops or multithreading. """ pdf_images = {} if os.path.isdir(input_data): # If input_data is a directory, check for nested directories for root, dirs, files in os.walk(input_data): if dirs: raise ValueError("Input directory should not contain nested directories: {}".format(input_data)) for file in files: if file.lower().endswith(('.pdf')): pdf_path = os.path.join(root, file) images = load_pdf(pdf_path) for i, img in enumerate(images): img_id = f"{os.path.splitext(file)[0]}_page_{i+1:04d}" pdf_images[img_id] = img # images = sorted(images) break # Only process the top-level directory else: # Determine the type of input data and process accordingly if input_data.lower().endswith(('.pdf')): # If input is a single image file images = load_pdf(input_data) for i, img in enumerate(images): img_id = f"{os.path.splitext(os.path.basename(input_data))[0]}_page_{i+1:04d}" pdf_images[img_id] = img else: raise ValueError("Unsupported input data format: {}".format(input_data)) return pdf_images ================================================ FILE: pdf_extract_kit/tasks/formula_detection/__init__.py ================================================ from pdf_extract_kit.tasks.formula_detection.models.yolo import FormulaDetectionYOLO from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "FurmulaDetectionYOLO", ] ================================================ FILE: pdf_extract_kit/tasks/formula_detection/models/yolo.py ================================================ import os import cv2 import torch from torch.utils.data import DataLoader, Dataset from ultralytics import YOLO from pdf_extract_kit.registry import MODEL_REGISTRY from pdf_extract_kit.utils.visualization import visualize_bbox from pdf_extract_kit.dataset.dataset import ImageDataset import torchvision.transforms as transforms @MODEL_REGISTRY.register('formula_detection_yolo') class FormulaDetectionYOLO: def __init__(self, config): """ Initialize the FormulaDetectionYOLO class. Args: config (dict): Configuration dictionary containing model parameters. """ # Mapping from class IDs to class names self.id_to_names = { 0: 'inline', 1: 'isolated' } # Load the YOLO model from the specified path self.model = YOLO(config['model_path']) # Set model parameters self.img_size = config.get('img_size', 1280) self.pdf_dpi = config.get('pdf_dpi', 200) self.conf_thres = config.get('conf_thres', 0.25) self.iou_thres = config.get('iou_thres', 0.45) self.visualize = config.get('visualize', False) self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') self.batch_size = config.get('batch_size', 1) def predict(self, images, result_path, image_ids=None): """ Predict formulas in images. Args: images (list): List of images to be predicted. result_path (str): Path to save the prediction results. image_ids (list, optional): List of image IDs corresponding to the images. Returns: list: List of prediction results. """ results = [] for idx, image in enumerate(images): result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0] if self.visualize: if not os.path.exists(result_path): os.makedirs(result_path) boxes = result.__dict__['boxes'].xyxy classes = result.__dict__['boxes'].cls scores = result.__dict__['boxes'].conf vis_result = visualize_bbox(image, boxes, classes, scores, self.id_to_names) # Determine the base name of the image if image_ids: base_name = image_ids[idx] else: # base_name = os.path.basename(image) base_name = os.path.splitext(os.path.basename(image))[0] # Remove file extension result_name = f"{base_name}_MFD.png" # Save the visualized result cv2.imwrite(os.path.join(result_path, result_name), vis_result) results.append(result) return results ================================================ FILE: pdf_extract_kit/tasks/formula_detection/task.py ================================================ from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("formula_detection") class FormulaDetectionTask(BaseTask): def __init__(self, model): super().__init__(model) def predict_images(self, input_data, result_path): """ Predict formulas in images. Args: input_data (str): Path to a single image file or a directory containing image files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ images = self.load_images(input_data) # Perform detection return self.model.predict(images, result_path) def predict_pdfs(self, input_data, result_path): """ Predict formulas in PDF files. Args: input_data (str): Path to a single PDF file or a directory containing PDF files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ pdf_images = self.load_pdf_images(input_data) # Perform detection return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys())) ================================================ FILE: pdf_extract_kit/tasks/formula_recognition/__init__.py ================================================ from pdf_extract_kit.tasks.formula_recognition.models.unimernet import FormulaRecognitionUniMERNet from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "FurmulaRecognitionUniMERNet", ] ================================================ FILE: pdf_extract_kit/tasks/formula_recognition/models/unimernet.py ================================================ import os import logging import argparse import cv2 import torch import numpy as np from PIL import Image import unimernet.tasks as tasks from unimernet.common.config import Config from unimernet.processors import load_processor from pdf_extract_kit.registry import MODEL_REGISTRY @MODEL_REGISTRY.register('formula_recognition_unimernet') class FormulaRecognitionUniMERNet: def __init__(self, config): """ Initialize the FormulaRecognitionUniMERNet class. Args: config (dict): Configuration dictionary containing model parameters. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_dir = config['model_path'] self.cfg_path = config.get('cfg_path', "pdf_extract_kit/configs/unimernet.yaml") self.batch_size = config.get('batch_size', 1) # Load the UniMERNet model self.model, self.vis_processor = self.load_model_and_processor() def load_model_and_processor(self): try: args = argparse.Namespace(cfg_path=self.cfg_path, options=None) cfg = Config(args) cfg.config.model.pretrained = os.path.join(self.model_dir, "pytorch_model.pth") cfg.config.model.model_config.model_name = self.model_dir cfg.config.model.tokenizer_config.path = self.model_dir task = tasks.setup_task(cfg) model = task.build_model(cfg).to(self.device) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) return model, vis_processor except Exception as e: logging.error(f"Error loading model and processor: {e}") raise def predict(self, images, result_path): results = [] for image_path in images: # Read the image using OpenCV open_cv_image = cv2.imread(image_path) if open_cv_image is None: logging.error(f"Error: Unable to open image at {image_path}") continue # Convert the OpenCV image to PIL.Image format raw_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) try: # Process the image using the visual processor and prepare it for the model image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) # Generate the prediction using the model output = self.model.generate({"image": image}) pred = output["pred_str"][0] logging.info(f'Prediction for {image_path}:\n{pred}') # cv2.imshow('Original Image', open_cv_image) # cv2.waitKey(0) # cv2.destroyAllWindows() results.append(pred) except Exception as e: logging.error(f"Error processing image {image_path}: {e}") return results ================================================ FILE: pdf_extract_kit/tasks/formula_recognition/task.py ================================================ from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("formula_recognition") class FormulaRecognitionTask(BaseTask): def __init__(self, model): super().__init__(model) def predict(self, input_data, result_path, bboxes=None): images = self.load_images(input_data) # Perform recognition return self.model.predict(images, result_path) ================================================ FILE: pdf_extract_kit/tasks/layout_detection/__init__.py ================================================ from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO # from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3 from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "LayoutDetectionYOLO", # "LayoutDetectionLayoutlmv3", ] ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/__init__.py ================================================ ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3.py ================================================ import os import cv2 import numpy as np from PIL import Image from pdf_extract_kit.registry.registry import MODEL_REGISTRY from pdf_extract_kit.utils.visualization import visualize_bbox from .layoutlmv3_util.model_init import Layoutlmv3_Predictor @MODEL_REGISTRY.register("layout_detection_layoutlmv3") class LayoutDetectionLayoutlmv3: def __init__(self, config): """ Initialize the LayoutDetectionYOLO class. Args: config (dict): Configuration dictionary containing model parameters. """ # Mapping from class IDs to class names self.id_to_names = { 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } self.model = Layoutlmv3_Predictor(config.get('model_path', None)) self.visualize = config.get('visualize', False) def predict(self, images, result_path, image_ids=None): """ Predict layouts in images. Args: images (list): List of images to be predicted. result_path (str): Path to save the prediction results. image_ids (list, optional): List of image IDs corresponding to the images. Returns: list: List of prediction results. """ if not os.path.exists(result_path): os.makedirs(result_path) results = [] for idx, im_file in enumerate(images): if isinstance(im_file, Image.Image): im = im_file.convert("RGB") # extracted PDF pages elif isinstance(im_file, str): im = Image.open(im_file).convert("RGB") # image path layout_res = self.model(np.array(im), ignore_catids=[]) poly = np.array([det["poly"] for det in layout_res["layout_dets"]]) boxes = poly[:, [0,1,4,5]] scores = np.array([det["score"] for det in layout_res["layout_dets"]]) classes = np.array([det["category_id"] for det in layout_res["layout_dets"]]) if self.visualize: vis_result = visualize_bbox(im_file, boxes, classes, scores, self.id_to_names) # Determine the base name of the image if image_ids: base_name = image_ids[idx] else: base_name = os.path.splitext(os.path.basename(im_file))[0] # Remove file extension result_name = f"{base_name}_layout.png" # Save the visualized result cv2.imwrite(os.path.join(result_path, result_name), vis_result) # append result results.append({ "im_path": im_file, "boxes": boxes, "scores": scores, "classes": classes, }) return results ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/backbone.py ================================================ # -------------------------------------------------------------------------------- # VIT: Multi-Path Vision Transformer for Dense Prediction # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI). # All Rights Reserved. # Written by Youngwan Lee # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # CoaT: https://github.com/mlpc-ucsd/CoaT # -------------------------------------------------------------------------------- import torch from detectron2.layers import ( ShapeSpec, ) from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16 from .deit import deit_base_patch16, mae_base_patch16 from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model from transformers import AutoConfig __all__ = [ "build_vit_fpn_backbone", ] class VIT_Backbone(Backbone): """ Implement VIT backbone. """ def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs, config_path=None, image_only=False, cfg=None): super().__init__() self._out_features = out_features if 'base' in name: self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32} self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768} else: self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32} self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024} if name == 'beit_base_patch16': model_func = beit_base_patch16 elif name == 'dit_base_patch16': model_func = dit_base_patch16 elif name == "deit_base_patch16": model_func = deit_base_patch16 elif name == "mae_base_patch16": model_func = mae_base_patch16 elif name == "dit_large_patch16": model_func = dit_large_patch16 elif name == "beit_large_patch16": model_func = beit_large_patch16 if 'beit' in name or 'dit' in name: if pos_type == "abs": self.backbone = model_func(img_size=img_size, out_features=out_features, drop_path_rate=drop_path, use_abs_pos_emb=True, **model_kwargs) elif pos_type == "shared_rel": self.backbone = model_func(img_size=img_size, out_features=out_features, drop_path_rate=drop_path, use_shared_rel_pos_bias=True, **model_kwargs) elif pos_type == "rel": self.backbone = model_func(img_size=img_size, out_features=out_features, drop_path_rate=drop_path, use_rel_pos_bias=True, **model_kwargs) else: raise ValueError() elif "layoutlmv3" in name: config = AutoConfig.from_pretrained(config_path) # disable relative bias as DiT config.has_spatial_attention_bias = False config.has_relative_attention_bias = False self.backbone = LayoutLMv3Model(config, detection=True, out_features=out_features, image_only=image_only) else: self.backbone = model_func(img_size=img_size, out_features=out_features, drop_path_rate=drop_path, **model_kwargs) self.name = name def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ if "layoutlmv3" in self.name: return self.backbone.forward( input_ids=x["input_ids"] if "input_ids" in x else None, bbox=x["bbox"] if "bbox" in x else None, images=x["images"] if "images" in x else None, attention_mask=x["attention_mask"] if "attention_mask" in x else None, # output_hidden_states=True, ) assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!" return self.backbone.forward_features(x) def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } def build_VIT_backbone(cfg): """ Create a VIT instance from config. Args: cfg: a detectron2 CfgNode Returns: A VIT backbone instance. """ # fmt: off name = cfg.MODEL.VIT.NAME out_features = cfg.MODEL.VIT.OUT_FEATURES drop_path = cfg.MODEL.VIT.DROP_PATH img_size = cfg.MODEL.VIT.IMG_SIZE pos_type = cfg.MODEL.VIT.POS_TYPE model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", "")) if 'layoutlmv3' in name: if cfg.MODEL.CONFIG_PATH != '': config_path = cfg.MODEL.CONFIG_PATH else: config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models else: config_path = None return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs, config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg) @BACKBONE_REGISTRY.register() def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec): """ Create a VIT w/ FPN backbone. Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ bottom_up = build_VIT_backbone(cfg) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS backbone = FPN( bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM, top_block=LastLevelMaxPool(), fuse_type=cfg.MODEL.FPN.FUSE_TYPE, ) return backbone ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/beit.py ================================================ """ Vision Transformer (ViT) in PyTorch A PyTorch implement of Vision Transformers as described in 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 The official jax code is released and available at https://github.com/google-research/vision_transformer Status/TODO: * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020 Ross Wightman """ import warnings import math import torch from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), **kwargs } class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.0) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None, training_window_size=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: if training_window_size == self.window_size: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) else: training_window_size = tuple(training_window_size.tolist()) new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3 # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok new_relative_position_bias_table = F.interpolate( self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads, 2 * self.window_size[0] - 1, 2 * self.window_size[1] - 1), size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic', align_corners=False) new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads, new_num_relative_distance - 3).permute( 1, 0) new_relative_position_bias_table = torch.cat( [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(training_window_size[0]) coords_w = torch.arange(training_window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += training_window_size[1] - 1 relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1 relative_position_index = \ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = new_num_relative_distance - 3 relative_position_index[0:, 0] = new_num_relative_distance - 2 relative_position_index[0, 0] = new_num_relative_distance - 1 relative_position_bias = \ new_relative_position_bias_table[relative_position_index.view(-1)].view( training_window_size[0] * training_window_size[1] + 1, training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None, training_window_size=None): if self.gamma_1 is None: x = x + self.drop_path( self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches_w = self.patch_shape[0] self.num_patches_h = self.patch_shape[1] # the so-called patch_shape is the patch shape during pre-training self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, position_embedding=None, **kwargs): # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] if position_embedding is not None: # interpolate the position embedding to the corresponding size position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2) position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic') x = x + position_embedding x = x.flatten(2).transpose(1, 2) return x, (Hp, Wp) class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_heads = num_heads self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self, training_window_size): if training_window_size == self.window_size: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww else: training_window_size = tuple(training_window_size.tolist()) new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3 # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok new_relative_position_bias_table = F.interpolate( self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads, 2 * self.window_size[0] - 1, 2 * self.window_size[1] - 1), size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic', align_corners=False) new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads, new_num_relative_distance - 3).permute( 1, 0) new_relative_position_bias_table = torch.cat( [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(training_window_size[0]) coords_w = torch.arange(training_window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += training_window_size[1] - 1 relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1 relative_position_index = \ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = new_num_relative_distance - 3 relative_position_index[0:, 0] = new_num_relative_distance - 2 relative_position_index[0, 0] = new_num_relative_distance - 1 relative_position_bias = \ new_relative_position_bias_table[relative_position_index.view(-1)].view( training_window_size[0] * training_window_size[1] + 1, training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return relative_position_bias class BEiT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_abs_pos_emb=False, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_checkpoint=True, pretrained=None, out_features=None, ): super(BEiT, self).__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.use_checkpoint = use_checkpoint if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.out_features = out_features self.out_indices = [int(name[5:]) for name in out_features] self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) self.use_shared_rel_pos_bias = use_shared_rel_pos_bias if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) # trunc_normal_(self.mask_token, std=.02) if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), # nn.SyncBatchNorm(embed_dim), nn.BatchNorm2d(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) self.fix_init_weight() def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) ''' def init_weights(self): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ logger = get_root_logger() if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) self.fix_init_weight() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}") load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger, beit_spec_expand_rel_pos = self.use_rel_pos_bias, ) ''' def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None) # Hp, Wp are HW for patches batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks if self.pos_embed is not None: cls_tokens = cls_tokens + self.pos_embed[:, :1, :] x = torch.cat((cls_tokens, x), dim=1) x = self.pos_drop(x) features = [] training_window_size = torch.tensor([Hp, Wp]) rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size) else: x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size) if i in self.out_indices: xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) feat_out = {} for name, value in zip(self.out_features, features): feat_out[name] = value return feat_out def forward(self, x): x = self.forward_features(x) return x def beit_base_patch16(pretrained=False, **kwargs): model = BEiT( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None, **kwargs) model.default_cfg = _cfg() return model def beit_large_patch16(pretrained=False, **kwargs): model = BEiT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None, **kwargs) model.default_cfg = _cfg() return model def dit_base_patch16(pretrained=False, **kwargs): model = BEiT( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0.1, **kwargs) model.default_cfg = _cfg() return model def dit_large_patch16(pretrained=False, **kwargs): model = BEiT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=1e-5, **kwargs) model.default_cfg = _cfg() return model if __name__ == '__main__': model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True) model = model.to("cuda:0") input1 = torch.rand(2, 3, 512, 762).to("cuda:0") input2 = torch.rand(2, 3, 800, 1200).to("cuda:0") input3 = torch.rand(2, 3, 720, 1000).to("cuda:0") output1 = model(input1) output2 = model(input2) output3 = model(input3) print("all done") ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/deit.py ================================================ """ Mostly copy-paste from DINO and timm library: https://github.com/facebookresearch/dino https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ import warnings import math import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.models.layers import trunc_normal_, drop_path, to_2tuple from functools import partial def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), **kwargs } class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches_w, self.num_patches_h = self.window_size self.num_patches = self.window_size[0] * self.window_size[1] self.img_size = img_size self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) return x class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros( 1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class ViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, model_name='vit_base_patch16_224', img_size=384, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, num_classes=19, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.1, attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_cfg=None, pos_embed_interp=False, random_init=False, align_corners=False, use_checkpoint=False, num_extra_tokens=1, out_features=None, **kwargs, ): super(ViT, self).__init__() self.model_name = model_name self.img_size = img_size self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.depth = depth self.num_heads = num_heads self.num_classes = num_classes self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.qk_scale = qk_scale self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.drop_path_rate = drop_path_rate self.hybrid_backbone = hybrid_backbone self.norm_layer = norm_layer self.norm_cfg = norm_cfg self.pos_embed_interp = pos_embed_interp self.random_init = random_init self.align_corners = align_corners self.use_checkpoint = use_checkpoint self.num_extra_tokens = num_extra_tokens self.out_features = out_features self.out_indices = [int(name[5:]) for name in out_features] # self.num_stages = self.depth # self.out_indices = tuple(range(self.num_stages)) if self.hybrid_backbone is not None: self.patch_embed = HybridEmbed( self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim) else: self.patch_embed = PatchEmbed( img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim) self.num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if self.num_extra_tokens == 2: self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros( 1, self.num_patches + self.num_extra_tokens, self.embed_dim)) self.pos_drop = nn.Dropout(p=self.drop_rate) # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer) for i in range(self.depth)]) # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here # self.repr = nn.Linear(embed_dim, representation_size) # self.repr_act = nn.Tanh() if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) if self.num_extra_tokens==2: trunc_normal_(self.dist_token, std=0.2) self.apply(self._init_weights) # self.fix_init_weight() def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) ''' def init_weights(self): logger = get_root_logger() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}") load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger) ''' def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def _conv_filter(self, state_dict, patch_size=16): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: v = v.reshape((v.shape[0], 3, patch_size, patch_size)) out_dict[k] = v return out_dict def to_2D(self, x): n, hw, c = x.shape h = w = int(math.sqrt(hw)) x = x.transpose(1, 2).reshape(n, c, h, w) return x def to_1D(self, x): n, c, h, w = x.shape x = x.reshape(n, c, -1).transpose(1, 2) return x def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - self.num_extra_tokens N = self.pos_embed.shape[1] - self.num_extra_tokens if npatch == N and w == h: return self.pos_embed class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens] patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:] dim = x.shape[-1] w0 = w // self.patch_embed.patch_size[0] h0 = h // self.patch_embed.patch_size[1] # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1) def prepare_tokens(self, x, mask=None): B, nc, w, h = x.shape # patch linear embedding x = self.patch_embed(x) # mask image modeling if mask is not None: x = self.mask_model(x, mask) x = x.flatten(2).transpose(1, 2) # add the [CLS] token to the embed patch tokens all_tokens = [self.cls_token.expand(B, -1, -1)] if self.num_extra_tokens == 2: dist_tokens = self.dist_token.expand(B, -1, -1) all_tokens.append(dist_tokens) all_tokens.append(x) x = torch.cat(all_tokens, dim=1) # add positional encoding to each token x = x + self.interpolate_pos_encoding(x, w, h) return self.pos_drop(x) def forward_features(self, x): # print(f"==========shape of x is {x.shape}==========") B, _, H, W = x.shape Hp, Wp = H // self.patch_size, W // self.patch_size x = self.prepare_tokens(x) features = [] for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if i in self.out_indices: xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) feat_out = {} for name, value in zip(self.out_features, features): feat_out[name] = value return feat_out def forward(self, x): x = self.forward_features(x) return x def deit_base_patch16(pretrained=False, **kwargs): model = ViT( patch_size=16, drop_rate=0., embed_dim=768, depth=12, num_heads=12, num_classes=1000, mlp_ratio=4., qkv_bias=True, use_checkpoint=True, num_extra_tokens=2, **kwargs) model.default_cfg = _cfg() return model def mae_base_patch16(pretrained=False, **kwargs): model = ViT( patch_size=16, drop_rate=0., embed_dim=768, depth=12, num_heads=12, num_classes=1000, mlp_ratio=4., qkv_bias=True, use_checkpoint=True, num_extra_tokens=1, **kwargs) model.default_cfg = _cfg() return model ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/__init__.py ================================================ from .models import ( LayoutLMv3Config, LayoutLMv3ForTokenClassification, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForSequenceClassification, LayoutLMv3Tokenizer, ) ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/__init__.py ================================================ # flake8: noqa from .data_collator import DataCollatorForKeyValueExtraction ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/cord.py ================================================ ''' Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py ''' import json import os from pathlib import Path import datasets from .image_utils import load_image, normalize_bbox logger = datasets.logging.get_logger(__name__) _CITATION = """\ @article{park2019cord, title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing}, author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk} booktitle={Document Intelligence Workshop at Neural Information Processing Systems} year={2019} } """ _DESCRIPTION = """\ https://github.com/clovaai/cord/ """ def quad_to_box(quad): # test 87 is wrongly annotated box = ( max(0, quad["x1"]), max(0, quad["y1"]), quad["x3"], quad["y3"] ) if box[3] < box[1]: bbox = list(box) tmp = bbox[3] bbox[3] = bbox[1] bbox[1] = tmp box = tuple(bbox) if box[2] < box[0]: bbox = list(box) tmp = bbox[2] bbox[2] = bbox[0] bbox[0] = tmp box = tuple(bbox) return box def _get_drive_url(url): base_url = 'https://drive.google.com/uc?id=' split_url = url.split('/') return base_url + split_url[5] _URLS = [ _get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"), _get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/") # If you failed to download the dataset through the automatic downloader, # you can download it manually and modify the code to get the local dataset. # Or you can use the following links. Please follow the original LICENSE of CORD for usage. # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip", # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip" ] class CordConfig(datasets.BuilderConfig): """BuilderConfig for CORD""" def __init__(self, **kwargs): """BuilderConfig for CORD. Args: **kwargs: keyword arguments forwarded to super. """ super(CordConfig, self).__init__(**kwargs) class Cord(datasets.GeneratorBasedBuilder): BUILDER_CONFIGS = [ CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"), ] def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, features=datasets.Features( { "id": datasets.Value("string"), "words": datasets.Sequence(datasets.Value("string")), "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), "ner_tags": datasets.Sequence( datasets.features.ClassLabel( names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"] ) ), "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), "image_path": datasets.Value("string"), } ), supervised_keys=None, citation=_CITATION, homepage="https://github.com/clovaai/cord/", ) def _split_generators(self, dl_manager): """Returns SplitGenerators.""" """Uses local files located with data_dir""" downloaded_file = dl_manager.download_and_extract(_URLS) # move files from the second URL together with files from the first one. dest = Path(downloaded_file[0])/"CORD" for split in ["train", "dev", "test"]: for file_type in ["image", "json"]: if split == "test" and file_type == "json": continue files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir() for f in files: os.rename(f, dest/split/file_type/f.name) return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"} ), datasets.SplitGenerator( name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"} ), datasets.SplitGenerator( name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"} ), ] def get_line_bbox(self, bboxs): x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)] y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)] x0, y0, x1, y1 = min(x), min(y), max(x), max(y) assert x1 >= x0 and y1 >= y0 bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))] return bbox def _generate_examples(self, filepath): logger.info("⏳ Generating examples from = %s", filepath) ann_dir = os.path.join(filepath, "json") img_dir = os.path.join(filepath, "image") for guid, file in enumerate(sorted(os.listdir(ann_dir))): words = [] bboxes = [] ner_tags = [] file_path = os.path.join(ann_dir, file) with open(file_path, "r", encoding="utf8") as f: data = json.load(f) image_path = os.path.join(img_dir, file) image_path = image_path.replace("json", "png") image, size = load_image(image_path) for item in data["valid_line"]: cur_line_bboxes = [] line_words, label = item["words"], item["category"] line_words = [w for w in line_words if w["text"].strip() != ""] if len(line_words) == 0: continue if label == "other": for w in line_words: words.append(w["text"]) ner_tags.append("O") cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size)) else: words.append(line_words[0]["text"]) ner_tags.append("B-" + label.upper()) cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size)) for w in line_words[1:]: words.append(w["text"]) ner_tags.append("I-" + label.upper()) cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size)) # by default: --segment_level_layout 1 # if do not want to use segment_level_layout, comment the following line cur_line_bboxes = self.get_line_bbox(cur_line_bboxes) bboxes.extend(cur_line_bboxes) # yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image} yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image, "image_path": image_path} ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/data_collator.py ================================================ import torch from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union from transformers import BatchEncoding, PreTrainedTokenizerBase from transformers.data.data_collator import ( DataCollatorMixin, _torch_collate_batch, ) from transformers.file_utils import PaddingStrategy from typing import NewType InputDataClass = NewType("InputDataClass", Any) def pre_calc_rel_mat(segment_ids): valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]), device=segment_ids.device, dtype=torch.bool) for i in range(segment_ids.shape[0]): for j in range(segment_ids.shape[1]): valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j] return valid_span @dataclass class DataCollatorForKeyValueExtraction(DataCollatorMixin): """ Data collator that will dynamically pad the inputs received, as well as the labels. Args: tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): The tokenizer used for encoding the data. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (:obj:`int`, `optional`): Maximum length of the returned list and optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). label_pad_token_id (:obj:`int`, `optional`, defaults to -100): The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 def __call__(self, features): label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None images = None if "images" in features[0]: images = torch.stack([torch.tensor(d.pop("images")) for d in features]) IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1 batch = self.tokenizer.pad( features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, # Conversion to tensors will fail if we have labels as they are not of the same length yet. return_tensors="pt" if labels is None else None, ) if images is not None: batch["images"] = images batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v for k, v in batch.items()} visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1) if labels is None: return batch has_bbox_input = "bbox" in features[0] has_position_input = "position_ids" in features[0] padding_idx=self.tokenizer.pad_token_id sequence_length = torch.tensor(batch["input_ids"]).shape[1] padding_side = self.tokenizer.padding_side if padding_side == "right": batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels] if has_bbox_input: batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]] if has_position_input: batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id)) for position_id in batch["position_ids"]] else: batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels] if has_bbox_input: batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]] if has_position_input: batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id)) + position_id for position_id in batch["position_ids"]] if 'segment_ids' in batch: assert 'position_ids' in batch for i in range(len(batch['segment_ids'])): batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [ batch['segment_ids'][i][-1] + 2] * IMAGE_LEN batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()} if 'segment_ids' in batch: valid_span = pre_calc_rel_mat( segment_ids=batch['segment_ids'] ) batch['valid_span'] = valid_span del batch['segment_ids'] if images is not None: visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100 batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1) return batch ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/funsd.py ================================================ # coding=utf-8 ''' Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py ''' import json import os import datasets from .image_utils import load_image, normalize_bbox logger = datasets.logging.get_logger(__name__) _CITATION = """\ @article{Jaume2019FUNSDAD, title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents}, author={Guillaume Jaume and H. K. Ekenel and J. Thiran}, journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)}, year={2019}, volume={2}, pages={1-6} } """ _DESCRIPTION = """\ https://guillaumejaume.github.io/FUNSD/ """ class FunsdConfig(datasets.BuilderConfig): """BuilderConfig for FUNSD""" def __init__(self, **kwargs): """BuilderConfig for FUNSD. Args: **kwargs: keyword arguments forwarded to super. """ super(FunsdConfig, self).__init__(**kwargs) class Funsd(datasets.GeneratorBasedBuilder): """Conll2003 dataset.""" BUILDER_CONFIGS = [ FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"), ] def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, features=datasets.Features( { "id": datasets.Value("string"), "tokens": datasets.Sequence(datasets.Value("string")), "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), "ner_tags": datasets.Sequence( datasets.features.ClassLabel( names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"] ) ), "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), "image_path": datasets.Value("string"), } ), supervised_keys=None, homepage="https://guillaumejaume.github.io/FUNSD/", citation=_CITATION, ) def _split_generators(self, dl_manager): """Returns SplitGenerators.""" downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip") return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"} ), datasets.SplitGenerator( name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"} ), ] def get_line_bbox(self, bboxs): x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)] y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)] x0, y0, x1, y1 = min(x), min(y), max(x), max(y) assert x1 >= x0 and y1 >= y0 bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))] return bbox def _generate_examples(self, filepath): logger.info("⏳ Generating examples from = %s", filepath) ann_dir = os.path.join(filepath, "annotations") img_dir = os.path.join(filepath, "images") for guid, file in enumerate(sorted(os.listdir(ann_dir))): tokens = [] bboxes = [] ner_tags = [] file_path = os.path.join(ann_dir, file) with open(file_path, "r", encoding="utf8") as f: data = json.load(f) image_path = os.path.join(img_dir, file) image_path = image_path.replace("json", "png") image, size = load_image(image_path) for item in data["form"]: cur_line_bboxes = [] words, label = item["words"], item["label"] words = [w for w in words if w["text"].strip() != ""] if len(words) == 0: continue if label == "other": for w in words: tokens.append(w["text"]) ner_tags.append("O") cur_line_bboxes.append(normalize_bbox(w["box"], size)) else: tokens.append(words[0]["text"]) ner_tags.append("B-" + label.upper()) cur_line_bboxes.append(normalize_bbox(words[0]["box"], size)) for w in words[1:]: tokens.append(w["text"]) ner_tags.append("I-" + label.upper()) cur_line_bboxes.append(normalize_bbox(w["box"], size)) # by default: --segment_level_layout 1 # if do not want to use segment_level_layout, comment the following line cur_line_bboxes = self.get_line_bbox(cur_line_bboxes) # box = normalize_bbox(item["box"], size) # cur_line_bboxes = [box for _ in range(len(words))] bboxes.extend(cur_line_bboxes) yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image, "image_path": image_path} ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/image_utils.py ================================================ import torchvision.transforms.functional as F import warnings import math import random import numpy as np from PIL import Image import torch from detectron2.data.detection_utils import read_image from detectron2.data.transforms import ResizeTransform, TransformList def normalize_bbox(bbox, size): return [ int(1000 * bbox[0] / size[0]), int(1000 * bbox[1] / size[1]), int(1000 * bbox[2] / size[0]), int(1000 * bbox[3] / size[1]), ] def load_image(image_path): image = read_image(image_path, format="BGR") h = image.shape[0] w = image.shape[1] img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)]) image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable return image, (w, h) def crop(image, i, j, h, w, boxes=None): cropped_image = F.crop(image, i, j, h, w) if boxes is not None: # Currently we cannot use this case since when some boxes is out of the cropped image, # it may be better to drop out these boxes along with their text input (instead of min or clamp) # which haven't been implemented here max_size = torch.as_tensor([w, h], dtype=torch.float32) cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i]) cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) cropped_boxes = cropped_boxes.clamp(min=0) boxes = cropped_boxes.reshape(-1, 4) return cropped_image, boxes def resize(image, size, interpolation, boxes=None): # It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally, # which is compatible with a square image size of 224x224 rescaled_image = F.resize(image, size, interpolation) if boxes is None: return rescaled_image, None ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratio_width, ratio_height = ratios # boxes = boxes.copy() scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) return rescaled_image, scaled_boxes def clamp(num, min_value, max_value): return max(min(num, max_value), min_value) def get_bb(bb, page_size): bbs = [float(j) for j in bb] xs, ys = [], [] for i, b in enumerate(bbs): if i % 2 == 0: xs.append(b) else: ys.append(b) (width, height) = page_size return_bb = [ clamp(min(xs), 0, width - 1), clamp(min(ys), 0, height - 1), clamp(max(xs), 0, width - 1), clamp(max(ys), 0, height - 1), ] return_bb = [ int(1000 * return_bb[0] / width), int(1000 * return_bb[1] / height), int(1000 * return_bb[2] / width), int(1000 * return_bb[3] / height), ] return return_bb class ToNumpy: def __call__(self, pil_img): np_img = np.array(pil_img, dtype=np.uint8) if np_img.ndim < 3: np_img = np.expand_dims(np_img, axis=-1) np_img = np.rollaxis(np_img, 2) # HWC to CHW return np_img class ToTensor: def __init__(self, dtype=torch.float32): self.dtype = dtype def __call__(self, pil_img): np_img = np.array(pil_img, dtype=np.uint8) if np_img.ndim < 3: np_img = np.expand_dims(np_img, axis=-1) np_img = np.rollaxis(np_img, 2) # HWC to CHW return torch.from_numpy(np_img).to(dtype=self.dtype) _pil_interpolation_to_str = { F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST', F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR', F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC', F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS', F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING', F.InterpolationMode.BOX: 'F.InterpolationMode.BOX', } def _pil_interp(method): if method == 'bicubic': return F.InterpolationMode.BICUBIC elif method == 'lanczos': return F.InterpolationMode.LANCZOS elif method == 'hamming': return F.InterpolationMode.HAMMING else: # default bilinear, do we want to allow nearest? return F.InterpolationMode.BILINEAR class Compose: """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img, augmentation=False, box=None): for t in self.transforms: img = t(img, augmentation, box) return img class RandomResizedCropAndInterpolationWithTwoPic: """Crop the given PIL Image to random size and aspect ratio with random interpolation. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge scale: range of size of the origin size cropped ratio: range of aspect ratio of the origin aspect ratio cropped interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation='bilinear', second_interpolation='lanczos'): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if second_size is not None: if isinstance(second_size, tuple): self.second_size = second_size else: self.second_size = (second_size, second_size) else: self.second_size = None if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("range should be of kind (min, max)") self.interpolation = _pil_interp(interpolation) self.second_interpolation = _pil_interp(second_interpolation) self.scale = scale self.ratio = ratio @staticmethod def get_params(img, scale, ratio): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. scale (tuple): range of size of the origin size cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ area = img.size[0] * img.size[1] for attempt in range(10): target_area = random.uniform(*scale) * area log_ratio = (math.log(ratio[0]), math.log(ratio[1])) aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback to central crop in_ratio = img.size[0] / img.size[1] if in_ratio < min(ratio): w = img.size[0] h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = img.size[1] w = int(round(h * max(ratio))) else: # whole image w = img.size[0] h = img.size[1] i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 return i, j, h, w def __call__(self, img, augmentation=False, box=None): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Randomly cropped and resized image. """ if augmentation: i, j, h, w = self.get_params(img, self.scale, self.ratio) img = F.crop(img, i, j, h, w) # img, box = crop(img, i, j, h, w, box) img = F.resize(img, self.size, self.interpolation) second_img = F.resize(img, self.second_size, self.second_interpolation) \ if self.second_size is not None else None return img, second_img def __repr__(self): if isinstance(self.interpolation, (tuple, list)): interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) else: interpolate_str = _pil_interpolation_to_str[self.interpolation] format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', interpolation={0}'.format(interpolate_str) if self.second_size is not None: format_string += ', second_size={0}'.format(self.second_size) format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation]) format_string += ')' return format_string def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/xfund.py ================================================ import os import json import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from PIL import Image from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic XFund_label2ids = { "O":0, 'B-HEADER':1, 'I-HEADER':2, 'B-QUESTION':3, 'I-QUESTION':4, 'B-ANSWER':5, 'I-ANSWER':6, } class xfund_dataset(Dataset): def box_norm(self, box, width, height): def clip(min_num, num, max_num): return min(max(num, min_num), max_num) x0, y0, x1, y1 = box x0 = clip(0, int((x0 / width) * 1000), 1000) y0 = clip(0, int((y0 / height) * 1000), 1000) x1 = clip(0, int((x1 / width) * 1000), 1000) y1 = clip(0, int((y1 / height) * 1000), 1000) assert x1 >= x0 assert y1 >= y0 return [x0, y0, x1, y1] def get_segment_ids(self, bboxs): segment_ids = [] for i in range(len(bboxs)): if i == 0: segment_ids.append(0) else: if bboxs[i - 1] == bboxs[i]: segment_ids.append(segment_ids[-1]) else: segment_ids.append(segment_ids[-1] + 1) return segment_ids def get_position_ids(self, segment_ids): position_ids = [] for i in range(len(segment_ids)): if i == 0: position_ids.append(2) else: if segment_ids[i] == segment_ids[i - 1]: position_ids.append(position_ids[-1] + 1) else: position_ids.append(2) return position_ids def load_data( self, data_file, ): # re-org data format total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []} for i in range(len(data_file['documents'])): width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][ 'height'] cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], [] for j in range(len(data_file['documents'][i]['document'])): cur_item = data_file['documents'][i]['document'][j] cur_doc_lines.append(cur_item['text']) cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height)) cur_doc_ner_tags.append(cur_item['label']) total_data['id'] += [len(total_data['id'])] total_data['lines'] += [cur_doc_lines] total_data['bboxes'] += [cur_doc_bboxes] total_data['ner_tags'] += [cur_doc_ner_tags] total_data['image_path'] += [data_file['documents'][i]['img']['fname']] # tokenize text and get bbox/label total_input_ids, total_bboxs, total_label_ids = [], [], [] for i in range(len(total_data['lines'])): cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], [] for j in range(len(total_data['lines'][i])): cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids'] if len(cur_input_ids) == 0: continue cur_label = total_data['ner_tags'][i][j].upper() if cur_label == 'OTHER': cur_labels = ["O"] * len(cur_input_ids) for k in range(len(cur_labels)): cur_labels[k] = self.label2ids[cur_labels[k]] else: cur_labels = [cur_label] * len(cur_input_ids) cur_labels[0] = self.label2ids['B-' + cur_labels[0]] for k in range(1, len(cur_labels)): cur_labels[k] = self.label2ids['I-' + cur_labels[k]] assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels) cur_doc_input_ids += cur_input_ids cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids) cur_doc_labels += cur_labels assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels) assert len(cur_doc_input_ids) > 0 total_input_ids.append(cur_doc_input_ids) total_bboxs.append(cur_doc_bboxs) total_label_ids.append(cur_doc_labels) assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids) # split text to several slices because of over-length input_ids, bboxs, labels = [], [], [] segment_ids, position_ids = [], [] image_path = [] for i in range(len(total_input_ids)): start = 0 cur_iter = 0 while start < len(total_input_ids[i]): end = min(start + 510, len(total_input_ids[i])) input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id]) bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]]) labels.append([-100] + total_label_ids[i][start: end] + [-100]) cur_segment_ids = self.get_segment_ids(bboxs[-1]) cur_position_ids = self.get_position_ids(cur_segment_ids) segment_ids.append(cur_segment_ids) position_ids.append(cur_position_ids) image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i])) start = end cur_iter += 1 assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids) assert len(segment_ids) == len(image_path) res = { 'input_ids': input_ids, 'bbox': bboxs, 'labels': labels, 'segment_ids': segment_ids, 'position_ids': position_ids, 'image_path': image_path, } return res def __init__( self, args, tokenizer, mode ): self.args = args self.mode = mode self.cur_la = args.language self.tokenizer = tokenizer self.label2ids = XFund_label2ids self.common_transform = Compose([ RandomResizedCropAndInterpolationWithTwoPic( size=args.input_size, interpolation=args.train_interpolation, ), ]) self.patch_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor((0.5, 0.5, 0.5)), std=torch.tensor((0.5, 0.5, 0.5))) ]) data_file = json.load( open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')), 'r')) self.feature = self.load_data(data_file) def __len__(self): return len(self.feature['input_ids']) def __getitem__(self, index): input_ids = self.feature["input_ids"][index] # attention_mask = self.feature["attention_mask"][index] attention_mask = [1] * len(input_ids) labels = self.feature["labels"][index] bbox = self.feature["bbox"][index] segment_ids = self.feature['segment_ids'][index] position_ids = self.feature['position_ids'][index] img = pil_loader(self.feature['image_path'][index]) for_patches, _ = self.common_transform(img, augmentation=False) patch = self.patch_transform(for_patches) assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids) res = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "bbox": bbox, "segment_ids": segment_ids, "position_ids": position_ids, "images": patch, } return res def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/__init__.py ================================================ from .layoutlmv3 import ( LayoutLMv3Config, LayoutLMv3ForTokenClassification, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForSequenceClassification, LayoutLMv3Tokenizer, ) ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/__init__.py ================================================ from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \ AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter from .configuration_layoutlmv3 import LayoutLMv3Config from .modeling_layoutlmv3 import ( LayoutLMv3ForTokenClassification, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForSequenceClassification, LayoutLMv3Model, ) from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast #AutoConfig.register("layoutlmv3", LayoutLMv3Config) #AutoModel.register(LayoutLMv3Config, LayoutLMv3Model) #AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification) #AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering) #AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification) #AutoTokenizer.register( # LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast #) SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter}) ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py ================================================ # coding=utf-8 from transformers.models.bert.configuration_bert import BertConfig from transformers.utils import logging logger = logging.get_logger(__name__) LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = { "layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json", "layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json", # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3 } class LayoutLMv3Config(BertConfig): model_type = "layoutlmv3" def __init__( self, pad_token_id=1, bos_token_id=0, eos_token_id=2, max_2d_position_embeddings=1024, coordinate_size=None, shape_size=None, has_relative_attention_bias=False, rel_pos_bins=32, max_rel_pos=128, has_spatial_attention_bias=False, rel_2d_pos_bins=64, max_rel_2d_pos=256, visual_embed=True, mim=False, wpa_task=False, discrete_vae_weight_path='', discrete_vae_type='dall-e', input_size=224, second_input_size=112, device='cuda', **kwargs ): """Constructs RobertaConfig.""" super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.max_2d_position_embeddings = max_2d_position_embeddings self.coordinate_size = coordinate_size self.shape_size = shape_size self.has_relative_attention_bias = has_relative_attention_bias self.rel_pos_bins = rel_pos_bins self.max_rel_pos = max_rel_pos self.has_spatial_attention_bias = has_spatial_attention_bias self.rel_2d_pos_bins = rel_2d_pos_bins self.max_rel_2d_pos = max_rel_2d_pos self.visual_embed = visual_embed self.mim = mim self.wpa_task = wpa_task self.discrete_vae_weight_path = discrete_vae_weight_path self.discrete_vae_type = discrete_vae_type self.input_size = input_size self.second_input_size = second_input_size self.device = device ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py ================================================ # coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch LayoutLMv3 model. """ import math import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import apply_chunking_to_forward from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, TokenClassifierOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from transformers.models.roberta.modeling_roberta import ( RobertaIntermediate, RobertaLMHead, RobertaOutput, RobertaSelfOutput, ) from transformers.utils import logging from .configuration_layoutlmv3 import LayoutLMv3Config from timm.models.layers import to_2tuple logger = logging.get_logger(__name__) class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # The following variables are used in detection mycheckpointer.py self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.num_patches_w = self.patch_shape[0] self.num_patches_h = self.patch_shape[1] def forward(self, x, position_embedding=None): x = self.proj(x) if position_embedding is not None: # interpolate the position embedding to the corresponding size position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2) Hp, Wp = x.shape[2], x.shape[3] position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic') x = x + position_embedding x = x.flatten(2).transpose(1, 2) return x class LayoutLMv3Embeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. """ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) # End copy self.padding_idx = config.pad_token_id self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx ) self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) def _calc_spatial_position_embeddings(self, bbox): try: assert torch.all(0 <= bbox) and torch.all(bbox <= 1023) left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) except IndexError as e: raise IndexError("The :obj:`bbox` coordinate values should be within 0-1000 range.") from e h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023)) w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023)) # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add) spatial_position_embeddings = torch.cat( [ left_position_embeddings, upper_position_embeddings, right_position_embeddings, lower_position_embeddings, h_position_embeddings, w_position_embeddings, ], dim=-1, ) return spatial_position_embeddings def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. Args: x: torch.Tensor x: Returns: torch.Tensor """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = input_ids.ne(padding_idx).int() incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx def forward( self, input_ids=None, bbox=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, ): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. position_ids = self.create_position_ids_from_input_ids( input_ids, self.padding_idx, past_key_values_length).to(input_ids.device) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings spatial_position_embeddings = self._calc_spatial_position_embeddings(bbox) embeddings = embeddings + spatial_position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings def create_position_ids_from_inputs_embeds(self, inputs_embeds): """ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. Args: inputs_embeds: torch.Tensor≈ Returns: torch.Tensor """ input_shape = inputs_embeds.size()[:-1] sequence_length = input_shape[1] position_ids = torch.arange( self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device ) return position_ids.unsqueeze(0).expand(input_shape) class LayoutLMv3PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LayoutLMv3Config base_model_prefix = "layoutlmv3" # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class LayoutLMv3SelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def cogview_attn(self, attention_scores, alpha=32): ''' https://arxiv.org/pdf/2105.13290.pdf Section 2.4 Stabilization of training: Precision Bottleneck Relaxation (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores) Seems the new attention_probs will result in a slower speed and a little bias Can use torch.allclose(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison The smaller atol (e.g., 1e-08), the better. ''' scaled_attention_scores = attention_scores / alpha max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1) # max_value = scaled_attention_scores.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1) new_attention_scores = (scaled_attention_scores - max_value) * alpha return nn.Softmax(dim=-1)(new_attention_scores) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, ): mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_layer = past_key_value[0] value_layer = past_key_value[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf) attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) if self.has_relative_attention_bias and self.has_spatial_attention_bias: attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) elif self.has_relative_attention_bias: attention_scores += rel_pos / math.sqrt(self.attention_head_size) # if self.has_relative_attention_bias: # attention_scores += rel_pos # if self.has_spatial_attention_bias: # attention_scores += rel_2d_pos # attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. # attention_probs = nn.Softmax(dim=-1)(attention_scores) # comment the line below and use this line for speedup attention_probs = self.cogview_attn(attention_scores) # to stablize training # assert torch.allclose(attention_probs, nn.Softmax(dim=-1)(attention_scores), atol=1e-8) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs class LayoutLMv3Attention(nn.Module): def __init__(self, config): super().__init__() self.self = LayoutLMv3SelfAttention(config) self.output = RobertaSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class LayoutLMv3Layer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = LayoutLMv3Attention(config) assert not config.is_decoder and not config.add_cross_attention, \ "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder." self.intermediate = RobertaIntermediate(config) self.output = RobertaOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class LayoutLMv3Encoder(nn.Module): def __init__(self, config, detection=False, out_features=None): super().__init__() self.config = config self.detection = detection self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias if self.has_relative_attention_bias: self.rel_pos_bins = config.rel_pos_bins self.max_rel_pos = config.max_rel_pos self.rel_pos_onehot_size = config.rel_pos_bins self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False) if self.has_spatial_attention_bias: self.max_rel_2d_pos = config.max_rel_2d_pos self.rel_2d_pos_bins = config.rel_2d_pos_bins self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) if self.detection: self.gradient_checkpointing = True embed_dim = self.config.hidden_size self.out_features = out_features self.out_indices = [int(name[5:]) for name in out_features] self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), # nn.SyncBatchNorm(embed_dim), nn.BatchNorm2d(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): ret = 0 if bidirectional: num_buckets //= 2 ret += (relative_position > 0).long() * num_buckets n = torch.abs(relative_position) else: n = torch.max(-relative_position, torch.zeros_like(relative_position)) # now n is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = n < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def _cal_1d_pos_emb(self, hidden_states, position_ids, valid_span): VISUAL_NUM = 196 + 1 rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) if valid_span is not None: # for the text part, if two words are not in the same line, # set their distance to the max value (position_ids.shape[-1]) rel_pos_mat[(rel_pos_mat > 0) & (valid_span == False)] = position_ids.shape[1] rel_pos_mat[(rel_pos_mat < 0) & (valid_span == False)] = -position_ids.shape[1] # image-text, minimum distance rel_pos_mat[:, -VISUAL_NUM:, :-VISUAL_NUM] = 0 rel_pos_mat[:, :-VISUAL_NUM, -VISUAL_NUM:] = 0 rel_pos = self.relative_position_bucket( rel_pos_mat, num_buckets=self.rel_pos_bins, max_distance=self.max_rel_pos, ) rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states) rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) rel_pos = rel_pos.contiguous() return rel_pos def _cal_2d_pos_emb(self, hidden_states, bbox): position_coord_x = bbox[:, :, 0] position_coord_y = bbox[:, :, 3] rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1) rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1) rel_pos_x = self.relative_position_bucket( rel_pos_x_2d_mat, num_buckets=self.rel_2d_pos_bins, max_distance=self.max_rel_2d_pos, ) rel_pos_y = self.relative_position_bucket( rel_pos_y_2d_mat, num_buckets=self.rel_2d_pos_bins, max_distance=self.max_rel_2d_pos, ) rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2) rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2) rel_pos_x = rel_pos_x.contiguous() rel_pos_y = rel_pos_y.contiguous() rel_2d_pos = rel_pos_x + rel_pos_y return rel_2d_pos def forward( self, hidden_states, bbox=None, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, position_ids=None, Hp=None, Wp=None, valid_span=None, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None next_decoder_cache = () if use_cache else None rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids, valid_span) if self.has_relative_attention_bias else None rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None if self.detection: feat_out = {} j = 0 for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) # The above line will cause error: # RuntimeError: Trying to backward through the graph a second time # (or directly access saved tensors after they have already been freed). return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, rel_pos, rel_2d_pos ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if self.detection and i in self.out_indices: xp = hidden_states[:, -Hp*Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp) feat_out[self.out_features[j]] = self.ops[j](xp.contiguous()) j += 1 if self.detection: return feat_out if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class LayoutLMv3Model(LayoutLMv3PreTrainedModel): """ """ _keys_to_ignore_on_load_missing = [r"position_ids"] # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta def __init__(self, config, detection=False, out_features=None, image_only=False): super().__init__(config) self.config = config assert not config.is_decoder and not config.add_cross_attention, \ "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder." self.detection = detection if not self.detection: self.image_only = False else: assert config.visual_embed self.image_only = image_only if not self.image_only: self.embeddings = LayoutLMv3Embeddings(config) self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features) if config.visual_embed: embed_dim = self.config.hidden_size # use the default pre-training parameters for fine-tuning (e.g., input_size) # when the input_size is larger in fine-tuning, we will interpolate the position embedding in forward self.patch_embed = PatchEmbed(embed_dim=embed_dim) patch_size = 16 size = int(self.config.input_size / patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, embed_dim)) self.pos_drop = nn.Dropout(p=0.) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: self._init_visual_bbox(img_size=(size, size)) from functools import partial norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm = norm_layer(embed_dim) self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def _init_visual_bbox(self, img_size=(14, 14), max_len=1000): visual_bbox_x = torch.div(torch.arange(0, max_len * (img_size[1] + 1), max_len), img_size[1], rounding_mode='trunc') visual_bbox_y = torch.div(torch.arange(0, max_len * (img_size[0] + 1), max_len), img_size[0], rounding_mode='trunc') visual_bbox = torch.stack( [ visual_bbox_x[:-1].repeat(img_size[0], 1), visual_bbox_y[:-1].repeat(img_size[1], 1).transpose(0, 1), visual_bbox_x[1:].repeat(img_size[0], 1), visual_bbox_y[1:].repeat(img_size[1], 1).transpose(0, 1), ], dim=-1, ).view(-1, 4) cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]]) self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0) def _calc_visual_bbox(self, device, dtype, bsz): # , img_size=(14, 14), max_len=1000): visual_bbox = self.visual_bbox.repeat(bsz, 1, 1) visual_bbox = visual_bbox.to(device).type(dtype) return visual_bbox def forward_image(self, x): if self.detection: x = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None) else: x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks if self.pos_embed is not None and self.detection: cls_tokens = cls_tokens + self.pos_embed[:, :1, :] x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None and not self.detection: x = x + self.pos_embed x = self.pos_drop(x) x = self.norm(x) return x # Copied from transformers.models.bert.modeling_bert.BertModel.forward def forward( self, input_ids=None, bbox=None, attention_mask=None, token_type_ids=None, valid_span=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, images=None, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = False # if input_ids is not None and inputs_embeds is not None: # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if input_ids is not None: input_shape = input_ids.size() batch_size, seq_length = input_shape device = input_ids.device elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape device = inputs_embeds.device elif images is not None: batch_size = len(images) device = images.device else: raise ValueError("You have to specify either input_ids or inputs_embeds or images") if not self.image_only: # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if not self.image_only: if bbox is None: bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) embedding_output = self.embeddings( input_ids=input_ids, bbox=bbox, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) final_bbox = final_position_ids = None Hp = Wp = None if images is not None: patch_size = 16 Hp, Wp = int(images.shape[2] / patch_size), int(images.shape[3] / patch_size) visual_emb = self.forward_image(images) if self.detection: visual_attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device) if self.image_only: attention_mask = visual_attention_mask else: attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1) elif self.image_only: attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device) if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: if self.config.has_spatial_attention_bias: visual_bbox = self._calc_visual_bbox(device, dtype=torch.long, bsz=batch_size) if self.image_only: final_bbox = visual_bbox else: final_bbox = torch.cat([bbox, visual_bbox], dim=1) visual_position_ids = torch.arange(0, visual_emb.shape[1], dtype=torch.long, device=device).repeat( batch_size, 1) if self.image_only: final_position_ids = visual_position_ids else: position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0) position_ids = position_ids.expand_as(input_ids) final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1) if self.image_only: embedding_output = visual_emb else: embedding_output = torch.cat([embedding_output, visual_emb], dim=1) embedding_output = self.LayerNorm(embedding_output) embedding_output = self.dropout(embedding_output) elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: if self.config.has_spatial_attention_bias: final_bbox = bbox if self.config.has_relative_attention_bias: position_ids = self.embeddings.position_ids[:, :input_shape[1]] position_ids = position_ids.expand_as(input_ids) final_position_ids = position_ids extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device) encoder_outputs = self.encoder( embedding_output, bbox=final_bbox, position_ids=final_position_ids, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, Hp=Hp, Wp=Wp, valid_span=valid_span, ) if self.detection: return encoder_outputs sequence_output = encoder_outputs[0] pooled_output = None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class LayoutLMv3ClassificationHead(nn.Module): """ Head for sentence-level classification tasks. Reference: RobertaClassificationHead """ def __init__(self, config, pool_feature=False): super().__init__() self.pool_feature = pool_feature if pool_feature: self.dense = nn.Linear(config.hidden_size*3, config.hidden_size) else: self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, x): # x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.layoutlmv3 = LayoutLMv3Model(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) if config.num_labels < 10: self.classifier = nn.Linear(config.hidden_size, config.num_labels) else: self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) self.init_weights() def forward( self, input_ids=None, bbox=None, attention_mask=None, token_type_ids=None, position_ids=None, valid_span=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, images=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.layoutlmv3( input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, images=images, valid_span=valid_span, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.layoutlmv3 = LayoutLMv3Model(config) # self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, valid_span=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, bbox=None, images=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.layoutlmv3( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, bbox=bbox, images=images, valid_span=valid_span, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.layoutlmv3 = LayoutLMv3Model(config) self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, valid_span=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, bbox=None, images=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.layoutlmv3( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, bbox=bbox, images=images, valid_span=valid_span, ) sequence_output = outputs[0][:, 0, :] logits = self.classifier(sequence_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py ================================================ # coding=utf-8 # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for LayoutLMv3, refer to RoBERTa.""" from transformers.models.roberta import RobertaTokenizer from transformers.utils import logging logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "merges_file": "merges.txt", } class LayoutLMv3Tokenizer(RobertaTokenizer): vocab_files_names = VOCAB_FILES_NAMES # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py ================================================ # coding=utf-8 # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fast Tokenization classes for LayoutLMv3, refer to RoBERTa.""" from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast from transformers.utils import logging from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} class LayoutLMv3TokenizerFast(RobertaTokenizerFast): vocab_files_names = VOCAB_FILES_NAMES # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] slow_tokenizer_class = LayoutLMv3Tokenizer ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmv3_base_inference.yaml ================================================ AUG: DETR: true CACHE_DIR: ~/cache/huggingface CUDNN_BENCHMARK: false DATALOADER: ASPECT_RATIO_GROUPING: true FILTER_EMPTY_ANNOTATIONS: false NUM_WORKERS: 4 REPEAT_THRESHOLD: 0.0 SAMPLER_TRAIN: TrainingSampler DATASETS: PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000 PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000 PROPOSAL_FILES_TEST: [] PROPOSAL_FILES_TRAIN: [] TEST: - scihub_train TRAIN: - scihub_train GLOBAL: HACK: 1.0 ICDAR_DATA_DIR_TEST: '' ICDAR_DATA_DIR_TRAIN: '' INPUT: CROP: ENABLED: true SIZE: - 384 - 600 TYPE: absolute_range FORMAT: RGB MASK_FORMAT: polygon MAX_SIZE_TEST: 1333 MAX_SIZE_TRAIN: 1333 MIN_SIZE_TEST: 800 MIN_SIZE_TRAIN: - 480 - 512 - 544 - 576 - 608 - 640 - 672 - 704 - 736 - 768 - 800 MIN_SIZE_TRAIN_SAMPLING: choice RANDOM_FLIP: horizontal MODEL: ANCHOR_GENERATOR: ANGLES: - - -90 - 0 - 90 ASPECT_RATIOS: - - 0.5 - 1.0 - 2.0 NAME: DefaultAnchorGenerator OFFSET: 0.0 SIZES: - - 32 - - 64 - - 128 - - 256 - - 512 BACKBONE: FREEZE_AT: 2 NAME: build_vit_fpn_backbone CONFIG_PATH: '' DEVICE: cuda FPN: FUSE_TYPE: sum IN_FEATURES: - layer3 - layer5 - layer7 - layer11 NORM: '' OUT_CHANNELS: 256 IMAGE_ONLY: true KEYPOINT_ON: false LOAD_PROPOSALS: false MASK_ON: true META_ARCHITECTURE: VLGeneralizedRCNN PANOPTIC_FPN: COMBINE: ENABLED: true INSTANCES_CONFIDENCE_THRESH: 0.5 OVERLAP_THRESH: 0.5 STUFF_AREA_LIMIT: 4096 INSTANCE_LOSS_WEIGHT: 1.0 PIXEL_MEAN: - 127.5 - 127.5 - 127.5 PIXEL_STD: - 127.5 - 127.5 - 127.5 PROPOSAL_GENERATOR: MIN_SIZE: 0 NAME: RPN RESNETS: DEFORM_MODULATED: false DEFORM_NUM_GROUPS: 1 DEFORM_ON_PER_STAGE: - false - false - false - false DEPTH: 50 NORM: FrozenBN NUM_GROUPS: 1 OUT_FEATURES: - res4 RES2_OUT_CHANNELS: 256 RES5_DILATION: 1 STEM_OUT_CHANNELS: 64 STRIDE_IN_1X1: true WIDTH_PER_GROUP: 64 RETINANET: BBOX_REG_LOSS_TYPE: smooth_l1 BBOX_REG_WEIGHTS: - 1.0 - 1.0 - 1.0 - 1.0 FOCAL_LOSS_ALPHA: 0.25 FOCAL_LOSS_GAMMA: 2.0 IN_FEATURES: - p3 - p4 - p5 - p6 - p7 IOU_LABELS: - 0 - -1 - 1 IOU_THRESHOLDS: - 0.4 - 0.5 NMS_THRESH_TEST: 0.5 NORM: '' NUM_CLASSES: 10 NUM_CONVS: 4 PRIOR_PROB: 0.01 SCORE_THRESH_TEST: 0.05 SMOOTH_L1_LOSS_BETA: 0.1 TOPK_CANDIDATES_TEST: 1000 ROI_BOX_CASCADE_HEAD: BBOX_REG_WEIGHTS: - - 10.0 - 10.0 - 5.0 - 5.0 - - 20.0 - 20.0 - 10.0 - 10.0 - - 30.0 - 30.0 - 15.0 - 15.0 IOUS: - 0.5 - 0.6 - 0.7 ROI_BOX_HEAD: BBOX_REG_LOSS_TYPE: smooth_l1 BBOX_REG_LOSS_WEIGHT: 1.0 BBOX_REG_WEIGHTS: - 10.0 - 10.0 - 5.0 - 5.0 CLS_AGNOSTIC_BBOX_REG: true CONV_DIM: 256 FC_DIM: 1024 NAME: FastRCNNConvFCHead NORM: '' NUM_CONV: 0 NUM_FC: 2 POOLER_RESOLUTION: 7 POOLER_SAMPLING_RATIO: 0 POOLER_TYPE: ROIAlignV2 SMOOTH_L1_BETA: 0.0 TRAIN_ON_PRED_BOXES: false ROI_HEADS: BATCH_SIZE_PER_IMAGE: 512 IN_FEATURES: - p2 - p3 - p4 - p5 IOU_LABELS: - 0 - 1 IOU_THRESHOLDS: - 0.5 NAME: CascadeROIHeads NMS_THRESH_TEST: 0.5 NUM_CLASSES: 10 POSITIVE_FRACTION: 0.25 PROPOSAL_APPEND_GT: true SCORE_THRESH_TEST: 0.05 ROI_KEYPOINT_HEAD: CONV_DIMS: - 512 - 512 - 512 - 512 - 512 - 512 - 512 - 512 LOSS_WEIGHT: 1.0 MIN_KEYPOINTS_PER_IMAGE: 1 NAME: KRCNNConvDeconvUpsampleHead NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: true NUM_KEYPOINTS: 17 POOLER_RESOLUTION: 14 POOLER_SAMPLING_RATIO: 0 POOLER_TYPE: ROIAlignV2 ROI_MASK_HEAD: CLS_AGNOSTIC_MASK: false CONV_DIM: 256 NAME: MaskRCNNConvUpsampleHead NORM: '' NUM_CONV: 4 POOLER_RESOLUTION: 14 POOLER_SAMPLING_RATIO: 0 POOLER_TYPE: ROIAlignV2 RPN: BATCH_SIZE_PER_IMAGE: 256 BBOX_REG_LOSS_TYPE: smooth_l1 BBOX_REG_LOSS_WEIGHT: 1.0 BBOX_REG_WEIGHTS: - 1.0 - 1.0 - 1.0 - 1.0 BOUNDARY_THRESH: -1 CONV_DIMS: - -1 HEAD_NAME: StandardRPNHead IN_FEATURES: - p2 - p3 - p4 - p5 - p6 IOU_LABELS: - 0 - -1 - 1 IOU_THRESHOLDS: - 0.3 - 0.7 LOSS_WEIGHT: 1.0 NMS_THRESH: 0.7 POSITIVE_FRACTION: 0.5 POST_NMS_TOPK_TEST: 1000 POST_NMS_TOPK_TRAIN: 2000 PRE_NMS_TOPK_TEST: 1000 PRE_NMS_TOPK_TRAIN: 2000 SMOOTH_L1_BETA: 0.0 SEM_SEG_HEAD: COMMON_STRIDE: 4 CONVS_DIM: 128 IGNORE_VALUE: 255 IN_FEATURES: - p2 - p3 - p4 - p5 LOSS_WEIGHT: 1.0 NAME: SemSegFPNHead NORM: GN NUM_CLASSES: 10 VIT: DROP_PATH: 0.1 IMG_SIZE: - 224 - 224 NAME: layoutlmv3_base OUT_FEATURES: - layer3 - layer5 - layer7 - layer11 POS_TYPE: abs WEIGHTS: OUTPUT_DIR: SCIHUB_DATA_DIR_TRAIN: ~/publaynet/layout_scihub/train SEED: 42 SOLVER: AMP: ENABLED: true BACKBONE_MULTIPLIER: 1.0 BASE_LR: 0.0002 BIAS_LR_FACTOR: 1.0 CHECKPOINT_PERIOD: 2000 CLIP_GRADIENTS: CLIP_TYPE: full_model CLIP_VALUE: 1.0 ENABLED: true NORM_TYPE: 2.0 GAMMA: 0.1 GRADIENT_ACCUMULATION_STEPS: 1 IMS_PER_BATCH: 32 LR_SCHEDULER_NAME: WarmupCosineLR MAX_ITER: 20000 MOMENTUM: 0.9 NESTEROV: false OPTIMIZER: ADAMW REFERENCE_WORLD_SIZE: 0 STEPS: - 10000 WARMUP_FACTOR: 0.01 WARMUP_ITERS: 333 WARMUP_METHOD: linear WEIGHT_DECAY: 0.05 WEIGHT_DECAY_BIAS: null WEIGHT_DECAY_NORM: 0.0 TEST: AUG: ENABLED: false FLIP: true MAX_SIZE: 4000 MIN_SIZES: - 400 - 500 - 600 - 700 - 800 - 900 - 1000 - 1100 - 1200 DETECTIONS_PER_IMAGE: 100 EVAL_PERIOD: 1000 EXPECTED_RESULTS: [] KEYPOINT_OKS_SIGMAS: [] PRECISE_BN: ENABLED: false NUM_ITER: 200 VERSION: 2 VIS_PERIOD: 0 ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/model_init.py ================================================ from .visualizer import Visualizer from .rcnn_vl import * from .backbone import * from detectron2.config import get_cfg from detectron2.config import CfgNode as CN from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.data.datasets import register_coco_instances from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor def add_vit_config(cfg): """ Add config for VIT. """ _C = cfg _C.MODEL.VIT = CN() # CoaT model name. _C.MODEL.VIT.NAME = "" # Output features from CoaT backbone. _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"] _C.MODEL.VIT.IMG_SIZE = [224, 224] _C.MODEL.VIT.POS_TYPE = "shared_rel" _C.MODEL.VIT.DROP_PATH = 0. _C.MODEL.VIT.MODEL_KWARGS = "{}" _C.SOLVER.OPTIMIZER = "ADAMW" _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 _C.AUG = CN() _C.AUG.DETR = False _C.MODEL.IMAGE_ONLY = True _C.PUBLAYNET_DATA_DIR_TRAIN = "" _C.PUBLAYNET_DATA_DIR_TEST = "" _C.FOOTNOTE_DATA_DIR_TRAIN = "" _C.FOOTNOTE_DATA_DIR_VAL = "" _C.SCIHUB_DATA_DIR_TRAIN = "" _C.SCIHUB_DATA_DIR_TEST = "" _C.JIAOCAI_DATA_DIR_TRAIN = "" _C.JIAOCAI_DATA_DIR_TEST = "" _C.ICDAR_DATA_DIR_TRAIN = "" _C.ICDAR_DATA_DIR_TEST = "" _C.M6DOC_DATA_DIR_TEST = "" _C.DOCSTRUCTBENCH_DATA_DIR_TEST = "" _C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = "" _C.CACHE_DIR = "" _C.MODEL.CONFIG_PATH = "" # effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS # maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1 def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() # add_coat_config(cfg) add_vit_config(cfg) cfg.merge_from_file(args.config_file) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model cfg.merge_from_list(args.opts) cfg.freeze() default_setup(cfg, args) register_coco_instances( "scihub_train", {}, cfg.SCIHUB_DATA_DIR_TRAIN + ".json", cfg.SCIHUB_DATA_DIR_TRAIN ) return cfg class DotDict(dict): def __init__(self, *args, **kwargs): super(DotDict, self).__init__(*args, **kwargs) def __getattr__(self, key): if key not in self.keys(): return None value = self[key] if isinstance(value, dict): value = DotDict(value) return value def __setattr__(self, key, value): self[key] = value class Layoutlmv3_Predictor(object): def __init__(self, weights): layout_args = { "config_file": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmv3_base_inference.yaml", "resume": False, "eval_only": False, "num_gpus": 1, "num_machines": 1, "machine_rank": 0, "dist_url": "tcp://127.0.0.1:57823", "opts": ["MODEL.WEIGHTS", weights], } layout_args = DotDict(layout_args) cfg = setup(layout_args) self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping self.predictor = DefaultPredictor(cfg) def __call__(self, image, ignore_catids=[]): page_layout_result = { "layout_dets": [] } outputs = self.predictor(image) boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist() labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist() scores = outputs["instances"].to("cpu")._fields["scores"].tolist() for bbox_idx in range(len(boxes)): if labels[bbox_idx] in ignore_catids: continue page_layout_result["layout_dets"].append({ "category_id": labels[bbox_idx], "poly": [ boxes[bbox_idx][0], boxes[bbox_idx][1], boxes[bbox_idx][2], boxes[bbox_idx][1], boxes[bbox_idx][2], boxes[bbox_idx][3], boxes[bbox_idx][0], boxes[bbox_idx][3], ], "score": scores[bbox_idx] }) return page_layout_result ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/rcnn_vl.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Dict, List, Optional, Tuple import torch from torch import nn from detectron2.config import configurable from detectron2.structures import ImageList, Instances from detectron2.utils.events import get_event_storage from detectron2.modeling.backbone import Backbone, build_backbone from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from detectron2.modeling.meta_arch import GeneralizedRCNN from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image from contextlib import contextmanager from itertools import count @META_ARCH_REGISTRY.register() class VLGeneralizedRCNN(GeneralizedRCNN): """ Generalized R-CNN. Any models that contains the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction """ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * proposals (optional): :class:`Instances`, precomputed proposals. Other information that's included in the original dicts, such as: * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "instances" whose value is a :class:`Instances`. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" """ if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None # features = self.backbone(images.tensor) input = self.get_batch(batched_inputs, images) features = self.backbone(input) if self.proposal_generator is not None: proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses def inference( self, batched_inputs: List[Dict[str, torch.Tensor]], detected_instances: Optional[List[Instances]] = None, do_postprocess: bool = True, ): """ Run inference on the given inputs. Args: batched_inputs (list[dict]): same as in :meth:`forward` detected_instances (None or list[Instances]): if not None, it contains an `Instances` object per image. The `Instances` object contains "pred_boxes" and "pred_classes" which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs. do_postprocess (bool): whether to apply post-processing on the outputs. Returns: When do_postprocess=True, same as in :meth:`forward`. Otherwise, a list[Instances] containing raw network outputs. """ assert not self.training images = self.preprocess_image(batched_inputs) # features = self.backbone(images.tensor) input = self.get_batch(batched_inputs, images) features = self.backbone(input) if detected_instances is None: if self.proposal_generator is not None: proposals, _ = self.proposal_generator(images, features, None) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] results, _ = self.roi_heads(images, features, proposals, None) else: detected_instances = [x.to(self.device) for x in detected_instances] results = self.roi_heads.forward_with_given_boxes(features, detected_instances) if do_postprocess: assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) else: return results def get_batch(self, examples, images): if len(examples) >= 1 and "bbox" not in examples[0]: # image_only return {"images": images.tensor} return input def _batch_inference(self, batched_inputs, detected_instances=None): """ Execute inference on a list of inputs, using batch size = self.batch_size (e.g., 2), instead of the length of the list. Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` """ if detected_instances is None: detected_instances = [None] * len(batched_inputs) outputs = [] inputs, instances = [], [] for idx, input, instance in zip(count(), batched_inputs, detected_instances): inputs.append(input) instances.append(instance) if len(inputs) == 2 or idx == len(batched_inputs) - 1: outputs.extend( self.inference( inputs, instances if instances[0] is not None else None, do_postprocess=True, # False ) ) inputs, instances = [], [] return outputs ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/visualizer.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import colorsys import logging import math import numpy as np from enum import Enum, unique import cv2 import matplotlib as mpl import matplotlib.colors as mplc import matplotlib.figure as mplfigure import pycocotools.mask as mask_util import torch from matplotlib.backends.backend_agg import FigureCanvasAgg from PIL import Image from detectron2.data import MetadataCatalog from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes from detectron2.utils.file_io import PathManager from detectron2.utils.colormap import random_color import pdb logger = logging.getLogger(__name__) __all__ = ["ColorMode", "VisImage", "Visualizer"] _SMALL_OBJECT_AREA_THRESH = 1000 _LARGE_MASK_AREA_THRESH = 120000 _OFF_WHITE = (1.0, 1.0, 240.0 / 255) _BLACK = (0, 0, 0) _RED = (1.0, 0, 0) _KEYPOINT_THRESHOLD = 0.05 #CLASS_NAMES = ["footnote", "footer", "header"] @unique class ColorMode(Enum): """ Enum of different color modes to use for instance visualizations. """ IMAGE = 0 """ Picks a random color for every instance and overlay segmentations with low opacity. """ SEGMENTATION = 1 """ Let instances of the same category have similar colors (from metadata.thing_colors), and overlay them with high opacity. This provides more attention on the quality of segmentation. """ IMAGE_BW = 2 """ Same as IMAGE, but convert all areas without masks to gray-scale. Only available for drawing per-instance mask predictions. """ class GenericMask: """ Attribute: polygons (list[ndarray]): list[ndarray]: polygons for this mask. Each ndarray has format [x, y, x, y, ...] mask (ndarray): a binary mask """ def __init__(self, mask_or_polygons, height, width): self._mask = self._polygons = self._has_holes = None self.height = height self.width = width m = mask_or_polygons if isinstance(m, dict): # RLEs assert "counts" in m and "size" in m if isinstance(m["counts"], list): # uncompressed RLEs h, w = m["size"] assert h == height and w == width m = mask_util.frPyObjects(m, h, w) self._mask = mask_util.decode(m)[:, :] return if isinstance(m, list): # list[ndarray] self._polygons = [np.asarray(x).reshape(-1) for x in m] return if isinstance(m, np.ndarray): # assumed to be a binary mask assert m.shape[1] != 2, m.shape assert m.shape == ( height, width, ), f"mask shape: {m.shape}, target dims: {height}, {width}" self._mask = m.astype("uint8") return raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) @property def mask(self): if self._mask is None: self._mask = self.polygons_to_mask(self._polygons) return self._mask @property def polygons(self): if self._polygons is None: self._polygons, self._has_holes = self.mask_to_polygons(self._mask) return self._polygons @property def has_holes(self): if self._has_holes is None: if self._mask is not None: self._polygons, self._has_holes = self.mask_to_polygons(self._mask) else: self._has_holes = False # if original format is polygon, does not have holes return self._has_holes def mask_to_polygons(self, mask): # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. # Internal contours (holes) are placed in hierarchy-2. # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) hierarchy = res[-1] if hierarchy is None: # empty mask return [], False has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 res = res[-2] res = [x.flatten() for x in res] # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. # We add 0.5 to turn them into real-value coordinate space. A better solution # would be to first +0.5 and then dilate the returned polygon by 0.5. res = [x + 0.5 for x in res if len(x) >= 6] return res, has_holes def polygons_to_mask(self, polygons): rle = mask_util.frPyObjects(polygons, self.height, self.width) rle = mask_util.merge(rle) return mask_util.decode(rle)[:, :] def area(self): return self.mask.sum() def bbox(self): p = mask_util.frPyObjects(self.polygons, self.height, self.width) p = mask_util.merge(p) bbox = mask_util.toBbox(p) bbox[2] += bbox[0] bbox[3] += bbox[1] return bbox class _PanopticPrediction: """ Unify different panoptic annotation/prediction formats """ def __init__(self, panoptic_seg, segments_info, metadata=None): if segments_info is None: assert metadata is not None # If "segments_info" is None, we assume "panoptic_img" is a # H*W int32 image storing the panoptic_id in the format of # category_id * label_divisor + instance_id. We reserve -1 for # VOID label. label_divisor = metadata.label_divisor segments_info = [] for panoptic_label in np.unique(panoptic_seg.numpy()): if panoptic_label == -1: # VOID region. continue pred_class = panoptic_label // label_divisor isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() segments_info.append( { "id": int(panoptic_label), "category_id": int(pred_class), "isthing": bool(isthing), } ) del metadata self._seg = panoptic_seg self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) areas = areas.numpy() sorted_idxs = np.argsort(-areas) self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] self._seg_ids = self._seg_ids.tolist() for sid, area in zip(self._seg_ids, self._seg_areas): if sid in self._sinfo: self._sinfo[sid]["area"] = float(area) def non_empty_mask(self): """ Returns: (H, W) array, a mask for all pixels that have a prediction """ empty_ids = [] for id in self._seg_ids: if id not in self._sinfo: empty_ids.append(id) if len(empty_ids) == 0: return np.zeros(self._seg.shape, dtype=np.uint8) assert ( len(empty_ids) == 1 ), ">1 ids corresponds to no labels. This is currently not supported" return (self._seg != empty_ids[0]).numpy().astype(np.bool) def semantic_masks(self): for sid in self._seg_ids: sinfo = self._sinfo.get(sid) if sinfo is None or sinfo["isthing"]: # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. continue yield (self._seg == sid).numpy().astype(np.bool), sinfo def instance_masks(self): for sid in self._seg_ids: sinfo = self._sinfo.get(sid) if sinfo is None or not sinfo["isthing"]: continue mask = (self._seg == sid).numpy().astype(np.bool) if mask.sum() > 0: yield mask, sinfo def _create_text_labels(classes, scores, class_names, is_crowd=None): """ Args: classes (list[int] or None): scores (list[float] or None): class_names (list[str] or None): is_crowd (list[bool] or None): Returns: list[str] or None """ #class_names = CLASS_NAMES labels = None if classes is not None: if class_names is not None and len(class_names) > 0: labels = [class_names[i] for i in classes] else: labels = [str(i) for i in classes] if scores is not None: if labels is None: labels = ["{:.0f}%".format(s * 100) for s in scores] else: labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] if labels is not None and is_crowd is not None: labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] return labels class VisImage: def __init__(self, img, scale=1.0): """ Args: img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. scale (float): scale the input image """ self.img = img self.scale = scale self.width, self.height = img.shape[1], img.shape[0] self._setup_figure(img) def _setup_figure(self, img): """ Args: Same as in :meth:`__init__()`. Returns: fig (matplotlib.pyplot.figure): top level container for all the image plot elements. ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. """ fig = mplfigure.Figure(frameon=False) self.dpi = fig.get_dpi() # add a small 1e-2 to avoid precision lost due to matplotlib's truncation # (https://github.com/matplotlib/matplotlib/issues/15363) fig.set_size_inches( (self.width * self.scale + 1e-2) / self.dpi, (self.height * self.scale + 1e-2) / self.dpi, ) self.canvas = FigureCanvasAgg(fig) # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) ax.axis("off") self.fig = fig self.ax = ax self.reset_image(img) def reset_image(self, img): """ Args: img: same as in __init__ """ img = img.astype("uint8") self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") def save(self, filepath): """ Args: filepath (str): a string that contains the absolute path, including the file name, where the visualized image will be saved. """ self.fig.savefig(filepath) def get_image(self): """ Returns: ndarray: the visualized image of shape (H, W, 3) (RGB) in uint8 type. The shape is scaled w.r.t the input image using the given `scale` argument. """ canvas = self.canvas s, (width, height) = canvas.print_to_buffer() # buf = io.BytesIO() # works for cairo backend # canvas.print_rgba(buf) # width, height = self.width, self.height # s = buf.getvalue() buffer = np.frombuffer(s, dtype="uint8") img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype("uint8") class Visualizer: """ Visualizer that draws data about detection/segmentation on images. It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` that draw primitive objects to images, as well as high-level wrappers like `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` that draw composite data in some pre-defined style. Note that the exact visualization style for the high-level wrappers are subject to change. Style such as color, opacity, label contents, visibility of labels, or even the visibility of objects themselves (e.g. when the object is too small) may change according to different heuristics, as long as the results still look visually reasonable. To obtain a consistent style, you can implement custom drawing functions with the abovementioned primitive methods instead. If you need more customized visualization styles, you can process the data yourself following their format documented in tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not intend to satisfy everyone's preference on drawing styles. This visualizer focuses on high rendering quality rather than performance. It is not designed to be used for real-time applications. """ # TODO implement a fast, rasterized version using OpenCV def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): """ Args: img_rgb: a numpy array of shape (H, W, C), where H and W correspond to the height and width of the image respectively. C is the number of color channels. The image is required to be in RGB format since that is a requirement of the Matplotlib library. The image is also expected to be in the range [0, 255]. metadata (Metadata): dataset metadata (e.g. class names and colors) instance_mode (ColorMode): defines one of the pre-defined style for drawing instances on an image. """ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) if metadata is None: metadata = MetadataCatalog.get("__nonexist__") self.metadata = metadata self.output = VisImage(self.img, scale=scale) self.cpu_device = torch.device("cpu") # too small texts are useless, therefore clamp to 9 self._default_font_size = max( np.sqrt(self.output.height * self.output.width) // 90, 10 // scale ) self._instance_mode = instance_mode self.keypoint_threshold = _KEYPOINT_THRESHOLD def draw_instance_predictions(self, predictions): """ Draw instance-level prediction results on an image. Args: predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). Returns: output (VisImage): image object with visualizations. """ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None scores = predictions.scores if predictions.has("scores") else None classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None if predictions.has("pred_masks"): masks = np.asarray(predictions.pred_masks) masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] else: masks = None if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes ] alpha = 0.8 else: colors = None alpha = 0.5 if self._instance_mode == ColorMode.IMAGE_BW: self.output.reset_image( self._create_grayscale_image( (predictions.pred_masks.any(dim=0) > 0).numpy() if predictions.has("pred_masks") else None ) ) alpha = 0.3 self.overlay_instances( masks=masks, boxes=boxes, labels=labels, keypoints=keypoints, assigned_colors=colors, alpha=alpha, ) return self.output def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): """ Draw semantic segmentation predictions/labels. Args: sem_seg (Tensor or ndarray): the segmentation of shape (H, W). Each value is the integer label of the pixel. area_threshold (int): segments with less than `area_threshold` are not drawn. alpha (float): the larger it is, the more opaque the segmentations are. Returns: output (VisImage): image object with visualizations. """ if isinstance(sem_seg, torch.Tensor): sem_seg = sem_seg.numpy() labels, areas = np.unique(sem_seg, return_counts=True) sorted_idxs = np.argsort(-areas).tolist() labels = labels[sorted_idxs] for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): try: mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] except (AttributeError, IndexError): mask_color = None binary_mask = (sem_seg == label).astype(np.uint8) text = self.metadata.stuff_classes[label] self.draw_binary_mask( binary_mask, color=mask_color, edge_color=_OFF_WHITE, text=text, alpha=alpha, area_threshold=area_threshold, ) return self.output def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): """ Draw panoptic prediction annotations or results. Args: panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. If it is a ``list[dict]``, each dict contains keys "id", "category_id". If None, category id of each pixel is computed by ``pixel // metadata.label_divisor``. area_threshold (int): stuff segments with less than `area_threshold` are not drawn. Returns: output (VisImage): image object with visualizations. """ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) if self._instance_mode == ColorMode.IMAGE_BW: self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) # draw mask for all semantic segments first i.e. "stuff" for mask, sinfo in pred.semantic_masks(): category_idx = sinfo["category_id"] try: mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] except AttributeError: mask_color = None text = self.metadata.stuff_classes[category_idx] self.draw_binary_mask( mask, color=mask_color, edge_color=_OFF_WHITE, text=text, alpha=alpha, area_threshold=area_threshold, ) # draw mask for all instances second all_instances = list(pred.instance_masks()) if len(all_instances) == 0: return self.output masks, sinfo = list(zip(*all_instances)) category_ids = [x["category_id"] for x in sinfo] try: scores = [x["score"] for x in sinfo] except KeyError: scores = None labels = _create_text_labels( category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo] ) try: colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids ] except AttributeError: colors = None self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) return self.output draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility def draw_dataset_dict(self, dic): """ Draw annotations/segmentaions in Detectron2 Dataset format. Args: dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. Returns: output (VisImage): image object with visualizations. """ annos = dic.get("annotations", None) if annos: if "segmentation" in annos[0]: masks = [x["segmentation"] for x in annos] else: masks = None if "keypoints" in annos[0]: keypts = [x["keypoints"] for x in annos] keypts = np.array(keypts).reshape(len(annos), -1, 3) else: keypts = None boxes = [ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) if len(x["bbox"]) == 4 else x["bbox"] for x in annos ] colors = None category_ids = [x["category_id"] for x in annos] if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): colors = [ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids ] names = self.metadata.get("thing_classes", None) labels = _create_text_labels( category_ids, scores=None, class_names=names, is_crowd=[x.get("iscrowd", 0) for x in annos], ) self.overlay_instances( labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors ) sem_seg = dic.get("sem_seg", None) if sem_seg is None and "sem_seg_file_name" in dic: with PathManager.open(dic["sem_seg_file_name"], "rb") as f: sem_seg = Image.open(f) sem_seg = np.asarray(sem_seg, dtype="uint8") if sem_seg is not None: self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5) pan_seg = dic.get("pan_seg", None) if pan_seg is None and "pan_seg_file_name" in dic: with PathManager.open(dic["pan_seg_file_name"], "rb") as f: pan_seg = Image.open(f) pan_seg = np.asarray(pan_seg) from panopticapi.utils import rgb2id pan_seg = rgb2id(pan_seg) if pan_seg is not None: segments_info = dic["segments_info"] pan_seg = torch.tensor(pan_seg) self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5) return self.output def overlay_instances( self, *, boxes=None, labels=None, masks=None, keypoints=None, assigned_colors=None, alpha=0.5, ): """ Args: boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, or a :class:`RotatedBoxes`, or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image, labels (list[str]): the text to be displayed for each instance. masks (masks-like object): Supported types are: * :class:`detectron2.structures.PolygonMasks`, :class:`detectron2.structures.BitMasks`. * list[list[ndarray]]: contains the segmentation masks for all objects in one image. The first level of the list corresponds to individual instances. The second level to all the polygon that compose the instance, and the third level to the polygon coordinates. The third level should have the format of [x0, y0, x1, y1, ..., xn, yn] (n >= 3). * list[ndarray]: each ndarray is a binary mask of shape (H, W). * list[dict]: each dict is a COCO-style RLE. keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), where the N is the number of instances and K is the number of keypoints. The last dimension corresponds to (x, y, visibility or score). assigned_colors (list[matplotlib.colors]): a list of colors, where each color corresponds to each mask or box in the image. Refer to 'matplotlib.colors' for full list of formats that the colors are accepted in. Returns: output (VisImage): image object with visualizations. """ num_instances = 0 if boxes is not None: boxes = self._convert_boxes(boxes) num_instances = len(boxes) if masks is not None: masks = self._convert_masks(masks) if num_instances: assert len(masks) == num_instances else: num_instances = len(masks) if keypoints is not None: if num_instances: assert len(keypoints) == num_instances else: num_instances = len(keypoints) keypoints = self._convert_keypoints(keypoints) if labels is not None: assert len(labels) == num_instances if assigned_colors is None: assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] if num_instances == 0: return self.output if boxes is not None and boxes.shape[1] == 5: return self.overlay_rotated_instances( boxes=boxes, labels=labels, assigned_colors=assigned_colors ) # Display in largest to smallest order to reduce occlusion. areas = None if boxes is not None: areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) elif masks is not None: areas = np.asarray([x.area() for x in masks]) if areas is not None: sorted_idxs = np.argsort(-areas).tolist() # Re-order overlapped instances in descending order. boxes = boxes[sorted_idxs] if boxes is not None else None labels = [labels[k] for k in sorted_idxs] if labels is not None else None masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] keypoints = keypoints[sorted_idxs] if keypoints is not None else None for i in range(num_instances): color = assigned_colors[i] if boxes is not None: self.draw_box(boxes[i], edge_color=color) if masks is not None: for segment in masks[i].polygons: self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) if labels is not None: # first get a box if boxes is not None: x0, y0, x1, y1 = boxes[i] text_pos = (x0, y0) # if drawing boxes, put text on the box corner. horiz_align = "left" elif masks is not None: # skip small mask without polygon if len(masks[i].polygons) == 0: continue x0, y0, x1, y1 = masks[i].bbox() # draw text in the center (defined by median) when box is not drawn # median is less sensitive to outliers. text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] horiz_align = "center" else: continue # drawing the box confidence for keypoints isn't very useful. # for small objects, draw text at the side to avoid occlusion instance_area = (y1 - y0) * (x1 - x0) if ( instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale or y1 - y0 < 40 * self.output.scale ): if y1 >= self.output.height - 5: text_pos = (x1, y0) else: text_pos = (x0, y1) height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) lighter_color = self._change_color_brightness(color, brightness_factor=0.7) font_size = ( np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size ) self.draw_text( labels[i], text_pos, color=lighter_color, horizontal_alignment=horiz_align, font_size=font_size, ) # draw keypoints if keypoints is not None: for keypoints_per_instance in keypoints: self.draw_and_connect_keypoints(keypoints_per_instance) return self.output def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): """ Args: boxes (ndarray): an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image. labels (list[str]): the text to be displayed for each instance. assigned_colors (list[matplotlib.colors]): a list of colors, where each color corresponds to each mask or box in the image. Refer to 'matplotlib.colors' for full list of formats that the colors are accepted in. Returns: output (VisImage): image object with visualizations. """ num_instances = len(boxes) if assigned_colors is None: assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] if num_instances == 0: return self.output # Display in largest to smallest order to reduce occlusion. if boxes is not None: areas = boxes[:, 2] * boxes[:, 3] sorted_idxs = np.argsort(-areas).tolist() # Re-order overlapped instances in descending order. boxes = boxes[sorted_idxs] labels = [labels[k] for k in sorted_idxs] if labels is not None else None colors = [assigned_colors[idx] for idx in sorted_idxs] for i in range(num_instances): self.draw_rotated_box_with_label( boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None ) return self.output def draw_and_connect_keypoints(self, keypoints): """ Draws keypoints of an instance and follows the rules for keypoint connections to draw lines between appropriate keypoints. This follows color heuristics for line color. Args: keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints and the last dimension corresponds to (x, y, probability). Returns: output (VisImage): image object with visualizations. """ visible = {} keypoint_names = self.metadata.get("keypoint_names") for idx, keypoint in enumerate(keypoints): # draw keypoint x, y, prob = keypoint if prob > self.keypoint_threshold: self.draw_circle((x, y), color=_RED) if keypoint_names: keypoint_name = keypoint_names[idx] visible[keypoint_name] = (x, y) if self.metadata.get("keypoint_connection_rules"): for kp0, kp1, color in self.metadata.keypoint_connection_rules: if kp0 in visible and kp1 in visible: x0, y0 = visible[kp0] x1, y1 = visible[kp1] color = tuple(x / 255.0 for x in color) self.draw_line([x0, x1], [y0, y1], color=color) # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip # Note that this strategy is specific to person keypoints. # For other keypoints, it should just do nothing try: ls_x, ls_y = visible["left_shoulder"] rs_x, rs_y = visible["right_shoulder"] mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 except KeyError: pass else: # draw line from nose to mid-shoulder nose_x, nose_y = visible.get("nose", (None, None)) if nose_x is not None: self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) try: # draw line from mid-shoulder to mid-hip lh_x, lh_y = visible["left_hip"] rh_x, rh_y = visible["right_hip"] except KeyError: pass else: mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) return self.output """ Primitive drawing functions: """ def draw_text( self, text, position, *, font_size=None, color="g", horizontal_alignment="center", rotation=0, ): """ Args: text (str): class label position (tuple): a tuple of the x and y coordinates to place text on image. font_size (int, optional): font of the text. If not provided, a font size proportional to the image width is calculated and used. color: color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. horizontal_alignment (str): see `matplotlib.text.Text` rotation: rotation angle in degrees CCW Returns: output (VisImage): image object with text drawn. """ if not font_size: font_size = self._default_font_size # since the text background is dark, we don't want the text to be dark color = np.maximum(list(mplc.to_rgb(color)), 0.2) color[np.argmax(color)] = max(0.8, np.max(color)) x, y = position self.output.ax.text( x, y, text, size=font_size * self.output.scale, family="sans-serif", bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, verticalalignment="top", horizontalalignment=horizontal_alignment, color=color, zorder=10, rotation=rotation, ) return self.output def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): """ Args: box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 are the coordinates of the image's top left corner. x1 and y1 are the coordinates of the image's bottom right corner. alpha (float): blending efficient. Smaller values lead to more transparent masks. edge_color: color of the outline of the box. Refer to `matplotlib.colors` for full list of formats that are accepted. line_style (string): the string to use to create the outline of the boxes. Returns: output (VisImage): image object with box drawn. """ x0, y0, x1, y1 = box_coord width = x1 - x0 height = y1 - y0 linewidth = max(self._default_font_size / 4, 1) self.output.ax.add_patch( mpl.patches.Rectangle( (x0, y0), width, height, fill=False, edgecolor=edge_color, linewidth=linewidth * self.output.scale, alpha=alpha, linestyle=line_style, ) ) return self.output def draw_rotated_box_with_label( self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None ): """ Draw a rotated box with label on its top-left corner. Args: rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), where cnt_x and cnt_y are the center coordinates of the box. w and h are the width and height of the box. angle represents how many degrees the box is rotated CCW with regard to the 0-degree box. alpha (float): blending efficient. Smaller values lead to more transparent masks. edge_color: color of the outline of the box. Refer to `matplotlib.colors` for full list of formats that are accepted. line_style (string): the string to use to create the outline of the boxes. label (string): label for rotated box. It will not be rendered when set to None. Returns: output (VisImage): image object with box drawn. """ cnt_x, cnt_y, w, h, angle = rotated_box area = w * h # use thinner lines when the box is small linewidth = self._default_font_size / ( 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 ) theta = angle * math.pi / 180.0 c = math.cos(theta) s = math.sin(theta) rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] # x: left->right ; y: top->down rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] for k in range(4): j = (k + 1) % 4 self.draw_line( [rotated_rect[k][0], rotated_rect[j][0]], [rotated_rect[k][1], rotated_rect[j][1]], color=edge_color, linestyle="--" if k == 1 else line_style, linewidth=linewidth, ) if label is not None: text_pos = rotated_rect[1] # topleft corner height_ratio = h / np.sqrt(self.output.height * self.output.width) label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) font_size = ( np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size ) self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) return self.output def draw_circle(self, circle_coord, color, radius=3): """ Args: circle_coord (list(int) or tuple(int)): contains the x and y coordinates of the center of the circle. color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. radius (int): radius of the circle. Returns: output (VisImage): image object with box drawn. """ x, y = circle_coord self.output.ax.add_patch( mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) ) return self.output def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): """ Args: x_data (list[int]): a list containing x values of all the points being drawn. Length of list should match the length of y_data. y_data (list[int]): a list containing y values of all the points being drawn. Length of list should match the length of x_data. color: color of the line. Refer to `matplotlib.colors` for a full list of formats that are accepted. linestyle: style of the line. Refer to `matplotlib.lines.Line2D` for a full list of formats that are accepted. linewidth (float or None): width of the line. When it's None, a default value will be computed and used. Returns: output (VisImage): image object with line drawn. """ if linewidth is None: linewidth = self._default_font_size / 3 linewidth = max(linewidth, 1) self.output.ax.add_line( mpl.lines.Line2D( x_data, y_data, linewidth=linewidth * self.output.scale, color=color, linestyle=linestyle, ) ) return self.output def draw_binary_mask( self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0 ): """ Args: binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and W is the image width. Each value in the array is either a 0 or 1 value of uint8 type. color: color of the mask. Refer to `matplotlib.colors` for a full list of formats that are accepted. If None, will pick a random color. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a full list of formats that are accepted. text (str): if None, will be drawn in the object's center of mass. alpha (float): blending efficient. Smaller values lead to more transparent masks. area_threshold (float): a connected component small than this will not be shown. Returns: output (VisImage): image object with mask drawn. """ if color is None: color = random_color(rgb=True, maximum=1) color = mplc.to_rgb(color) has_valid_segment = False binary_mask = binary_mask.astype("uint8") # opencv needs uint8 mask = GenericMask(binary_mask, self.output.height, self.output.width) shape2d = (binary_mask.shape[0], binary_mask.shape[1]) if not mask.has_holes: # draw polygons for regular masks for segment in mask.polygons: area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) if area < (area_threshold or 0): continue has_valid_segment = True segment = segment.reshape(-1, 2) self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) else: # TODO: Use Path/PathPatch to draw vector graphics: # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon rgba = np.zeros(shape2d + (4,), dtype="float32") rgba[:, :, :3] = color rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha has_valid_segment = True self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) if text is not None and has_valid_segment: # TODO sometimes drawn on wrong objects. the heuristics here can improve. lighter_color = self._change_color_brightness(color, brightness_factor=0.7) _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) largest_component_id = np.argmax(stats[1:, -1]) + 1 # draw text on the largest component, as well as other very large components. for cid in range(1, _num_cc): if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: # median is more stable than centroid # center = centroids[largest_component_id] center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] self.draw_text(text, center, color=lighter_color) return self.output def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): """ Args: segment: numpy array of shape Nx2, containing all the points in the polygon. color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a full list of formats that are accepted. If not provided, a darker shade of the polygon color will be used instead. alpha (float): blending efficient. Smaller values lead to more transparent masks. Returns: output (VisImage): image object with polygon drawn. """ if edge_color is None: # make edge color darker than the polygon color if alpha > 0.8: edge_color = self._change_color_brightness(color, brightness_factor=-0.7) else: edge_color = color edge_color = mplc.to_rgb(edge_color) + (1,) polygon = mpl.patches.Polygon( segment, fill=True, facecolor=mplc.to_rgb(color) + (alpha,), edgecolor=edge_color, linewidth=max(self._default_font_size // 15 * self.output.scale, 1), ) self.output.ax.add_patch(polygon) return self.output """ Internal methods: """ def _jitter(self, color): """ Randomly modifies given color to produce a slightly different color than the color given. Args: color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color picked. The values in the list are in the [0.0, 1.0] range. Returns: jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color after being jittered. The values in the list are in the [0.0, 1.0] range. """ color = mplc.to_rgb(color) vec = np.random.rand(3) # better to do it in another color space vec = vec / np.linalg.norm(vec) * 0.5 res = np.clip(vec + color, 0, 1) return tuple(res) def _create_grayscale_image(self, mask=None): """ Create a grayscale version of the original image. The colors in masked area, if given, will be kept. """ img_bw = self.img.astype("f4").mean(axis=2) img_bw = np.stack([img_bw] * 3, axis=2) if mask is not None: img_bw[mask] = self.img[mask] return img_bw def _change_color_brightness(self, color, brightness_factor): """ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with less or more saturation than the original color. Args: color: color of the polygon. Refer to `matplotlib.colors` for a full list of formats that are accepted. brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of 0 will correspond to no change, a factor in [-1.0, 0) range will result in a darker color and a factor in (0, 1.0] range will result in a lighter color. Returns: modified_color (tuple[double]): a tuple containing the RGB values of the modified color. Each value in the tuple is in the [0.0, 1.0] range. """ assert brightness_factor >= -1.0 and brightness_factor <= 1.0 color = mplc.to_rgb(color) polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) return modified_color def _convert_boxes(self, boxes): """ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. """ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): return boxes.tensor.detach().numpy() else: return np.asarray(boxes) def _convert_masks(self, masks_or_polygons): """ Convert different format of masks or polygons to a tuple of masks and polygons. Returns: list[GenericMask]: """ m = masks_or_polygons if isinstance(m, PolygonMasks): m = m.polygons if isinstance(m, BitMasks): m = m.tensor.numpy() if isinstance(m, torch.Tensor): m = m.numpy() ret = [] for x in m: if isinstance(x, GenericMask): ret.append(x) else: ret.append(GenericMask(x, self.output.height, self.output.width)) return ret def _convert_keypoints(self, keypoints): if isinstance(keypoints, Keypoints): keypoints = keypoints.tensor keypoints = np.asarray(keypoints) return keypoints def get_output(self): """ Returns: output (VisImage): the image output containing the visualizations added to the image. """ return self.output ================================================ FILE: pdf_extract_kit/tasks/layout_detection/models/yolo.py ================================================ import os import cv2 import torch from pdf_extract_kit.registry import MODEL_REGISTRY from pdf_extract_kit.utils.visualization import visualize_bbox from pdf_extract_kit.dataset.dataset import ImageDataset @MODEL_REGISTRY.register('layout_detection_yolo') class LayoutDetectionYOLO: def __init__(self, config): """ Initialize the LayoutDetectionYOLO class. Args: config (dict): Configuration dictionary containing model parameters. """ # Mapping from class IDs to class names self.id_to_names = { 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } # Load the YOLO model from the specified path try: from doclayout_yolo import YOLOv10 self.model = YOLOv10(config['model_path']) except AttributeError: from ultralytics import YOLO self.model = YOLO(config['model_path']) # Set model parameters self.img_size = config.get('img_size', 1280) self.conf_thres = config.get('conf_thres', 0.25) self.iou_thres = config.get('iou_thres', 0.45) self.visualize = config.get('visualize', False) self.nc = config.get('nc', 10) self.workers = config.get('workers', 8) self.device = config.get('device', 'cpu') if self.iou_thres > 0: import torchvision self.nms_func = torchvision.ops.nms def predict(self, images, result_path, image_ids=None): """ Predict formulas in images. Args: images (list): List of images to be predicted. result_path (str): Path to save the prediction results. image_ids (list, optional): List of image IDs corresponding to the images. Returns: list: List of prediction results. """ results = [] for idx, image in enumerate(images): result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False, device=self.device)[0] if self.visualize: if not os.path.exists(result_path): os.makedirs(result_path) boxes = result.__dict__['boxes'].xyxy classes = result.__dict__['boxes'].cls scores = result.__dict__['boxes'].conf if self.iou_thres > 0: indices = self.nms_func(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=self.iou_thres) boxes, scores, classes = boxes[indices], scores[indices], classes[indices] if len(boxes.shape) == 1: boxes = np.expand_dims(boxes, 0) scores = np.expand_dims(scores, 0) classes = np.expand_dims(classes, 0) vis_result = visualize_bbox(image, boxes, classes, scores, self.id_to_names) # Determine the base name of the image if image_ids: base_name = image_ids[idx] else: # base_name = os.path.basename(image) base_name = os.path.splitext(os.path.basename(image))[0] # Remove file extension result_name = f"{base_name}_layout.png" # Save the visualized result cv2.imwrite(os.path.join(result_path, result_name), vis_result) results.append(result) return results ================================================ FILE: pdf_extract_kit/tasks/layout_detection/task.py ================================================ from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("layout_detection") class LayoutDetectionTask(BaseTask): def __init__(self, model): super().__init__(model) def predict_images(self, input_data, result_path): """ Predict layouts in images. Args: input_data (str): Path to a single image file or a directory containing image files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ images = self.load_images(input_data) # Perform detection return self.model.predict(images, result_path) def predict_pdfs(self, input_data, result_path): """ Predict layouts in PDF files. Args: input_data (str): Path to a single PDF file or a directory containing PDF files. result_path (str): Path to save the prediction results. Returns: list: List of prediction results. """ pdf_images = self.load_pdf_images(input_data) # Perform detection return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys())) ================================================ FILE: pdf_extract_kit/tasks/ocr/__init__.py ================================================ from pdf_extract_kit.tasks.ocr.models.paddle_ocr import ModifiedPaddleOCR # from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "ModifiedPaddleOCR", ] ================================================ FILE: pdf_extract_kit/tasks/ocr/models/paddle_ocr.py ================================================ import time import copy import logging import base64 import cv2 import numpy as np from io import BytesIO from PIL import Image from paddleocr import PaddleOCR from ppocr.utils.logging import get_logger from ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop from pdf_extract_kit.registry import MODEL_REGISTRY logger = get_logger() def img_decode(content: bytes): np_arr = np.frombuffer(content, dtype=np.uint8) return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) def check_img(img): if isinstance(img, bytes): img = img_decode(img) if isinstance(img, str): image_file = img img, flag_gif, flag_pdf = check_and_read(image_file) if not flag_gif and not flag_pdf: with open(image_file, 'rb') as f: img_str = f.read() img = img_decode(img_str) if img is None: try: buf = BytesIO() image = BytesIO(img_str) im = Image.open(image) rgb = im.convert('RGB') rgb.save(buf, 'jpeg') buf.seek(0) image_bytes = buf.read() data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8") image_decode = base64.b64decode(data_base64) img_array = np.frombuffer(image_decode, np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) except: logger.error("error in loading image:{}".format(image_file)) return None if img is None: logger.error("error in loading image:{}".format(image_file)) return None if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if isinstance(img, Image.Image): img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) return img def sorted_boxes(dt_boxes): """ Sort text boxes in order from top to bottom, left to right args: dt_boxes(array):detected text boxes with shape [4, 2] return: sorted boxes(array) with shape [4, 2] """ num_boxes = dt_boxes.shape[0] sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) _boxes = list(sorted_boxes) for i in range(num_boxes - 1): for j in range(i, -1, -1): if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ (_boxes[j + 1][0][0] < _boxes[j][0][0]): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp else: break return _boxes def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): """Check if two bounding boxes overlap on the y-axis, and if the height of the overlapping region exceeds 80% of the height of the shorter bounding box.""" _, y0_1, _, y1_1 = bbox1 _, y0_2, _, y1_2 = bbox2 overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) height1, height2 = y1_1 - y0_1, y1_2 - y0_2 max_height = max(height1, height2) min_height = min(height1, height2) return (overlap / min_height) > overlap_ratio_threshold def bbox_to_points(bbox): """ change bbox(shape: N * 4) to polygon(shape: N * 8) """ x0, y0, x1, y1 = bbox return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') def points_to_bbox(points): """ change polygon(shape: N * 8) to bbox(shape: N * 4) """ x0, y0 = points[0] x1, _ = points[1] _, y1 = points[2] return [x0, y0, x1, y1] def merge_intervals(intervals): # Sort the intervals based on the start value intervals.sort(key=lambda x: x[0]) merged = [] for interval in intervals: # If the list of merged intervals is empty or if the current # interval does not overlap with the previous, simply append it. if not merged or merged[-1][1] < interval[0]: merged.append(interval) else: # Otherwise, there is overlap, so we merge the current and previous intervals. merged[-1][1] = max(merged[-1][1], interval[1]) return merged def remove_intervals(original, masks): # Merge all mask intervals merged_masks = merge_intervals(masks) result = [] original_start, original_end = original for mask in merged_masks: mask_start, mask_end = mask # If the mask starts after the original range, ignore it if mask_start > original_end: continue # If the mask ends before the original range starts, ignore it if mask_end < original_start: continue # Remove the masked part from the original range if original_start < mask_start: result.append([original_start, mask_start - 1]) original_start = max(mask_end + 1, original_start) # Add the remaining part of the original range, if any if original_start <= original_end: result.append([original_start, original_end]) return result def update_det_boxes(dt_boxes, mfd_res): new_dt_boxes = [] for text_box in dt_boxes: text_bbox = points_to_bbox(text_box) masks_list = [] for mf_box in mfd_res: mf_bbox = mf_box['bbox'] if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): masks_list.append([mf_bbox[0], mf_bbox[2]]) text_x_range = [text_bbox[0], text_bbox[2]] text_remove_mask_range = remove_intervals(text_x_range, masks_list) temp_dt_box = [] for text_remove_mask in text_remove_mask_range: temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) if len(temp_dt_box) > 0: new_dt_boxes.extend(temp_dt_box) return new_dt_boxes def merge_spans_to_line(spans): """ Merge given spans into lines. Spans are considered based on their position in the document. If spans overlap sufficiently on the Y-axis, they are merged into the same line; otherwise, a new line is started. Parameters: spans (list): A list of spans, where each span is a dictionary containing at least the key 'bbox', which itself is a list of four integers representing the bounding box: [x0, y0, x1, y1], where (x0, y0) is the top-left corner and (x1, y1) is the bottom-right corner. Returns: list: A list of lines, where each line is a list of spans. """ # Return an empty list if the spans list is empty if len(spans) == 0: return [] else: # Sort spans by the Y0 coordinate spans.sort(key=lambda span: span['bbox'][1]) lines = [] current_line = [spans[0]] for span in spans[1:]: # If the current span overlaps with the last span in the current line on the Y-axis, add it to the current line if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']): current_line.append(span) else: # Otherwise, start a new line lines.append(current_line) current_line = [span] # Add the last line if it exists if current_line: lines.append(current_line) return lines def merge_overlapping_spans(spans): """ Merges overlapping spans on the same line. :param spans: A list of span coordinates [(x1, y1, x2, y2), ...] :return: A list of merged spans """ # Return an empty list if the input spans list is empty if not spans: return [] # Sort spans by their starting x-coordinate spans.sort(key=lambda x: x[0]) # Initialize the list of merged spans merged = [] for span in spans: # Unpack span coordinates x1, y1, x2, y2 = span # If the merged list is empty or there's no horizontal overlap, add the span directly if not merged or merged[-1][2] < x1: merged.append(span) else: # If there is horizontal overlap, merge the current span with the previous one last_span = merged.pop() # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2) x1 = min(last_span[0], x1) y1 = min(last_span[1], y1) x2 = max(last_span[2], x2) y2 = max(last_span[3], y2) # Add the merged span back to the list merged.append((x1, y1, x2, y2)) # Return the list of merged spans return merged def merge_det_boxes(dt_boxes): """ Merge detection boxes. This function takes a list of detected bounding boxes, each represented by four corner points. The goal is to merge these bounding boxes into larger text regions. Parameters: dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points. Returns: list: A list containing the merged text regions, where each region is represented by four corner points. """ # Convert the detection boxes into a dictionary format with bounding boxes and type dt_boxes_dict_list = [] for text_box in dt_boxes: text_bbox = points_to_bbox(text_box) text_box_dict = { 'bbox': text_bbox, } dt_boxes_dict_list.append(text_box_dict) # Merge adjacent text regions into lines lines = merge_spans_to_line(dt_boxes_dict_list) # Initialize a new list for storing the merged text regions new_dt_boxes = [] for line in lines: line_bbox_list = [] for span in line: line_bbox_list.append(span['bbox']) # Merge overlapping text regions within the same line merged_spans = merge_overlapping_spans(line_bbox_list) # Convert the merged text regions back to point format and add them to the new detection box list for span in merged_spans: new_dt_boxes.append(bbox_to_points(span)) return new_dt_boxes @MODEL_REGISTRY.register('ocr_ppocr') class ModifiedPaddleOCR(PaddleOCR): def __init__(self, config): super().__init__(**config) def predict(self, img, **kwargs): ppocr_res = self.ocr(img, **kwargs)[0] ocr_res = [] for box_ocr_res in ppocr_res: p1, p2, p3, p4 = box_ocr_res[0] text, score = box_ocr_res[1] ocr_res.append({ "category_type": "text", 'poly': p1 + p2 + p3 + p4, 'score': round(score, 2), 'text': text, }) return ocr_res def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)): """ OCR with PaddleOCR args: img: img for OCR, support ndarray, img_path and list or ndarray det: use text detection or not. If False, only rec will be exec. Default is True rec: use text recognition or not. If False, only det will be exec. Default is True cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False. bin: binarize image to black and white. Default is False. inv: invert image colors. Default is False. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white. """ assert isinstance(img, (np.ndarray, list, str, bytes, Image.Image)) if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) if cls == True and self.use_angle_cls == False: logger.warning( 'Since the angle classifier is not initialized, it will not be used during the forward process' ) img = check_img(img) # for infer pdf file if isinstance(img, list): if self.page_num > len(img) or self.page_num == 0: self.page_num = len(img) imgs = img[:self.page_num] else: imgs = [img] def preprocess_image(_image): _image = alpha_to_color(_image, alpha_color) if inv: _image = cv2.bitwise_not(_image) if bin: _image = binarize_img(_image) return _image if det and rec: ocr_res = [] for idx, img in enumerate(imgs): img = preprocess_image(img) dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res) if not dt_boxes and not rec_res: ocr_res.append(None) continue tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] ocr_res.append(tmp_res) return ocr_res elif det and not rec: ocr_res = [] for idx, img in enumerate(imgs): img = preprocess_image(img) dt_boxes, elapse = self.text_detector(img) if not dt_boxes: ocr_res.append(None) continue tmp_res = [box.tolist() for box in dt_boxes] ocr_res.append(tmp_res) return ocr_res else: ocr_res = [] cls_res = [] for idx, img in enumerate(imgs): if not isinstance(img, list): img = preprocess_image(img) img = [img] if self.use_angle_cls and cls: img, cls_res_tmp, elapse = self.text_classifier(img) if not rec: cls_res.append(cls_res_tmp) rec_res, elapse = self.text_recognizer(img) ocr_res.append(rec_res) if not rec: return cls_res return ocr_res def __call__(self, img, cls=True, mfd_res=None): time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} if img is None: logger.debug("no valid image provided") return None, None, time_dict start = time.time() ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) time_dict['det'] = elapse if dt_boxes is None: logger.debug("no dt_boxes found, elapsed : {}".format(elapse)) end = time.time() time_dict['all'] = end - start return None, None, time_dict else: logger.debug("dt_boxes num : {}, elapsed : {}".format( len(dt_boxes), elapse)) img_crop_list = [] dt_boxes = sorted_boxes(dt_boxes) dt_boxes = merge_det_boxes(dt_boxes) if mfd_res: bef = time.time() dt_boxes = update_det_boxes(dt_boxes, mfd_res) aft = time.time() logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format( len(dt_boxes), aft-bef)) for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) if self.args.det_box_type == "quad": img_crop = get_rotate_crop_image(ori_im, tmp_box) else: img_crop = get_minarea_rect_crop(ori_im, tmp_box) img_crop_list.append(img_crop) if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) time_dict['cls'] = elapse logger.debug("cls num : {}, elapsed : {}".format( len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) time_dict['rec'] = elapse logger.debug("rec_res num : {}, elapsed : {}".format( len(rec_res), elapse)) if self.args.save_crop_res: self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res) filter_boxes, filter_rec_res = [], [] for box, rec_result in zip(dt_boxes, rec_res): text, score = rec_result if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) end = time.time() time_dict['all'] = end - start return filter_boxes, filter_rec_res, time_dict ================================================ FILE: pdf_extract_kit/tasks/ocr/task.py ================================================ import os import json import random from PIL import Image, ImageDraw from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.utils.data_preprocess import load_pdf from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("ocr") class OCRTask(BaseTask): def __init__(self, model): """init the task based on the given model. Args: model: task model, must contains predict function. """ super().__init__(model) def predict_image(self, image): """predict on one image, reture text detection and recognition results. Args: image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict) Returns: List[dict]: list of text bbox with it's content Return example: [ { "category_type": "text", "poly": [ 380.6792698635707, 159.85058512958923, 765.1419999999998, 159.85058512958923, 765.1419999999998, 192.51073013642917, 380.6792698635707, 192.51073013642917 ], "text": "this is an example text", "score": 0.97 }, ... ] """ return self.model.predict(image) def prepare_input_files(self, input_path): if os.path.isdir(input_path): file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)] else: file_list = [input_path] return file_list def process(self, input_path, save_dir=None, visualize=False): file_list = self.prepare_input_files(input_path) res_list = [] for fpath in file_list: basename = os.path.basename(fpath)[:-4] if fpath.endswith(".pdf") or fpath.endswith(".PDF"): images = load_pdf(fpath) pdf_res = [] for page, img in enumerate(images): page_res = self.predict_image(img) pdf_res.append(page_res) if save_dir: os.makedirs(os.path.join(save_dir, basename), exist_ok=True) self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json")) if visualize: self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg")) res_list.append(pdf_res) else: image = Image.open(fpath) img_res = self.predict_image(image) res_list.append(img_res) if save_dir: os.makedirs(save_dir, exist_ok=True) self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json")) if visualize: self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png")) return res_list def visualize_image(self, image, ocr_res, save_path="", cate2color={}): """plot each result's bbox and category on image. Args: image: PIL.Image.Image ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function save_path: path to save visualized image """ draw = ImageDraw.Draw(image) for res in ocr_res: box_color = cate2color.get(res['category_type'], (0, 255, 0)) x_min, y_min = int(res['poly'][0]), int(res['poly'][1]) x_max, y_max = int(res['poly'][4]), int(res['poly'][5]) draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1) draw.text((x_min, y_min), res['category_type'], (255, 0, 0)) if save_path: image.save(save_path) def save_json_result(self, ocr_res, save_path): """save results to a json file. Args: ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function save_path: path to save visualized image """ with open(save_path, "w", encoding="utf-8") as f: f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False)) ================================================ FILE: pdf_extract_kit/tasks/table_parsing/__init__.py ================================================ from pdf_extract_kit.tasks.table_parsing.models.struct_eqtable import TableParsingStructEqTable from pdf_extract_kit.registry.registry import MODEL_REGISTRY __all__ = [ "TableParsingStructEqTable", ] ================================================ FILE: pdf_extract_kit/tasks/table_parsing/models/struct_eqtable.py ================================================ import torch from PIL import Image from struct_eqtable import build_model from pdf_extract_kit.registry.registry import MODEL_REGISTRY @MODEL_REGISTRY.register("table_parsing_struct_eqtable") class TableParsingStructEqTable: def __init__(self, config): """ Initialize the TableParsingStructEqTable class. Args: config (dict): Configuration dictionary containing model parameters. """ assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model." self.model_dir = config.get('model_path', 'U4R/StructTable-InternVL2-1B') self.max_new_tokens = config.get('max_new_tokens', 1024) self.max_time = config.get('max_time', 30) self.lmdeploy = config.get('lmdeploy', False) self.flash_attn = config.get('flash_attn', True) self.batch_size = config.get('batch_size', 1) self.default_format = config.get('output_format', 'latex') # Load the StructEqTable model self.model = build_model( model_ckpt=self.model_dir, max_new_tokens=self.max_new_tokens, max_time=self.max_time, lmdeploy=self.lmdeploy, flash_attn=self.flash_attn, batch_size=self.batch_size, ).cuda() def predict(self, images, result_path, output_format=None, **kwargs): load_images = [Image.open(image_path) for image_path in images] if output_format is None: output_format = self.default_format else: if output_format not in ['latex', 'markdown', 'html']: raise ValueError(f"Output format {output_format} is not supported.") results = self.model( load_images, output_format=output_format ) return results ================================================ FILE: pdf_extract_kit/tasks/table_parsing/task.py ================================================ from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("table_parsing") class TableParsingTask(BaseTask): def __init__(self, model): super().__init__(model) def predict(self, input_data, result_path, **kwargs): images = self.load_images(input_data) # Perform layout detection on input_data return self.model.predict(images, result_path, **kwargs) ================================================ FILE: pdf_extract_kit/utils/__init__.py ================================================ ================================================ FILE: pdf_extract_kit/utils/config_loader.py ================================================ import yaml import warnings from pdf_extract_kit.registry.registry import TASK_REGISTRY, MODEL_REGISTRY def load_config(config_path): if config_path is None: warnings.warn( ("Configuration path is None. Please provide a valid configuration file path. ") ) return None with open(config_path, 'r') as file: config = yaml.safe_load(file) return config # def initialize_task_and_model(config): # task_name = config['task'] # model_name = config['model'] # model_config = config['model_config'] # TaskClass = TASK_REGISTRY.get(task_name) # ModelClass = MODEL_REGISTRY.get(model_name) # model_instance = ModelClass(model_config) # task_instance = TaskClass(model_instance) # return task_instance def initialize_tasks_and_models(config): task_instances = {} for task_name in config['tasks']: model_name = config['tasks'][task_name]['model'] model_config = config['tasks'][task_name]['model_config'] TaskClass = TASK_REGISTRY.get(task_name) ModelClass = MODEL_REGISTRY.get(model_name) model_instance = ModelClass(model_config) task_instance = TaskClass(model_instance) task_instances[task_name] = task_instance return task_instances ================================================ FILE: pdf_extract_kit/utils/data_preprocess.py ================================================ import fitz from PIL import Image def load_pdf_page(page, dpi): pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72)) image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) if pix.width > 3000 or pix.height > 3000: pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) return image def load_pdf(pdf_path, dpi=144): images = [] doc = fitz.open(pdf_path) for i in range(len(doc)): page = doc[i] image = load_pdf_page(page, dpi) images.append(image) return images ================================================ FILE: pdf_extract_kit/utils/merge_blocks_and_spans.py ================================================ # revised from https://github.com/opendatalab/MinerU/blob/7f0fe20004af7416db886f4b75c116bcc1c986b4/magic_pdf/pdf_parse_union_core.py#L177 # from fast_langdetect import detect_language # import unicodedata import re def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%""" _, y0_1, _, y1_1 = bbox1 _, y0_2, _, y1_2 = bbox2 overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) height1, height2 = y1_1 - y0_1, y1_2 - y0_2 max_height = max(height1, height2) min_height = min(height1, height2) return (overlap / min_height) > overlap_ratio_threshold def merge_spans_to_line(spans): if len(spans) == 0: return [] else: # 按照y0坐标排序 spans.sort(key=lambda span: span['bbox'][1]) lines = [] current_line = [spans[0]] for span in spans[1:]: # 如果当前的span类型为"isolated" 或者 当前行中已经有"isolated" # image和table类型,同上 if span['type'] in ['isolated'] or any( s['type'] in ['isolated'] for s in current_line): # 则开始新行 lines.append(current_line) current_line = [span] continue # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']): current_line.append(span) else: # 否则,开始新行 lines.append(current_line) current_line = [span] # 添加最后一行 if current_line: lines.append(current_line) return lines # 将每一个line中的span从左到右排序 def line_sort_spans_by_left_to_right(lines): line_objects = [] for line in lines: # 按照x0坐标排序 line.sort(key=lambda span: span['bbox'][0]) line_bbox = [ min(span['bbox'][0] for span in line), # x0 min(span['bbox'][1] for span in line), # y0 max(span['bbox'][2] for span in line), # x1 max(span['bbox'][3] for span in line), # y1 ] line_objects.append({ "bbox": line_bbox, "spans": line, }) return line_objects def fix_text_block(block): # 文本block中的公式span都应该转换成行内type for span in block['spans']: if span['type'] == "isolated": span['type'] = "inline" block_lines = merge_spans_to_line(block['spans']) sort_block_lines = line_sort_spans_by_left_to_right(block_lines) block['lines'] = sort_block_lines del block['spans'] return block def fix_interline_block(block): block_lines = merge_spans_to_line(block['spans']) sort_block_lines = line_sort_spans_by_left_to_right(block_lines) block['lines'] = sort_block_lines del block['spans'] return block def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): """ 计算box1和box2的重叠面积占bbox1的比例 """ # Determine the coordinates of the intersection rectangle x_left = max(bbox1[0], bbox2[0]) y_top = max(bbox1[1], bbox2[1]) x_right = min(bbox1[2], bbox2[2]) y_bottom = min(bbox1[3], bbox2[3]) if x_right < x_left or y_bottom < y_top: return 0.0 # The area of overlap area intersection_area = (x_right - x_left) * (y_bottom - y_top) bbox1_area = (bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]) if bbox1_area == 0: return 0 else: return intersection_area / bbox1_area def fill_spans_in_blocks(blocks, spans, radio): ''' 将allspans中的span按位置关系,放入blocks中 ''' block_with_spans = [] for block in blocks: block_type = block["category_type"] L = block['poly'][0] U = block['poly'][1] R = block['poly'][2] D = block['poly'][5] L, R = min(L, R), max(L, R) U, D = min(U, D), max(U, D) block_bbox = [L, U, R, D] block_dict = { 'type': block_type, 'bbox': block_bbox, 'saved_info': block } block_spans = [] for span in spans: span_bbox = span["bbox"] if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio: block_spans.append(span) '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)''' # displayed_list = [] # text_inline_lines = [] # modify_y_axis(block_spans, displayed_list, text_inline_lines) '''模型识别错误的行间公式, type类型转换成行内公式''' # block_spans = modify_inline(block_spans, displayed_list, text_inline_lines) '''bbox去除粘连''' # 去粘连会影响span的bbox,导致后续fill的时候出错 # block_spans = remove_overlap_between_bbox_for_span(block_spans) block_dict['spans'] = block_spans block_with_spans.append(block_dict) # 从spans删除已经放入block_spans中的span if len(block_spans) > 0: for span in block_spans: spans.remove(span) return block_with_spans, spans def fix_block_spans(block_with_spans): ''' 1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系 需要将caption和footnote的text_span放入相应img_block和table_block内的 caption_block和footnote_block中 2、同时需要删除block中的spans字段 ''' fix_blocks = [] for block in block_with_spans: block_type = block['type'] # if block_type == BlockType.Image: # block = fix_image_block(block, img_blocks) # elif block_type == BlockType.Table: # block = fix_table_block(block, table_blocks) if block_type == "isolate_formula": block = fix_interline_block(block) else: block = fix_text_block(block) fix_blocks.append(block) return fix_blocks # def detect_lang(text: str) -> str: # if len(text) == 0: # return "" # try: # lang_upper = detect_language(text) # except: # html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]]) # lang_upper = detect_language(html_no_ctrl_chars) # try: # lang = lang_upper.lower() # except: # lang = "" # return lang def detect_lang(string): """ 检查整个字符串是否包含中文 :param string: 需要检查的字符串 :return: bool """ for ch in string: if u'\u4e00' <= ch <= u'\u9fff': return 'zh' return 'en' def ocr_escape_special_markdown_char(content): """ 转义正文里对markdown语法有特殊意义的字符 """ special_chars = ["*", "`", "~", "$"] for char in special_chars: content = content.replace(char, "\\" + char) return content # def split_long_words(text): # segments = text.split(' ') # for i in range(len(segments)): # words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE) # for j in range(len(words)): # if len(words[j]) > 15: # words[j] = ' '.join(wordninja.split(words[j])) # segments[i] = ''.join(words) # return ' '.join(segments) def merge_para_with_text(para_block): para_text = '' for line in para_block['lines']: line_text = "" line_lang = "" for span in line['spans']: span_type = span['type'] if span_type == "text": line_text += span['content'].strip() if line_text != "": line_lang = detect_lang(line_text) for span in line['spans']: span_type = span['type'] content = '' if span_type == "text": content = span['content'] content = ocr_escape_special_markdown_char(content) # language = detect_lang(content) # if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本 # content = ocr_escape_special_markdown_char(split_long_words(content)) # else: # content = ocr_escape_special_markdown_char(content) elif span_type == 'inline': content = f" ${span['content'].strip('$')}$ " elif span_type == 'ignore-formula': content = f" ${span['content'].strip('$')}$ " elif span_type == 'isolated': content = f"\n$$\n{span['content'].strip('$')}\n$$\n" elif span_type == 'footnote': content_ori = span['content'].strip('$') if '^' in content_ori: content = f" ${content_ori}$ " else: content = f" $^{content_ori}$ " if content != '': if 'zh' in line_lang: # 遇到一些一个字一个span的文档,这种单字语言判断不准,需要用整行文本判断 para_text += content.strip() # 中文语境下,content间不需要空格分隔 else: para_text += content.strip() + ' ' # 英文语境下 content间需要空格分隔 return para_text ================================================ FILE: pdf_extract_kit/utils/pdf_utils.py ================================================ from pdf2image import convert_from_path def load_pdf(pdf_path): images = convert_from_path(pdf_path) return images ================================================ FILE: pdf_extract_kit/utils/visualization.py ================================================ import numpy as np import cv2 from PIL import Image def colormap(N=256, normalized=False): """ Generate the color map. Args: N (int): Number of labels (default is 256). normalized (bool): If True, return colors normalized to [0, 1]. Otherwise, return [0, 255]. Returns: np.ndarray: Color map array of shape (N, 3). """ def bitget(byteval, idx): """ Get the bit value at the specified index. Args: byteval (int): The byte value. idx (int): The index of the bit. Returns: int: The bit value (0 or 1). """ return ((byteval & (1 << idx)) != 0) cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << (7 - j)) g = g | (bitget(c, 1) << (7 - j)) b = b | (bitget(c, 2) << (7 - j)) c = c >> 3 cmap[i] = np.array([r, g, b]) if normalized: cmap = cmap.astype(np.float32) / 255.0 return cmap def visualize_bbox(image_path, bboxes, classes, scores, id_to_names, alpha=0.3): """ Visualize layout detection results on an image. Args: image_path (str): Path to the input image. bboxes (list): List of bounding boxes, each represented as [x_min, y_min, x_max, y_max]. classes (list): List of class IDs corresponding to the bounding boxes. id_to_names (dict): Dictionary mapping class IDs to class names. alpha (float): Transparency factor for the filled color (default is 0.3). Returns: np.ndarray: Image with visualized layout detection results. """ # Check if image_path is a PIL.Image.Image object if isinstance(image_path, Image.Image): image = np.array(image_path) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV else: image = cv2.imread(image_path) overlay = image.copy() cmap = colormap(N=len(id_to_names), normalized=False) # Iterate over each bounding box for i, bbox in enumerate(bboxes): x_min, y_min, x_max, y_max = map(int, bbox) class_id = int(classes[i]) class_name = id_to_names[class_id] text = class_name + f":{scores[i]:.3f}" color = tuple(int(c) for c in cmap[class_id]) cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2) # Add the class name with a background rectangle (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2) cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1) cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2) # Blend the overlay with the original image cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) return image ================================================ FILE: pdf_extract_kit/version.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple __version__ = '0.1.0' short_version = __version__ def parse_version_info(version_str: str) -> Tuple: """Parse version from a string. Args: version_str (str): A string represents a version info. Returns: tuple: A sequence of integer and string represents version. """ _version_info = [] for x in version_str.split('.'): if x.isdigit(): _version_info.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') _version_info.append(int(patch_version[0])) _version_info.append(f'rc{patch_version[1]}') return tuple(_version_info) version_info = parse_version_info(__version__) ================================================ FILE: project/pdf2markdown/README.md ================================================ # PDF2Markdown **Demo:(left: input image; right: rendered markdown.)** ![demo](demo.png) 1. Extract PDF features by these tasks: - Layout Detection: Using the YOLOv8 model for region detection, such as images, tables, titles, text, etc.; - Formula Detection: Using YOLOv8 for detecting formulas, including inline formulas and isolated formulas; - Formula Recognition: Using UniMERNet for formula recognition; - Table Recognition: Using StructEqTable for table recognition; - Optical Character Recognition: Using PaddleOCR for text recognition; 2. Convert features to markdown file: Using simple rules to convert the identified result to markdown (*Note: this is a simply convert code and can only support one-column PDFs, see [MinerU](https://github.com/opendatalab/MinerU) for more complex situation*). # Usage ``` python project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml ``` ================================================ FILE: project/pdf2markdown/configs/pdf2markdown.yaml ================================================ inputs: assets/demo/formula_detection outputs: outputs/pdf2markdown visualize: True merge2markdown: True tasks: layout_detection: model: layout_detection_yolo model_config: img_size: 1024 conf_thres: 0.25 iou_thres: 0.45 model_path: models/Layout/YOLO/doclayout_yolo_ft.pt formula_detection: model: formula_detection_yolo model_config: img_size: 1280 conf_thres: 0.25 iou_thres: 0.45 batch_size: 1 model_path: models/MFD/YOLO/yolo_v8_ft.pt formula_recognition: model: formula_recognition_unimernet model_config: batch_size: 128 cfg_path: pdf_extract_kit/configs/unimernet.yaml model_path: models/MFR/unimernet_tiny ocr: model: ocr_ppocr model_config: lang: ch show_log: True det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec det_db_box_thresh: 0.3 ================================================ FILE: project/pdf2markdown/scripts/pdf2markdown.py ================================================ import os import re import gc import sys import time import torch from PIL import Image, ImageDraw from torchvision import transforms from torch.utils.data import DataLoader sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..')) from pdf_extract_kit.utils.data_preprocess import load_pdf from pdf_extract_kit.tasks.ocr.task import OCRTask from pdf_extract_kit.dataset.dataset import MathDataset from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.utils.merge_blocks_and_spans import ( fill_spans_in_blocks, fix_block_spans, merge_para_with_text ) def latex_rm_whitespace(s: str): """Remove unnecessary whitespace from LaTeX code. """ text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' letter = '[a-zA-Z]' noletter = '[\W_^\d]' names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] s = re.sub(text_reg, lambda match: str(names.pop(0)), s) news = s while True: s = news news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) if news == s: break return s def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) # Create a white background with an additional width and height of 50 crop_new_width = crop_xmax - crop_xmin + padding_x * 2 crop_new_height = crop_ymax - crop_ymin + padding_y * 2 return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') # Crop image crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) cropped_img = input_pil_img.crop(crop_box) return_image.paste(cropped_img, (padding_x, padding_y)) return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] return return_image, return_list @TASK_REGISTRY.register("pdf2markdown") class PDF2MARKDOWN(OCRTask): def __init__(self, layout_model, mfd_model, mfr_model, ocr_model): self.layout_model = layout_model self.mfd_model = mfd_model self.mfr_model = mfr_model self.ocr_model = ocr_model if self.mfr_model is not None: assert self.mfd_model is not None, "formula recognition based on formula detection, mfd_model can not be None." self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ]) self.color_palette = { 'title': (255, 64, 255), 'plain text': (255, 255, 0), 'abandon': (0, 255, 255), 'figure': (255, 215, 135), 'figure_caption': (215, 0, 95), 'table': (100, 0, 48), 'table_caption': (0, 175, 0), 'table_footnote': (95, 0, 95), 'isolate_formula': (175, 95, 0), 'formula_caption': (95, 95, 0), 'inline': (0, 0, 255), 'isolated': (0, 255, 0), 'text': (255, 0, 0) } def convert_format(self, yolo_res, id_to_names, ): """ convert yolo format to pdf-extract format. """ res_list = [] for xyxy, conf, cla in zip(yolo_res.boxes.xyxy.cpu(), yolo_res.boxes.conf.cpu(), yolo_res.boxes.cls.cpu()): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] new_item = { 'category_type': id_to_names[int(cla.item())], 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], 'score': round(float(conf.item()), 2), } res_list.append(new_item) return res_list def process_single_pdf(self, image_list): """predict on one image, reture text detection and recognition results. Args: image_list: List[PIL.Image.Image] Returns: List[dict]: list of PDF extract results Return example: [ { "layout_dets": [ { "category_type": "text", "poly": [ 380.6792698635707, 159.85058512958923, 765.1419999999998, 159.85058512958923, 765.1419999999998, 192.51073013642917, 380.6792698635707, 192.51073013642917 ], "text": "this is an example text", "score": 0.97 }, ... ], "page_info": { "page_no": 0, "height": 2339, "width": 1654, } }, ... ] """ pdf_extract_res = [] mf_image_list = [] latex_filling_list = [] for idx, image in enumerate(image_list): img_W, img_H = image.size if self.layout_model is not None: ori_layout_res = self.layout_model.predict([image], "")[0] layout_res = self.convert_format(ori_layout_res, self.layout_model.id_to_names) else: layout_res = [] single_page_res = {'layout_dets': layout_res} single_page_res['page_info'] = dict( page_no = idx, height = img_H, width = img_W ) if self.mfd_model is not None: mfd_res = self.mfd_model.predict([image], "")[0] for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] new_item = { 'category_type': self.mfd_model.id_to_names[int(cla.item())], 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], 'score': round(float(conf.item()), 2), 'latex': '', } single_page_res['layout_dets'].append(new_item) if self.mfr_model is not None: latex_filling_list.append(new_item) bbox_img = image.crop((xmin, ymin, xmax, ymax)) mf_image_list.append(bbox_img) pdf_extract_res.append(single_page_res) del mfd_res torch.cuda.empty_cache() gc.collect() # Formula recognition, collect all formula images in whole pdf file, then batch infer them. if self.mfr_model is not None: a = time.time() dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataloader = DataLoader(dataset, batch_size=self.mfr_model.batch_size, num_workers=0) mfr_res = [] for imgs in dataloader: imgs = imgs.to(self.mfr_model.device) output = self.mfr_model.model.generate({'image': imgs}) mfr_res.extend(output['pred_str']) for res, latex in zip(latex_filling_list, mfr_res): res['latex'] = latex_rm_whitespace(latex) b = time.time() print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) # ocr_res = self.ocr_model.predict(image) # ocr and table recognition for idx, image in enumerate(image_list): layout_res = pdf_extract_res[idx]['layout_dets'] pil_img = image.copy() ocr_res_list = [] table_res_list = [] single_page_mfdetrec_res = [] for res in layout_res: if res['category_type'] in self.mfd_model.id_to_names.values(): single_page_mfdetrec_res.append({ "bbox": [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])], }) elif res['category_type'] in [self.layout_model.id_to_names[cid] for cid in [0, 1, 2, 4, 6, 7]]: ocr_res_list.append(res) elif res['category_type'] in [self.layout_model.id_to_names[5]]: table_res_list.append(res) ocr_start = time.time() # Process each area that requires OCR processing for res in ocr_res_list: new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25) paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list # Adjust the coordinates of the formula area adjusted_mfdetrec_res = [] for mf_res in single_page_mfdetrec_res: mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] # Adjust the coordinates of the formula area to the coordinates relative to the cropping area x0 = mf_xmin - xmin + paste_x y0 = mf_ymin - ymin + paste_y x1 = mf_xmax - xmin + paste_x y1 = mf_ymax - ymin + paste_y # Filter formula blocks outside the graph if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): continue else: adjusted_mfdetrec_res.append({ "bbox": [x0, y0, x1, y1], }) # OCR recognition ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] # Integration results if ocr_res: for box_ocr_res in ocr_res: p1, p2, p3, p4 = box_ocr_res[0] text, score = box_ocr_res[1] # Convert the coordinates back to the original coordinate system p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] layout_res.append({ 'category_type': 'text', 'poly': p1 + p2 + p3 + p4, 'score': round(score, 2), 'text': text, }) ocr_cost = round(time.time() - ocr_start, 2) print(f"ocr cost: {ocr_cost}") return pdf_extract_res def order_blocks(self, blocks): def calculate_oder(poly): xmin, ymin, _, _, xmax, ymax, _, _ = poly return ymin*3000 + xmin return sorted(blocks, key=lambda item: calculate_oder(item['poly'])) def convert2md(self, extract_res): blocks = [] spans = [] for item in extract_res['layout_dets']: if item['category_type'] in ['inline', 'text', 'isolated']: text_key = 'text' if item['category_type'] == 'text' else 'latex' xmin, ymin, _, _, xmax, ymax, _, _ = item['poly'] spans.append( { "type": item['category_type'], "bbox": [xmin, ymin, xmax, ymax], "content": item[text_key] } ) if item['category_type'] == "isolated": item['category_type'] = "isolate_formula" blocks.append(item) else: blocks.append(item) blocks_types = ["title", "plain text", "figure_caption", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] need_fix_bbox = [] final_block = [] for block in blocks: block_type = block["category_type"] if block_type in blocks_types: need_fix_bbox.append(block) else: final_block.append(block) block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6) fix_blocks = fix_block_spans(block_with_spans) for para_block in fix_blocks: result = merge_para_with_text(para_block) if para_block['type'] == "isolate_formula": para_block['saved_info']['latex'] = result else: para_block['saved_info']['text'] = result final_block.append(para_block['saved_info']) final_block = self.order_blocks(final_block) md_text = "" for block in final_block: if block['category_type'] == "title": md_text += "\n# "+block['text'] +"\n" elif block['category_type'] in ["isolate_formula"]: md_text += "\n"+block['latex']+"\n" elif block['category_type'] in ["plain text", "figure_caption", "table_caption"]: md_text += " "+block['text']+" " elif block['category_type'] in ["figure", "table"]: continue else: continue return md_text def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False): file_list = self.prepare_input_files(input_path) res_list = [] for fpath in file_list: basename = os.path.basename(fpath)[:-4] if fpath.endswith(".pdf") or fpath.endswith(".PDF"): images = load_pdf(fpath) else: images = [Image.open(fpath)] pdf_extract_res = self.process_single_pdf(images) res_list.append(pdf_extract_res) if save_dir: os.makedirs(save_dir, exist_ok=True) self.save_json_result(pdf_extract_res, os.path.join(save_dir, f"{basename}.json")) if merge2markdown: md_content = [] for extract_res in pdf_extract_res: md_text = self.convert2md(extract_res) md_content.append(md_text) with open(os.path.join(save_dir, f"{basename}.md"), "w") as f: f.write("\n\n".join(md_content)) if visualize: for image, page_res in zip(images, pdf_extract_res): self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette) if fpath.endswith(".pdf") or fpath.endswith(".PDF"): first_page = images.pop(0) first_page.save(os.path.join(save_dir, f'{basename}.pdf'), 'PDF', resolution=100, save_all=True, append_images=images) else: images[0].save(os.path.join(save_dir, f"{basename}.png")) return res_list ================================================ FILE: project/pdf2markdown/scripts/run_project.py ================================================ import os import sys import os.path as osp import argparse from pdf2markdown import PDF2MARKDOWN sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models from pdf_extract_kit.registry.registry import TASK_REGISTRY TASK_NAME = 'pdf2markdown' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs/pdf_extract') visualize = config.get('visualize', False) merge2markdown = config.get('merge2markdown', False) layout_model = task_instances['layout_detection'].model if 'layout_detection' in task_instances else None mfd_model = task_instances['formula_detection'].model if 'formula_detection' in task_instances else None mfr_model = task_instances['formula_recognition'].model if 'formula_recognition' in task_instances else None ocr_model = task_instances['ocr'].model if 'ocr' in task_instances else None pdf_extract_task = TASK_REGISTRY.get(TASK_NAME)(layout_model, mfd_model, mfr_model, ocr_model) extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown) print(f'Task done, results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" [project] name = "pdf-extract-kit" version = "0.1.0" authors = [ { name="Bin Wang", email="ictwangbin@gmail.com" } ] description = "A Comprehensive Toolkit for High-Quality PDF Content Extraction." readme = "README.md" license = { file="LICENSE" } requires-python = ">=3.10" dependencies = [ "PyPDF2", "matplotlib", "pyyaml", "frontend", "pymupdf", opencv-python = "^4.6.0" # Add other common dependencies ] [project.optional-dependencies] layout_detection = [ "transformers", # for layoutlmv3 # Add other dependencies for layout detection ] formula_detection = [ "ultralytics", # for yolov8 # Add other dependencies for formula detection ] # Add additional dependencies for other models ================================================ FILE: requirements/docs.txt ================================================ myst-parser sphinx sphinx-book-theme sphinx-copybutton sphinx-tabs sphinxcontrib-mermaid ================================================ FILE: requirements-cpu.txt ================================================ omegaconf matplotlib PyMuPDF ultralytics>=8.2.85 doclayout-yolo==0.0.2 unimernet==0.2.1 paddlepaddle paddleocr==2.7.3 struct-eqtable ================================================ FILE: requirements.txt ================================================ omegaconf matplotlib PyMuPDF ultralytics>=8.2.85 doclayout-yolo==0.0.2 unimernet==0.2.1 paddlepaddle-gpu paddleocr==2.7.3 struct-eqtable lmdeploy ================================================ FILE: scripts/formula_detection.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks TASK_NAME = 'formula_detection' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # formula_detection_task model_formula_detection = task_instances[TASK_NAME] # for image detection detection_results = model_formula_detection.predict_images(input_data, result_path) # for pdf detection # detection_results = model_formula_detection.predict_pdfs(input_data, result_path) # print(detection_results) print(f'The predicted results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: scripts/formula_recognition.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks TASK_NAME = 'formula_recognition' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # formula_detection_task model_formula_recognition = task_instances[TASK_NAME] # for image detection recognition_results = model_formula_recognition.predict(input_data, result_path) print('Recognition results are as follows:') for id, math in enumerate(recognition_results): print(str(id+1)+': ', math) if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: scripts/layout_detection.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks TASK_NAME = 'layout_detection' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # layout_detection_task model_layout_detection = task_instances[TASK_NAME] # for image detection detection_results = model_layout_detection.predict_images(input_data, result_path) # for pdf detection # detection_results = model_layout_detection.predict_pdfs(input_data, result_path) # print(detection_results) print(f'The predicted results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: scripts/ocr.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks TASK_NAME = 'ocr' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) visualize = config.get('visualize', False) # formula_detection_task task = task_instances[TASK_NAME] detection_results = task.process(input_data, save_dir=result_path, visualize=visualize) print(f'Task done, results can be found at {result_path}') if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: scripts/run_task.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks # 确保所有任务模块被导入 def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # 从配置文件中获取输入数据路径 input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs') # formula_detection_task model_formula_detection = task_instances['formula_detection'] detection_results = model_formula_detection.predict(input_data, result_path) print(detection_results) # formula_recognition_task # model_formula_recognition = task_instances['formula_recognition'] # recognition_results = model_formula_recognition.predict(input_data, result_path) # for id, math in enumerate(recognition_results): # print(str(id+1)+': ', math) # results = task_instance.run(input_data) # print(results) if __name__ == "__main__": args = parse_args() main(args.config) ================================================ FILE: scripts/table_parsing.py ================================================ import os import sys import os.path as osp import argparse sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..')) from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models import pdf_extract_kit.tasks TASK_NAME = 'table_parsing' def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) # get input and output path from config input_data = config.get('inputs', None) result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME) # table_parsing_task model_table_parsing = task_instances[TASK_NAME] # for image detection parsing_results = model_table_parsing.predict(input_data, result_path) print('Table Parsing results are as follows:') for id, result in enumerate(parsing_results): print(str(id+1)+':\n', result) if __name__ == "__main__": args = parse_args() main(args.config)