[
  {
    "path": ".gitignore",
    "content": "*.ipynb*\n*.ipynb\n\n# local data\noutputs/*\ndata/*\ntemp*\ntest*\n\n# python\n.ipynb_checkpoints\n*.ipynb\n**/__pycache__/\n\n# logs\n*.log\n*.out\n\nmodels/*\n\n# Sphinx documentation\ndocs/*/_build/\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "version: 2\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\nformats:\n  - epub\n\npython:\n  install:\n    - requirements: requirements/docs.txt\n\nsphinx:\n  configuration: docs/zh_cn/conf.py\n"
  },
  {
    "path": ".vscode/launch.json",
    "content": "{\n    // 使用 IntelliSense 了解相关属性。 \n    // 悬停以查看现有属性的描述。\n    // 欲了解更多信息，请访问: https://go.microsoft.com/fwlink/?linkid=830387\n    \"version\": \"0.2.0\",\n    \"configurations\": [\n        {\n            \"name\": \"run_mfd\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/run_mfd.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/config_mfd.yaml\" \n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        },\n        {\n            \"name\": \"run_formula_recognition\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/formula_recognition.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/formula_recognition.yaml\"\n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        },\n        {\n            \"name\": \"run_ocr\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/ocr.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/ocr.yaml\"\n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        },\n        {\n            \"name\": \"run_formula_detection\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/formula_detection.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/formula_detection.yaml\"\n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        },\n        {\n            \"name\": \"run_layout_detection\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/layout_detection.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/layout_detection.yaml\"\n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        },\n        {\n            \"name\": \"run_layout_detection_layoutlmv3\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${workspaceFolder}/scripts/layout_detection.py\",\n            \"console\": \"integratedTerminal\",\n            \"args\": [\n                \"--config\",\n                \"configs/layout_detection_layoutlmv3.yaml\"\n            ],\n            \"env\": {\n                \"PYTHONPATH\": \"/Users/bin/anaconda3/envs/mfd_test\"\n            }\n        }\n    ]\n}"
  },
  {
    "path": "LICENSE.md",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published\n    by the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>."
  },
  {
    "path": "README.md",
    "content": "\n<p align=\"center\">\n  <img src=\"assets/readme/pdf-extract-kit_logo.png\" width=\"220px\" style=\"vertical-align:middle;\">\n</p>\n\n<div align=\"center\">\n\nEnglish | [简体中文](./README_zh-CN.md)\n\n[PDF-Extract-Kit-1.0 Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html)\n\n[[Models (🤗Hugging Face)]](https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0) | [[Models(<img src=\"./assets/readme/modelscope_logo.png\" width=\"20px\">ModelScope)]](https://www.modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0) \n \n🔥🔥🔥 [MinerU: Efficient Document Content Extraction Tool Based on PDF-Extract-Kit](https://github.com/opendatalab/MinerU)\n\n</div>\n\n<p align=\"center\">\n    👋 join us on <a href=\"https://discord.gg/Tdedn9GTXq\" target=\"_blank\">Discord</a> and <a href=\"https://r.vansin.top/?r=MinerU\" target=\"_blank\">WeChat</a>\n</p>\n\n\n## Overview\n\n`PDF-Extract-Kit` is a powerful open-source toolkit designed to efficiently extract high-quality content from complex and diverse PDF documents. Here are its main features and advantages:\n\n- **Integration of Leading Document Parsing Models**: Incorporates state-of-the-art models for layout detection, formula detection, formula recognition, OCR, and other core document parsing tasks.\n- **High-Quality Parsing Across Diverse Documents**: Fine-tuned with diverse document annotation data to deliver high-quality results across various complex document types.\n- **Modular Design**: The flexible modular design allows users to easily combine and construct various applications by modifying configuration files and minimal code, making application building as straightforward as stacking blocks.\n- **Comprehensive Evaluation Benchmarks**: Provides diverse and comprehensive PDF evaluation benchmarks, enabling users to choose the most suitable model based on evaluation results.\n\n**Experience PDF-Extract-Kit now and unlock the limitless potential of PDF documents!**\n\n> **Note:** PDF-Extract-Kit is designed for high-quality document processing and functions as a model toolbox.    \n> If you are interested in extracting high-quality document content (e.g., converting PDFs to Markdown), please use [MinerU](https://github.com/opendatalab/MinerU), which combines the high-quality predictions from PDF-Extract-Kit with specialized engineering optimizations for more convenient and efficient content extraction.    \n> If you're a developer looking to create engaging applications such as document translation, document Q&A, or document assistants, you'll find it very convenient to build your own projects using PDF-Extract-Kit. In particular, we will periodically update the PDF-Extract-Kit/project directory with interesting applications, so stay tuned!\n\n**We welcome researchers and engineers from the community to contribute outstanding models and innovative applications by submitting PRs to become contributors to the PDF-Extract-Kit project.**\n\n## Model Overview\n\n| **Task Type**     | **Description**                                                                 | **Models**                    |\n|-------------------|---------------------------------------------------------------------------------|-------------------------------|\n| **Layout Detection** | Locate different elements in a document: including images, tables, text, titles, formulas | `DocLayout-YOLO_ft`, `YOLO-v10_ft`, `LayoutLMv3_ft` | \n| **Formula Detection** | Locate formulas in documents: including inline and block formulas            | `YOLOv8_ft`                   |  \n| **Formula Recognition** | Recognize formula images into LaTeX source code                             | `UniMERNet`                   |  \n| **OCR**           | Extract text content from images (including location and recognition)            | `PaddleOCR`                   | \n| **Table Recognition** | Recognize table images into corresponding source code (LaTeX/HTML/Markdown)   | `PaddleOCR+TableMaster`, `StructEqTable` |  \n| **Reading Order** | Sort and concatenate discrete text paragraphs                                    | Coming Soon!                  | \n\n## News and Updates\n- `2024.10.22` 🎉🎉🎉 We are excited to announce that table recognition model [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B), which supports output LaTeX, HTML and MarkdDown formats has been officially integrated into `PDF-Extract-Kit 1.0`. Please refer to the [table recognition algorithm documentation](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/table_recognition.html) for usage instructions!\n- `2024.10.17` 🎉🎉🎉 We are excited to announce that the more accurate and faster layout detection model, [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO), has been officially integrated into `PDF-Extract-Kit 1.0`. Please refer to the [layout detection algorithm documentation](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/layout_detection.html) for usage instructions!\n- `2024.10.10` 🎉🎉🎉 The official release of `PDF-Extract-Kit 1.0`, rebuilt with modularity for more convenient and flexible model usage! Please switch to the [release/0.1.1](https://github.com/opendatalab/PDF-Extract-Kit/tree/release/0.1.1) branch for the old version.\n- `2024.08.01` 🎉🎉🎉 Added the [StructEqTable](demo/TabRec/StructEqTable/README_TABLE.md) module for table content extraction. Welcome to use it!\n- `2024.07.01` 🎉🎉🎉 We released `PDF-Extract-Kit`, a comprehensive toolkit for high-quality PDF content extraction, including `Layout Detection`, `Formula Detection`, `Formula Recognition`, and `OCR`.\n\n## Performance Demonstration\n\nMany current open-source SOTA models are trained and evaluated on academic datasets, achieving high-quality results only on single document types. To enable models to achieve stable and robust high-quality results on diverse documents, we constructed diverse fine-tuning datasets and fine-tuned some SOTA models to obtain practical parsing models. Below are some visual results of the models.\n\n### Layout Detection\n\nWe trained robust `Layout Detection` models using diverse PDF document annotations. Our fine-tuned models achieve accurate extraction results on diverse PDF documents such as papers, textbooks, research reports, and financial reports, and demonstrate high robustness to challenges like blurring and watermarks. The visualization example below shows the inference results of the fine-tuned LayoutLMv3 model.\n \n![](assets/readme/layout_example.png)\n\n### Formula Detection\n\nSimilarly, we collected and annotated documents containing formulas in both English and Chinese, and fine-tuned advanced formula detection models. The visualization result below shows the inference results of the fine-tuned YOLO formula detection model:\n\n![](assets/readme/mfd_example.png)\n\n### Formula Recognition\n\n[UniMERNet](https://github.com/opendatalab/UniMERNet) is an algorithm designed for diverse formula recognition in real-world scenarios. By constructing large-scale training data and carefully designed results, it achieves excellent recognition performance for complex long formulas, handwritten formulas, and noisy screenshot formulas.\n\n### Table Recognition\n\n[StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) is a high efficiency toolkit that can converts table images into LaTeX/HTML/MarkDown. The latest version, powered by the InternVL2-1B foundation model,  improves Chinese recognition accuracy and expands multi-format output options.\n\n#### For more visual and inference results of the models, please refer to the [PDF-Extract-Kit tutorial documentation](xxx).\n\n## Evaluation Metrics\n\nComing Soon!\n\n## Usage Guide\n\n### Environment Setup\n\n```bash\nconda create -n pdf-extract-kit-1.0 python=3.10\nconda activate pdf-extract-kit-1.0\npip install -r requirements.txt\n```\n> **Note:** If your device does not support GPU, please install the CPU version dependencies using `requirements-cpu.txt` instead of `requirements.txt`.\n\n> **Note：** Current Doclayout-YOLO only supports installation from pypi，if error raises during DocLayout-YOLO installation，please install through `pip3 install doclayout-yolo==0.0.2 --extra-index-url=https://pypi.org/simple` .\n\n### Model Download\n\nPlease refer to the [Model Weights Download Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html) to download the required model weights. Note: You can choose to download all the weights or select specific ones. For detailed instructions, please refer to the tutorial.\n\n### Running Demos\n\n#### Layout Detection Model\n\n```bash \npython scripts/layout_detection.py --config=configs/layout_detection.yaml\n```\nLayout detection models support **DocLayout-YOLO** (default model), YOLO-v10, and LayoutLMv3. For YOLO-v10 and LayoutLMv3, please refer to [Layout Detection Algorithm](https://pdf-extract-kit.readthedocs.io/en/latest/algorithm/layout_detection.html). You can view the layout detection results in the `outputs/layout_detection` folder.\n\n#### Formula Detection Model\n\n```bash \npython scripts/formula_detection.py --config=configs/formula_detection.yaml\n```\nYou can view the formula detection results in the `outputs/formula_detection` folder.\n\n#### OCR Model\n\n```bash \npython scripts/ocr.py --config=configs/ocr.yaml\n```\nYou can view the OCR results in the `outputs/ocr` folder.\n\n#### Formula Recognition Model\n\n```bash \npython scripts/formula_recognition.py --config=configs/formula_recognition.yaml\n```\nYou can view the formula recognition results in the `outputs/formula_recognition` folder.\n\n#### Table Recognition Model\n\n```bash \npython scripts/table_parsing.py --config configs/table_parsing.yaml\n```\nYou can view the table recognition results in the `outputs/table_parsing` folder.\n\n> **Note:** For more details on using the model, please refer to the[PDF-Extract-Kit-1.0 Tutorial](https://pdf-extract-kit.readthedocs.io/en/latest/get_started/pretrained_model.html).\n\n> This project focuses on using models for `high-quality` content extraction from `diverse` documents and does not involve reconstructing extracted content into new documents, such as PDF to Markdown. For such needs, please refer to our other GitHub project: [MinerU](https://github.com/opendatalab/MinerU).\n\n## To-Do List\n\n- [x] **Table Parsing**: Develop functionality to convert table images into corresponding LaTeX/Markdown format source code.\n- [ ] **Chemical Equation Detection**: Implement automatic detection of chemical equations.\n- [ ] **Chemical Equation/Diagram Recognition**: Develop models to recognize and parse chemical equations and diagrams.\n- [ ] **Reading Order Sorting Model**: Build a model to determine the correct reading order of text in documents.\n\n**PDF-Extract-Kit** aims to provide high-quality PDF content extraction capabilities. We encourage the community to propose specific and valuable needs and welcome everyone to participate in continuously improving the PDF-Extract-Kit tool to advance research and industry development.\n\n## License\n\nThis project is open-sourced under the [AGPL-3.0](LICENSE) license.\n\nSince this project uses YOLO code and PyMuPDF for file processing, these components require compliance with the AGPL-3.0 license. Therefore, to ensure adherence to the licensing requirements of these dependencies, this repository as a whole adopts the AGPL-3.0 license.\n\n## Acknowledgement\n\n   - [LayoutLMv3](https://github.com/microsoft/unilm/tree/master/layoutlmv3): Layout detection model\n   - [UniMERNet](https://github.com/opendatalab/UniMERNet): Formula recognition model\n   - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy): Table recognition model\n   - [YOLO](https://github.com/ultralytics/ultralytics): Formula detection model\n   - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR): OCR model\n   - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO): Layout detection model\n\n## Citation\nIf you find our models / code / papers useful in your research, please consider giving ⭐ and citations 📝, thx :)  \n```bibtex\n@article{wang2024mineru,\n  title={MinerU: An Open-Source Solution for Precise Document Content Extraction},\n  author={Wang, Bin and Xu, Chao and Zhao, Xiaomeng and Ouyang, Linke and Wu, Fan and Zhao, Zhiyuan and Xu, Rui and Liu, Kaiwen and Qu, Yuan and Shang, Fukai and others},\n  journal={arXiv preprint arXiv:2409.18839},\n  year={2024}\n}\n\n@misc{zhao2024doclayoutyoloenhancingdocumentlayout,\n      title={DocLayout-YOLO: Enhancing Document Layout Analysis through Diverse Synthetic Data and Global-to-Local Adaptive Perception}, \n      author={Zhiyuan Zhao and Hengrui Kang and Bin Wang and Conghui He},\n      year={2024},\n      eprint={2410.12628},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2410.12628}, \n}\n\n@misc{wang2024unimernet,\n      title={UniMERNet: A Universal Network for Real-World Mathematical Expression Recognition}, \n      author={Bin Wang and Zhuangcheng Gu and Chao Xu and Bo Zhang and Botian Shi and Conghui He},\n      year={2024},\n      eprint={2404.15254},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n\n@article{he2024opendatalab,\n  title={Opendatalab: Empowering general artificial intelligence with open datasets},\n  author={He, Conghui and Li, Wei and Jin, Zhenjiang and Xu, Chao and Wang, Bin and Lin, Dahua},\n  journal={arXiv preprint arXiv:2407.13773},\n  year={2024}\n}\n```\n\n## Star History\n\n<a>\n <picture>\n   <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date&theme=dark\" />\n   <source media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date\" />\n   <img alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date\" />\n </picture>\n</a>\n\n## Related Links\n- [UniMERNet (Real-World Formula Recognition Algorithm)](https://github.com/opendatalab/UniMERNet)\n- [LabelU (Lightweight Multimodal Annotation Tool)](https://github.com/opendatalab/labelU)\n- [LabelLLM (Open Source LLM Dialogue Annotation Platform)](https://github.com/opendatalab/LabelLLM)\n- [MinerU (One-Stop High-Quality Data Extraction Tool)](https://github.com/opendatalab/MinerU)\n"
  },
  {
    "path": "README_zh-CN.md",
    "content": "\n<p align=\"center\">\n  <img src=\"assets/readme/pdf-extract-kit_logo.png\" width=\"220px\" style=\"vertical-align:middle;\">\n</p>\n\n<div align=\"center\">\n\n[English](./README.md) | 简体中文\n\n[PDF-Extract-Kit-1.0中文教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html)\n\n[[Models (🤗Hugging Face)]](https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0) | [[Models(<img src=\"./assets/readme/modelscope_logo.png\" width=\"20px\">ModelScope)]](https://www.modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0) \n \n🔥🔥🔥 [MinerU：基于PDF-Extract-Kit的高效文档内容提取工具](https://github.com/opendatalab/MinerU)\n</div>\n\n<p align=\"center\">\n    👋 join us on <a href=\"https://discord.gg/JYsXDXXN\" target=\"_blank\">Discord</a> and <a href=\"https://r.vansin.top/?r=MinerU\" target=\"_blank\">WeChat</a>\n</p>\n\n\n## 整体介绍\n\n`PDF-Extract-Kit` 是一款功能强大的开源工具箱，旨在从复杂多样的 PDF 文档中高效提取高质量内容。以下是其主要功能和优势：\n\n- **集成文档解析主流模型**：汇聚布局检测、公式检测、公式识别、OCR等文档解析核心任务的众多SOTA模型；\n- **多样性文档下高质量解析结果**：结合多样性文档标注数据在进行模型微调，在复杂多样的文档下提供高质量解析结果；\n- **模块化设计**：模块化设计使用户可以通过修改配置文件及少量代码即可自由组合构建各种应用，让应用构建像搭积木一样简便；  \n- **全面评测基准**：提供多样性全面的PDF评测基准，用户可根据评测结果选择最适合自己的模型。  \n\n**立即体验 PDF-Extract-Kit，解锁 PDF 文档的无限潜力！** \n\n> **注意：** PDF-Extract-Kit 专注于高质量文档处理，适合作为模型工具箱使用。\n> 如果你想提取高质量文档内容(PDF转Markdown)，请直接使用[MinerU](https://github.com/opendatalab/MinerU)，MinerU结合PDF-Extract-Kit的高质量预测结果，进行了专门的工程优化，使得PDF文档内容提取更加便捷高效；  \n> 如果你是一位开发者，希望搭建更多有意思的应用（如文档翻译，文档问答，文档助手等），基于PDF-Extract-Kit自行进行DIY将会十分便捷。特别地，我们会在`PDF-Extract-Kit/project`下面不定期更新一些有趣的应用，敬请期待！  \n\n**我们欢迎社区研究员和工程师贡献优秀模型和创新应用，通过提交 PR 成为 PDF-Extract-Kit 的贡献者。**\n\n\n## 模型概览\n\n| **任务类型** | **任务描述**                                                                    | **模型**                     |\n|--------------|---------------------------------------------------------------------------------|------------------------------|\n| **布局检测** | 定位文档中不同元素位置：包含图像、表格、文本、标题、公式等 | `DocLayout-YOLO_ft`, `YOLO-v10_ft`, `LayoutLMv3_ft` |\n| **公式检测** | 定位文档中公式位置：包含行内公式和行间公式                                      | `YOLOv8_ft`                       |\n| **公式识别** | 识别公式图像为latex源码                                                         | `UniMERNet`                  |\n|    **OCR**   | 提取图像中的文本内容（包括定位和识别）                                          | `PaddleOCR`                  |\n| **表格识别** | 识别表格图像为对应源码（Latex/HTML/Markdown）                                   | `PaddleOCR+TableMaster`,`StructEqTable`  |\n| **阅读顺序** | 将离散的文本段落进行排序拼接                                                    |  Coming Soon !                            |\n\n\n\n## 新闻和更新\n- `2024.10.22` 🎉🎉🎉 支持LaTex和HTML等多种输出格式的表格模型[StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B)正式接入`PDF-Extract-Kit 1.0`，请参考[表格识别算法文档](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/table_recognition.html)进行使用！\n- `2024.10.17` 🎉🎉🎉 检测结果更准确，速度更快的布局检测模型`DocLayout-YOLO`正式接入`PDF-Extract-Kit 1.0`，请参考[布局检测算法文档](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/layout_detection.html)进行使用！\n- `2024.10.10` 🎉🎉🎉 基于模块化重构的`PDF-Extract-Kit 1.0`正式版本正式发布，模型使用更加便捷灵活！老版本请切换至[release/0.1.1](https://github.com/opendatalab/PDF-Extract-Kit/tree/release/0.1.1)分支进行使用。\n- `2024.08.01` 🎉🎉🎉 新增了[StructEqTable](demo/TabRec/StructEqTable/README_TABLE.md)表格识别模块用于表格内容提取，欢迎使用！\n- `2024.07.01` 🎉🎉🎉 我们发布了`PDF-Extract-Kit`，一个用于高质量PDF内容提取的综合工具包，包括`布局检测`、`公式检测`、`公式识别`和`OCR`。\n\n\n\n## 效果展示\n\n当前的一些开源SOTA模型多基于学术数据集进行训练评测，仅能在单一的文档类型上获取高质量结果。为了使得模型能够在多样性文档上也能获得稳定鲁棒的高质量结果，我们构建多样性的微调数据集，并在一些SOTA模型上微调已得到可实用解析模型。下边是一些模型的可视化结果。\n\n### 布局检测\n\n结合多样性PDF文档标注，我们训练了鲁棒的`布局检测`模型。在论文、教材、研报、财报等多样性的PDF文档上，我们微调后的模型都能得到准确的提取结果，对于扫描模糊、水印等情况也有较高鲁棒性。下面可视化示例是经过微调后的LayoutLMv3模型的推理结果。\n\n![](assets/readme/layout_example.png)\n\n\n### 公式检测\n\n同样的，我们收集了包含公式的中英文文档进行标注，基于先进的公式检测模型进行微调，下面可视化结果是微调后的YOLO公式检测模型的推理结果：\n\n![](assets/readme/mfd_example.png)\n\n\n### 公式识别\n\n[UniMERNet](https://github.com/opendatalab/UniMERNet)是针对真实场景下多样性公式识别的算法，通过构建大规模训练数据及精心设计的结果，使得其可以对复杂长公式、手写公式、含噪声的截图公式均有不错的识别效果。\n\n### 表格识别\n\n[StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)是一个高效表格内容提取工具，能够将表格图像转换为LaTeX/HTML/Markdown格式，最新版本使用InternVL2-1B基础模型，提高了中文识别准确度并增加了多格式输出能力。\n\n#### 更多模型的可视化结果及推理结果可以参考[PDF-Extract-Kit教程文档](xxx)\n\n\n## 评测指标\n\nComing Soon! \n\n## 使用教程\n\n### 环境安装\n\n```bash\nconda create -n pdf-extract-kit-1.0 python=3.10\nconda activate pdf-extract-kit-1.0\npip install -r requirements.txt\n```\n> **注意：** 如果你的设备不支持 GPU，请使用 `requirements-cpu.txt` 安装 CPU 版本的依赖。\n\n> **注意：** 目前doclayout-yolo仅支持从pypi源安装，如果出现doclayout-yolo无法安装，请通过 `pip3 install doclayout-yolo==0.0.2 --extra-index-url=https://pypi.org/simple` 安装。\n\n### 模型下载\n\n参考[模型权重下载教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html)下载所需模型权重。注：可以选择全部下载，也可以选择部分下载，具体操作参考教程。\n\n\n### Demo运行\n\n#### 布局检测模型\n\n```bash \npython scripts/layout_detection.py --config=configs/layout_detection.yaml\n```\n布局检测模型支持**DocLayout-YOLO**（默认模型），YOLO-v10，以及LayoutLMv3。对于YOLO-v10和LayoutLMv3的布局检测，请参考[Layout Detection Algorithm](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/algorithm/layout_detection.html)。你可以在 `outputs/layout_detection` 文件夹下查看布局检测结果。\n\n#### 公式检测模型\n\n```bash \npython scripts/formula_detection.py --config=configs/formula_detection.yaml\n```\n你可以在 `outputs/formula_detection` 文件夹下查看公式检测结果。\n\n\n#### 文本识别（OCR）模型\n\n```bash \npython scripts/ocr.py --config=configs/ocr.yaml\n```\n你可以在 `outputs/ocr` 文件夹下查看OCR结果。\n\n\n#### 公式识别模型\n\n```bash \npython scripts/formula_recognition.py --config=configs/formula_recognition.yaml\n```\n你可以在 `outputs/formula_recognition` 文件夹下查看公式识别结果。\n\n\n#### 表格识别模型\n\n```bash \npython scripts/table_parsing.py --config configs/table_parsing.yaml\n```\n你可以在 `outputs/table_parsing` 文件夹下查看表格内容识别结果。\n\n\n> **注意：** 更多模型使用细节请查看[PDF-Extract-Kit-1.0 中文教程](https://pdf-extract-kit.readthedocs.io/zh-cn/latest/get_started/pretrained_model.html).\n\n> 本项目专注使用模型对`多样性`文档进行`高质量`内容提取，不涉及提取后内容拼接成新文档，如PDF转Markdown。如果有此类需求，请参考我们另一个Github项目: [MinerU](https://github.com/opendatalab/MinerU)\n\n\n## 待办事项\n\n- [x] **表格解析**：开发能够将表格图像转换成对应的LaTeX/Markdown格式源码的功能。  \n- [ ] **化学方程式检测**：实现对化学方程式的自动检测。  \n- [ ] **化学方程式/图解识别**：开发识别并解析化学方程式的模型。  \n- [ ] **阅读顺序排序模型**：构建模型以确定文档中文本的正确阅读顺序。  \n\n**PDF-Extract-Kit** 旨在提供高质量PDF文件的提取能力。我们鼓励社区提出具体且有价值的需求，并欢迎大家共同参与，以不断改进PDF-Extract-Kit工具，推动科研及产业发展。\n\n\n## 协议\n\n本项目采用 [AGPL-3.0](LICENSE) 协议开源。\n\n由于本项目中使用了 YOLO 代码和 PyMuPDF 进行文件处理，这些组件都需要遵循 AGPL-3.0 协议。因此，为了确保遵守这些依赖项的许可证要求，本仓库整体采用 AGPL-3.0 协议。\n\n\n## 致谢\n\n   - [LayoutLMv3](https://github.com/microsoft/unilm/tree/master/layoutlmv3): 布局检测模型\n   - [UniMERNet](https://github.com/opendatalab/UniMERNet): 公式识别模型\n   - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy): 表格识别模型\n   - [YOLO](https://github.com/ultralytics/ultralytics): 公式检测模型\n   - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR): OCR模型\n   - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO): 布局检测模型\n\n\n## Citation\n\n如果你觉得我们模型/代码/技术报告对你有帮助，请给我们⭐和引用📝,谢谢 :)  \n```bibtex\n@article{wang2024mineru,\n  title={MinerU: An Open-Source Solution for Precise Document Content Extraction},\n  author={Wang, Bin and Xu, Chao and Zhao, Xiaomeng and Ouyang, Linke and Wu, Fan and Zhao, Zhiyuan and Xu, Rui and Liu, Kaiwen and Qu, Yuan and Shang, Fukai and others},\n  journal={arXiv preprint arXiv:2409.18839},\n  year={2024}\n}\n\n@misc{wang2024unimernet,\n      title={UniMERNet: A Universal Network for Real-World Mathematical Expression Recognition}, \n      author={Bin Wang and Zhuangcheng Gu and Chao Xu and Bo Zhang and Botian Shi and Conghui He},\n      year={2024},\n      eprint={2404.15254},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n\n@misc{zhao2024doclayoutyoloenhancingdocumentlayout,\n      title={DocLayout-YOLO: Enhancing Document Layout Analysis through Diverse Synthetic Data and Global-to-Local Adaptive Perception}, \n      author={Zhiyuan Zhao and Hengrui Kang and Bin Wang and Conghui He},\n      year={2024},\n      eprint={2410.12628},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2410.12628}, \n}\n\n@article{he2024opendatalab,\n  title={Opendatalab: Empowering general artificial intelligence with open datasets},\n  author={He, Conghui and Li, Wei and Jin, Zhenjiang and Xu, Chao and Wang, Bin and Lin, Dahua},\n  journal={arXiv preprint arXiv:2407.13773},\n  year={2024}\n}\n```\n\n\n## Star历史\n\n<a>\n <picture>\n   <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date&theme=dark\" />\n   <source media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date\" />\n   <img alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=opendatalab/PDF-Extract-Kit&type=Date\" />\n </picture>\n</a>\n\n## 友情链接\n- [UniMERNet（真实场景公式识别算法）](https://github.com/opendatalab/UniMERNet)\n- [LabelU（轻量级多模态标注工具）](https://github.com/opendatalab/labelU)\n- [LabelLLM（开源LLM对话标注平台）](https://github.com/opendatalab/LabelLLM)\n- [MinerU（一站式高质量数据提取工具）](https://github.com/opendatalab/MinerU)"
  },
  {
    "path": "configs/config.yaml",
    "content": "inputs: assets/demo/formula_detection_pdfs\noutputs: outputs/formula_detection_pdfs\ntasks:\n  formula_detection:\n    model: formula_detection_yolo\n    model_config:\n      img_size: 1280\n      conf_thres: 0.25\n      iou_thres: 0.45\n      model_path: models/MFD/weights.pt\n      visualize: True\n  formula_recognition:\n    model: formula_recognition_unimernet\n    model_config:\n      cfg_path: pdf_extract_kit/configs/unimernet.yaml\n      model_path: models/MFR/UniMERNet\n      visualize: True"
  },
  {
    "path": "configs/formula_detection.yaml",
    "content": "inputs: assets/demo/formula_detection\noutputs: outputs/formula_detection\ntasks:\n  formula_detection:\n    model: formula_detection_yolo\n    model_config:\n      img_size: 1280\n      conf_thres: 0.25\n      iou_thres: 0.45\n      batch_size: 1\n      model_path: models/MFD/YOLO/yolo_v8_ft.pt\n      visualize: True"
  },
  {
    "path": "configs/formula_recognition.yaml",
    "content": "inputs: assets/demo/formula_recognition\noutputs: outputs/formula_recognition\ntasks:\n  formula_recognition:\n    model: formula_recognition_unimernet\n    model_config:\n      cfg_path: pdf_extract_kit/configs/unimernet.yaml\n      model_path: models/MFR/unimernet_tiny\n      visualize: False"
  },
  {
    "path": "configs/layout_detection.yaml",
    "content": "inputs: assets/demo/layout_detection\noutputs: outputs/layout_detection\ntasks:\n  layout_detection:\n    model: layout_detection_yolo\n    model_config:\n      img_size: 1024\n      conf_thres: 0.25\n      iou_thres: 0.45\n      model_path: models/Layout/YOLO/doclayout_yolo_ft.pt\n      visualize: True"
  },
  {
    "path": "configs/layout_detection_layoutlmv3.yaml",
    "content": "inputs: assets/demo/layout_detection\noutputs: outputs/layout_detection\ntasks:\n  layout_detection:\n    model: layout_detection_layoutlmv3\n    model_config:\n      model_path: models/Layout/LayoutLMv3/model_final.pth"
  },
  {
    "path": "configs/layout_detection_yolo.yaml",
    "content": "inputs: assets/demo/layout_detection\noutputs: outputs/layout_detection\ntasks:\n  layout_detection:\n    model: layout_detection_yolo\n    model_config:\n      img_size: 1024\n      conf_thres: 0.25\n      iou_thres: 0.45\n      model_path: models/Layout/YOLO/doclayout_yolo_ft.pt\n      visualize: True\n      device: 0"
  },
  {
    "path": "configs/ocr.yaml",
    "content": "inputs: assets/demo/ocr\noutputs: outputs/ocr\nvisualize: True\ntasks:\n  ocr:\n    model: ocr_ppocr\n    model_config:\n      lang: ch\n      show_log: True\n      det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n      rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n      det_db_box_thresh: 0.3"
  },
  {
    "path": "configs/table_parsing.yaml",
    "content": "inputs: assets/demo/table_parsing\noutputs: outputs/table_parsing\ntasks:\n  table_parsing:\n    model: table_parsing_struct_eqtable\n    model_config:\n      model_path: models/TabRec/StructEqTable\n      max_new_tokens: 1024\n      max_time: 30\n      output_format: latex\n      lmdeploy: False\n      flash_atten: True"
  },
  {
    "path": "docs/en/.readthedocs.yaml",
    "content": "version: 2\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\nformats:\n  - epub\n\npython:\n  install:\n    - requirements: requirements/docs.txt\n\nsphinx:\n  configuration: docs/en/conf.py\n"
  },
  {
    "path": "docs/en/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/en/algorithm/formula_detection.rst",
    "content": "..  _algorithm_formula_detection:\n\n====================\nFormula Detection Algorithm\n====================\n\nIntroduction\n====================\n\nFormula detection involves identifying the positions of all formulas (including inline and block formulas) in a given input image.\n\n.. note::\n\n   Formula detection is technically a subtask of layout detection. However, due to its complexity, we recommend using a dedicated formula detection model to decouple it. This approach typically makes data annotation easier and improves detection performance.\n\nModel Usage\n====================\n\nWith the environment properly set up, simply run the layout detection algorithm script by executing ``scripts/formula_detection.py``.\n\n.. code:: shell\n\n   $ python scripts/formula_detection.py --config configs/formula_detection.yaml\n\nModel Configuration\n--------------------\n\n.. code:: yaml\n\n   inputs: assets/demo/formula_detection\n   outputs: outputs/formula_detection\n   tasks:\n      formula_detection:\n         model: formula_detection_yolo\n         model_config:\n            img_size: 1280\n            conf_thres: 0.25\n            iou_thres: 0.45\n            batch_size: 1\n            model_path: models/MFD/yolov8/weights.pt\n            visualize: True\n\n- inputs/outputs: Define the input file path and the visualization output directory, respectively.\n- tasks: Define the task type, currently only a formula detection task is included.\n- model: Define the specific model type: currently, only the YOLO formula detection model is available.\n- model_config: Define the model configuration.\n- img_size: Define the image's longer side size; the shorter side will be scaled proportionally.\n- conf_thres: Define the confidence threshold; only targets above this threshold will be detected.\n- iou_thres: Define the IoU threshold to remove targets with an overlap greater than this value.\n- batch_size: Define the batch size; the number of images inferred simultaneously. Generally, the larger the batch size, the faster the inference speed. A better GPU allows for a larger batch size.\n- model_path: Path to the model weights.\n- visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory.\n\nDiverse Input Support\n--------------------\n\nThe formula detection script in PDF-Extract-Kit supports various input formats such as ``a single image``, ``a directory of image files``, ``a single PDF file``, and ``a directory of PDF files``.\n\n.. note:: \n\n   Modify the ``inputs`` path in ``configs/formula_detection.yaml`` according to your actual data format:\n   - Single image: path/to/image  \n   - Image directory: path/to/images  \n   - Single PDF file: path/to/pdf  \n   - PDF directory: path/to/pdfs  \n\n.. note::\n\n   When using a PDF as input, you need to change ``predict_images`` to ``predict_pdfs`` in ``formula_detection.py``.\n\n   .. code:: python\n\n      # for image detection\n      detection_results = model_formula_detection.predict_images(input_data, result_path)\n   \n   Change to:\n\n   .. code:: python\n\n      # for pdf detection\n      detection_results = model_formula_detection.predict_pdfs(input_data, result_path)\n\n\nViewing Visualization Results\n--------------------\n\nWhen the ``visualize`` option in the config file is set to ``True``, visualization results will be saved in the ``outputs/formula_detection`` directory.\n\n.. note::\n\n   Visualization facilitates the analysis of model results. However, for large-scale tasks, it is recommended to disable visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage."
  },
  {
    "path": "docs/en/algorithm/formula_recognition.rst",
    "content": "..  _algorithm_formula_recognition:\n\n============\nFormula Recognition Algorithm\n============\n\nIntroduction\n=================\n\nFormula detection involves recognizing the content of a given input formula image and converting it to ``LaTeX`` format.\n\nModel Usage\n=================\n\nWith the environment properly configured, you can run the layout detection algorithm script by executing ``scripts/formula_recognition.py``.\n\n.. code:: shell\n\n   $ python scripts/formula_recognition.py --config configs/formula_recognition.yaml\n\nModel Configuration\n-----------------\n\n.. code:: yaml\n\n   inputs: assets/demo/formula_recognition\n   outputs: outputs/formula_recognition\n   tasks:\n      formula_recognition:\n         model: formula_recognition_unimernet\n         model_config:\n            cfg_path: pdf_extract_kit/configs/unimernet.yaml\n            model_path: models/MFR/unimernet_tiny\n            visualize: False\n\n- inputs/outputs: Define the input file path and the directory for LaTeX prediction results, respectively.\n- tasks: Define the task type, currently only containing a formula recognition task.\n- model: Define the specific model type: Currently, only the `UniMERNet <https://github.com/opendatalab/UniMERNet>`_ formula recognition model is provided.\n- model_config: Define the model configuration.\n- cfg_path: Path to the UniMERNet configuration file.\n- model_path: Path to the model weights.\n- visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory.\n\nSupport for Diverse Inputs\n-----------------\n\nThe formula detection script in PDF-Extract-Kit supports ``single formula images`` and ``document images with corresponding formula regions``.\n\nViewing Visualization Results\n-----------------\n\nWhen the visualize setting in the config file is set to True, ``LaTeX`` prediction results will be saved in the outputs directory."
  },
  {
    "path": "docs/en/algorithm/layout_detection.rst",
    "content": ".. _algorithm_layout_detection:\n\n=================\nLayout Detection Algorithm\n=================\n\nIntroduction\n=================\n\nLayout detection is a fundamental task in document content extraction, aiming to locate different types of regions on a page, such as images, tables, text, and headings, to facilitate high-quality content extraction. For text and heading regions, OCR models can be used for text recognition, while table regions can be converted using table recognition models.\n\nModel Usage\n=================\n\nLayout detection supports following models：\n\n.. raw:: html\n\n    <style type=\"text/css\">\n    .tg  {border-collapse:collapse;border-color:#9ABAD9;border-spacing:0;}\n    .tg td{background-color:#EBF5FF;border-color:#9ABAD9;border-style:solid;border-width:1px;color:#444;\n      font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg th{background-color:#409cff;border-color:#9ABAD9;border-style:solid;border-width:1px;color:#fff;\n      font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg .tg-f8tz{background-color:#409cff;border-color:inherit;text-align:left;vertical-align:top}\n    .tg .tg-0lax{text-align:left;vertical-align:top}\n    .tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}\n    </style>\n    <table class=\"tg\"><thead>\n      <tr>\n        <th class=\"tg-0lax\">Model</th>\n        <th class=\"tg-f8tz\">Description</th>\n        <th class=\"tg-f8tz\">Characteristics</th>\n        <th class=\"tg-f8tz\">Model weight</th>\n        <th class=\"tg-f8tz\">Config file</th>\n      </tr></thead>\n    <tbody>\n      <tr>\n        <td class=\"tg-0lax\">DocLayout-YOLO</td>\n        <td class=\"tg-0pky\">Improved based on YOLO-v10：<br>1. Generate diverse pre-training data，enhance generalization ability across multiple document types<br>2. Model architecture improvement, improve perception ability on scale-varing instances<br>Details in <a href=\"https://github.com/opendatalab/DocLayout-YOLO\" target=\"_blank\" rel=\"noopener noreferrer\">DocLayout-YOLO</a></td>\n        <td class=\"tg-0pky\">Speed:Fast, Accuracy:High</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/blob/main/models/Layout/YOLO/doclayout_yolo_ft.pt\" target=\"_blank\" rel=\"noopener noreferrer\">doclayout_yolo_ft.pt</a></td>\n        <td class=\"tg-0pky\">layout_detection.yaml</td>\n      </tr>\n      <tr>\n        <td class=\"tg-0lax\">YOLO-v10</td>\n        <td class=\"tg-0pky\">Base YOLO-v10 model</td>\n        <td class=\"tg-0pky\">Speed:Fast, Accuracy:Moderate</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/blob/main/models/Layout/YOLO/yolov10l_ft.pt\" target=\"_blank\" rel=\"noopener noreferrer\">yolov10l_ft.pt</a></td>\n        <td class=\"tg-0pky\">layout_detection_yolo.yaml</td>\n      </tr>\n      <tr>\n        <td class=\"tg-0lax\">LayoutLMv3</td>\n        <td class=\"tg-0pky\">Base LayoutLMv3 model</td>\n        <td class=\"tg-0pky\">Speed:Slow, Accuracy:High</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/tree/main/models/Layout/LayoutLMv3\" target=\"_blank\" rel=\"noopener noreferrer\">layoutlmv3_ft</a></td>\n        <td class=\"tg-0pky\">layout_detection_layoutlmv3.yaml</td>\n      </tr>\n    </tbody></table>\n\nOnce enciroment is setup, you can perform layout detection by executing ``scripts/layout_detection.py`` directly.\n\n**Run demo**\n\n.. code:: shell\n\n   $ python scripts/layout_detection.py --config configs/layout_detection.yaml\n\nModel Configuration\n-----------------\n\n**1. DocLayout-YOLO / YOLO-v10**\n\n.. code:: yaml\n\n    inputs: assets/demo/layout_detection\n    outputs: outputs/layout_detection\n    tasks:\n      layout_detection:\n        model: layout_detection_yolo\n        model_config:\n          img_size: 1024\n          conf_thres: 0.25\n          iou_thres: 0.45\n          model_path: path/to/doclayout_yolo_model\n          visualize: True\n\n- inputs/outputs: Define the input file path and the directory for visualization output.\n- tasks: Define the task type, currently only a layout detection task is included.\n- model: Specify the specific model type, e.g., layout_detection_yolo.\n- model_config: Define the model configuration.\n- img_size: Define the image long edge size; the short edge will be scaled proportionally based on the long edge, with the default long edge being 1024.\n- conf_thres: Define the confidence threshold, detecting only targets above this threshold.\n- iou_thres: Define the IoU threshold, removing targets with an overlap greater than this threshold.\n- model_path: Path to the model weights.\n- visualize: Whether to visualize the model results; visualized results will be saved in the outputs directory.\n\n\n**2. layoutlmv3**\n\n.. note::\n   \n   LayoutLMv3 cannot run directly by default. Please follow the steps below to modify the configuration:\n\n   1. **Detectron2 Environment Setup**\n\n   .. code-block:: bash\n\n      # For Linux\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-linux_x86_64.whl\n\n      # For macOS\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-macosx_10_9_universal2.whl\n\n      # For Windows\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-win_amd64.whl\n\n   2. **Enable LayoutLMv3 Registration Code**\n\n   Uncomment the lines at the following links:\n   \n   - `line 2 <https://github.com/opendatalab/PDF-Extract-Kit/blob/main/pdf_extract_kit/tasks/layout_detection/__init__.py#L2>`_\n   - `line 8 <https://github.com/opendatalab/PDF-Extract-Kit/blob/main/pdf_extract_kit/tasks/layout_detection/__init__.py#L8>`_\n\n   .. code-block:: python\n\n      from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO\n      from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3\n      from pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n      __all__ = [\n         \"LayoutDetectionYOLO\",\n         \"LayoutDetectionLayoutlmv3\",\n      ]\n\n\n.. code:: yaml\n\n    inputs: assets/demo/layout_detection\n    outputs: outputs/layout_detection\n    tasks:\n      layout_detection:\n        model: layout_detection_layoutlmv3\n        model_config:\n          model_path: path/to/layoutlmv3_model\n\n- inputs/outputs: Define the input file path and the directory for visualization output.\n- tasks: Define the task type, currently only a layout detection task is included.\n- model: Specify the specific model type, e.g., layout_detection_layoutlmv3.\n- model_config: Define the model configuration.\n- model_path: Path to the model weights.\n\n\n\nDiverse Input Support\n-----------------\n\nThe layout detection script in PDF-Extract-Kit supports input formats such as a ``single image``, a ``directory containing only image files``, a ``single PDF file``, and a ``directory containing only PDF files``.\n\n.. note::\n\n   Modify the path to inputs in configs/layout_detection.yaml according to your actual data format:\n   - Single image: path/to/image  \n   - Image directory: path/to/images  \n   - Single PDF file: path/to/pdf  \n   - PDF directory: path/to/pdfs  \n\n.. note::\n   When using PDF as input, you need to change ``predict_images`` to ``predict_pdfs`` in ``layout_detection.py``.\n\n   .. code:: python\n\n      # for image detection\n      detection_results = model_layout_detection.predict_images(input_data, result_path)\n\n   Change to:\n\n   .. code:: python\n\n      # for pdf detection\n      detection_results = model_layout_detection.predict_pdfs(input_data, result_path)\n\nViewing Visualization Results\n-----------------\n\nWhen ``visualize`` is set to ``True`` in the config file, the visualization results will be saved in the ``outputs`` directory.\n\n.. note::\n\n   Visualization is helpful for analyzing model results, but for large-scale tasks, it is recommended to turn off visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage."
  },
  {
    "path": "docs/en/algorithm/ocr.rst",
    "content": "..  _algorithm_ocr:\n==========================\nOCR (Optical Character Recognition) Algorithm\n==========================\n\nIntroduction\n====================\n\nOCR(Optical Character Recognition) involves identifying the positions ajnd contents of all text blocks in pictures.\n\n\nModel Usage\n====================\n\nWith the environment properly set up, simply run the ocr algorithm script by executing ``scripts/ocr.py`` .\n\n.. code:: shell\n\n   $ python scripts/ocr.py --config configs/ocr.yaml\n\n\nModel Configuration\n--------------------\n\n.. code:: yaml\n\n   inputs: assets/demo/ocr\n   outputs: outputs/ocr\n   visualize: True\n   tasks:\n      ocr:\n         model: ocr_ppocr\n         model_config:\n            lang: ch\n            show_log: True\n            det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n            rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n            det_db_box_thresh: 0.3\n\n- inputs/outputs: Define the input path and the output path, respectively.\n- visualize: Whether to visualize the model results. Visualized results will be saved in the outputs directory.\n- tasks: Define the task type, currently only a OCR task is included.\n- model: Define the specific model type, currently, only the PaddleOCR model is available.\n- model_config: Define the model configuration.\n- lang: Define the language, default language ch supports both english and chinese.\n- show_log: Whether to print running logs.\n- det_model_dir: Define the path of PaddleOCR' detection model, If the specified path does not exist, the model weight will be automatically downloaded to the path.\n- rec_model_dir: Define the path of PaddleOCR' recognize model, If the specified path does not exist, the model weight will be automatically downloaded to the path.\n- det_db_box_thresh: Confidence filter threshold, bounding boxes whose confidence is lower than the threshold are discarded.\n\n\nDiverse Input Support\n--------------------\n\nThe OCR script in PDF-Extract-Kit supports various input formats such as ``a single image/PDF``, ``a directory of image/PDF files``.\n\n\nViewing Visualization Results\n--------------------\n\nWhen the ``visualize`` option in the config file is set to ``True``, visualization results will be saved in the ``outputs`` directory.\n\n.. note::\n\n   Visualization facilitates the analysis of model results. However, for large-scale tasks, it is recommended to disable visualization (set ``visualize`` to ``False`` ) to reduce memory and disk usage."
  },
  {
    "path": "docs/en/algorithm/reading_order.rst",
    "content": "..  _algorithm_reading_oder:\n==============\nReading Order Algorithm\n==============\n\nComming soon."
  },
  {
    "path": "docs/en/algorithm/table_recognition.rst",
    "content": "..  _algorithm_table_recognition:\n\n========================\nTable Recognition Algorithm\n========================\n\nIntroduction\n=================\n\nTable recognition refers to the process of inputting a table image, identifying the table structure and content, and converting it into formats such as ``LaTeX`` or ``HTML``.\n\nModel Usage\n=================\n\nWith the environment properly configured, you can run the table recognition algorithm script by directly executing ``scripts/table_parsing.py``.\n\n.. code:: shell\n\n   $ python scripts/table_parsing.py --config configs/table_parsing.yaml\n\nModel Configuration\n-----------------\n\n.. code:: yaml\n\n    inputs: assets/demo/table_parsing\n    outputs: outputs/table_parsing\n    tasks:\n      table_parsing:\n        model: table_parsing_struct_eqtable\n        model_config:\n          model_path: models/TabRec/StructEqTable\n          max_new_tokens: 1024\n          max_time: 30\n          output_format: latex\n          lmdeploy: False\n          flash_attn: True\n\n- inputs/outputs: Define the input file path and table recognition result directory respectively\n- tasks: Define the task type, currently only including one table recognition task\n- model: Define the specific model type: currently using the `StructEqTable <https://github.com/UniModal4Reasoning/StructEqTable-Deploy>`_ table recognition model\n- model_config: Define the model configuration\n- model_path: Path to the model weights\n- max_new_tokens: Maximum number of tokens to generate, default is 1024, maximum supported is 4096\n- max_time: Maximum runtime for the model (in seconds)\n- output_format: Output format, default is set to ``latex``, options include ``html`` and ``markdown``\n- lmdeploy: Whether to use LMDeploy for deployment, currently set to False\n- flash_attn: Whether to use flash attention, only available for Ampere GPUs\n\nDiverse Input Support\n-----------------\n\nThe table recognition script in PDF-Extract-Kit supports ``single table images`` and ``multiple table images`` as input.\n\n.. note::\n\n   The StructEqTable model only supports running on GPU devices\n\n.. note::\n    \n    Adjust ``max_new_tokens`` and ``max_time`` according to the table content, defaults are 1024 and 30 respectively.\n\n.. note::\n    \n    lmdeploy is an option for accelerated inference. If set to True, it will use LMDeploy for accelerated inference deployment.\n    To use LMDeploy deployment, you need to install LMDeploy. For installation methods, refer to `LMDeploy <https://github.com/InternLM/lmdeploy>`_."
  },
  {
    "path": "docs/en/conf copy.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nimport os\nimport subprocess\nimport sys\n\n# def install(package):\n#     subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n\n# # 安装 requirements.txt 中的依赖项\n# requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))\n# if os.path.exists(requirements_path):\n#     with open(requirements_path) as f:\n#         packages = f.readlines()\n#     for package in packages:\n#         install(package.strip())\n\nfrom sphinx.ext import autodoc\n\nsys.path.insert(0, os.path.abspath('../..'))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'PDF-Extract-Kit'\ncopyright = '2024, OpenDataLab'\nauthor = 'PDF-Extract-Kit Contributors'\n\n# The full version, including alpha/beta/rc tags\nversion_file = '../../pdf_extract_kit/version.py'\nwith open(version_file) as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.intersphinx',\n    'sphinx_copybutton',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'myst_parser',\n    'sphinxarg.ext',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# Exclude the prompt \"$\" when copying code\ncopybutton_prompt_text = r'\\$ '\ncopybutton_prompt_is_regexp = True\n\nlanguage = 'zh_CN'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/logo.png'\nhtml_theme_options = {\n    'path_to_docs': 'docs/zh_cn',\n    'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit',\n    'use_repository_button': True,\n}\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = ['_static']\n\n# Mock out external dependencies here.\nautodoc_mock_imports = [\n    'cpuinfo',\n    'torch',\n    'transformers',\n    'psutil',\n    'prometheus_client',\n    'sentencepiece',\n    'vllm.cuda_utils',\n    'vllm._C',\n    'numpy',\n    'tqdm',\n]\n\n\nclass MockedClassDocumenter(autodoc.ClassDocumenter):\n    \"\"\"Remove note about base class when a class is derived from object.\"\"\"\n\n    def add_line(self, line: str, source: str, *lineno: int) -> None:\n        if line == '   Bases: :py:class:`object`':\n            return\n        super().add_line(line, source, *lineno)\n\n\nautodoc.ClassDocumenter = MockedClassDocumenter\n\nnavigation_with_keys = False\n"
  },
  {
    "path": "docs/en/conf.bak",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nimport os\nimport sys\n\nfrom sphinx.ext import autodoc\n\nsys.path.insert(0, os.path.abspath('../..'))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'PDF-Extract-Kit'\ncopyright = '2024, PDF-Extract-Kit Contributors'\nauthor = 'PDF-Extract-Kit Contributors'\n\n# The full version, including alpha/beta/rc tags\nversion_file = '../../pdf_extract_kit/version.py'\nwith open(version_file) as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.intersphinx',\n    'sphinx_copybutton',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'myst_parser',\n    'sphinxarg.ext',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# Exclude the prompt \"$\" when copying code\ncopybutton_prompt_text = r'\\$ '\ncopybutton_prompt_is_regexp = True\n\nlanguage = 'en'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/logo.png'\nhtml_theme_options = {\n    'path_to_docs': 'docs/en',\n    'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit',\n    'use_repository_button': True,\n}\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = ['_static']\n\n# Mock out external dependencies here.\nautodoc_mock_imports = [\n    'cpuinfo',\n    'torch',\n    'transformers',\n    'psutil',\n    'prometheus_client',\n    'sentencepiece',\n    'vllm.cuda_utils',\n    'vllm._C',\n    'numpy',\n    'tqdm',\n]\n\n\nclass MockedClassDocumenter(autodoc.ClassDocumenter):\n    \"\"\"Remove note about base class when a class is derived from object.\"\"\"\n\n    def add_line(self, line: str, source: str, *lineno: int) -> None:\n        if line == '   Bases: :py:class:`object`':\n            return\n        super().add_line(line, source, *lineno)\n\n\nautodoc.ClassDocumenter = MockedClassDocumenter\n\nnavigation_with_keys = False\n"
  },
  {
    "path": "docs/en/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nimport os\nimport subprocess\nimport sys\n\ndef install(package):\n    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n\n# 安装 requirements.txt 中的依赖项\nrequirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))\nif os.path.exists(requirements_path):\n    with open(requirements_path) as f:\n        packages = f.readlines()\n    for package in packages:\n        install(package.strip())\n        \nfrom sphinx.ext import autodoc\nsys.path.insert(0, os.path.abspath('../..'))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'PDF-Extract-Kit'\ncopyright = '2024, PDF-Extract-Kit Contributors'\nauthor = 'OpenDataLab'\n\n# The full version, including alpha/beta/rc tags\nversion_file = '../../pdf_extract_kit/version.py'\nwith open(version_file) as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.intersphinx',\n    'sphinx_copybutton',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'myst_parser',\n    'sphinxarg.ext',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# Exclude the prompt \"$\" when copying code\ncopybutton_prompt_text = r'\\$ '\ncopybutton_prompt_is_regexp = True\n\nlanguage = 'en'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/logo.png'\nhtml_theme_options = {\n    'path_to_docs': 'docs/en',\n    'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit',\n    'use_repository_button': True,\n}\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = ['_static']\n\n# Mock out external dependencies here.\nautodoc_mock_imports = [\n    'cpuinfo',\n    'torch',\n    'transformers',\n    'psutil',\n    'prometheus_client',\n    'sentencepiece',\n    'vllm.cuda_utils',\n    'vllm._C',\n    'numpy',\n    'tqdm',\n]\n\n\nclass MockedClassDocumenter(autodoc.ClassDocumenter):\n    \"\"\"Remove note about base class when a class is derived from object.\"\"\"\n\n    def add_line(self, line: str, source: str, *lineno: int) -> None:\n        if line == '   Bases: :py:class:`object`':\n            return\n        super().add_line(line, source, *lineno)\n\n\nautodoc.ClassDocumenter = MockedClassDocumenter\n\nnavigation_with_keys = False"
  },
  {
    "path": "docs/en/evaluation/formula_detection.rst",
    "content": "=====================\nFormula Detection Evaluation\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/formula_recognition.rst",
    "content": "=====================\nFormula Recognition Evaluation\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/layout_detection.rst",
    "content": "=====================\nLayout Detection Evaluation\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/ocr.rst",
    "content": "=====================\nOCR Evaluation\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/pdf_extract.rst",
    "content": "=====================\nPDF Content Extraction Evaluation [End-to-End]\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/reading_order.rst",
    "content": "=====================\nReading Order Evaluation\n=====================\n\nXXX"
  },
  {
    "path": "docs/en/evaluation/table_recognition.rst",
    "content": "=====================\nTable Recognition Evaluation\n=====================\n\nXXX\n"
  },
  {
    "path": "docs/en/get_started/installation.rst",
    "content": "==================================\nInstallation\n==================================\n\nIn this section, we will demonstrate how to install PDF-Extract-Kit.\n\nBest Practices\n==============\n\nWe recommend users follow our best practices for installing PDF-Extract-Kit. It is recommended to use a Python 3.10 conda virtual environment for the installation.\n\n**Step 1.** Create a Python 3.10 virtual environment using conda.\n\n.. code-block:: console\n\n    $ conda create -n pdf-extract-kit-1.0 python=3.10 -y\n    $ conda activate pdf-extract-kit-1.0\n\n**Step 2.** Install the dependencies for PDF-Extract-Kit.\n\n.. code-block:: console\n\n    $ # For GPU devices\n    $ pip install -r requirements.txt\n    $ # For CPU-only devices\n    $ pip install -r requirements-cpu.txt\n\n.. note::\n\n    For the convenience of user environment configuration, requirements.txt only includes the environment needed for the current best models, which currently include:\n   \n    - Layout Detection: YOLO series (YOLOv10, DocLayout-YOLO)  \n    - Formula Detection: YOLO series (YOLOv8)  \n    - Formula Recognition: UniMERNet  \n    - OCR: PaddleOCR  \n\n    For other models, such as LayoutLMv3, additional environment setup is required. For details, see \\ :ref:`Layout Detection Algorithms <algorithm_layout_detection>`."
  },
  {
    "path": "docs/en/get_started/pretrained_model.rst",
    "content": "==================================\nModel Weights Download\n==================================\n\nBefore using the PDF-Extract-Kit, we need to download the required model weights. You can download all models or specific model files (e.g., formula detection MFD) according to your needs.\n\n[Recommended] Method 1: ``snapshot_download``\n========================================\n\nHuggingFace\n------------\n\n``huggingface_hub.snapshot_download`` supports downloading specific model weights from the HuggingFace Hub and allows multithreading. You can use the following code to download model weights in parallel:\n\n.. code:: python\n\n   from huggingface_hub import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', max_workers=20)\n\nIf you want to download a single algorithm model (e.g., the YOLO model for the formula detection task), use the following code:\n\n.. code:: python\n\n   from huggingface_hub import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') \n\n.. note::\n\n   Here, ``repo_id`` represents the name of the model on HuggingFace Hub, ``local_dir`` indicates the desired local storage path, ``max_workers`` specifies the maximum number of parallel downloads, and ``allow_patterns`` specifies the files you want to download.\n\n.. tip::\n\n   If ``local_dir`` is not specified, it will be downloaded to the default cache path of HuggingFace (``~/.cache/huggingface/hub``). To change the default cache path, modify the relevant environment variables:\n\n   .. code:: console\n\n      $ # Default is `~/.cache/huggingface/`\n      $ export HF_HOME=Comming soon!\n\n.. tip::\n   \n   If the download speed is slow (e.g., unable to reach maximum bandwidth), try setting ``export HF_HUB_ENABLE_HF_TRANSFER=1`` for higher download speeds.\n\nModelScope\n-----------\n\n``modelscope.snapshot_download`` supports downloading specified model weights. You can use the following command to download the model:\n\n.. code:: python\n\n   from modelscope import snapshot_download\n\n   snapshot_download(model_id='opendatalab/pdf-extract-kit-1.0', cache_dir='./')\n\nIf you want to download a single algorithm model (e.g., the YOLO model for the formula detection task), use the following code:\n\n.. code:: python\n\n   from modelscope import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') \n\n\n.. note::\n   Here, ``model_id`` represents the name of the model in the ModelScope library, ``cache_dir`` indicates the desired local storage path, and ``allow_patterns`` specifies the files you want to download.\n\n.. note::\n   ``modelscope.snapshot_download`` does not support multithreaded parallel downloads.\n\n.. tip::\n\n   If ``cache_dir`` is not specified, it will be downloaded to the default cache path of ModelScope (``~/.cache/huggingface/hub``).\n\n   To change the default cache path, modify the relevant environment variables:\n\n   .. code:: console\n\n      $ # Default is ~/.cache/modelscope/hub/\n      $ export MODELSCOPE_CACHE=XXXX\n\n\n\nMethod 2: Git LFS\n===================\n\nThe remote model repositories of HuggingFace and ModelScope are Git repositories managed by Git LFS. Therefore, we can use ``git clone`` to download the weights:\n\n.. code:: console\n\n   $ git lfs install\n   $ # From HuggingFace\n   $ git lfs clone https://huggingface.co/opendatalab/pdf-extract-kit-1.0\n   $ # From ModelScope\n   $ git clone https://www.modelscope.cn/opendatalab/pdf-extract-kit-1.0.git"
  },
  {
    "path": "docs/en/get_started/quickstart.rst",
    "content": "==================================\nQuick Start\n==================================\n\nOnce the PDF-Extract-Kit environment is set up and the models are downloaded, we can start using PDF-Extract-Kit.\n\nLayout Detection Example\n==============\n\nLayout detection offers several models: ``LayoutLMv3``, ``YOLOv10``, and ``DocLayout-YOLO``. Compared to ``LayoutLMv3``, ``YOLOv10`` is faster. ``DocLayout-YOLO`` is based on YOLOv10 and includes diverse document pre-training and model optimization, offering both speed and high accuracy.\n\n**1. Using Layout Detection Models**\n\n.. code-block:: console\n\n    $ python scripts/layout_detection.py --config configs/layout_detection.yaml\n\nAfter execution, we can view the detection results in the `outputs/layout_detection` directory.\n\n.. note::   \n\n    The ``layout_detection.yaml`` file sets the input, output, and model configuration. For a more detailed tutorial on layout detection, see :ref:`Layout Detection Algorithm <algorithm_layout_detection>`.\n\nFormula Detection Example\n==============\n\n.. code-block:: console\n\n    $ python scripts/formula_detection.py --config configs/formula_detection.yaml\n\nAfter execution, we can view the detection results in the `outputs/formula_detection` directory.\n\n.. note::   \n\n    The ``formula_detection.yaml`` file sets the input, output, and model configuration. For a more detailed tutorial on formula detection, see :ref:`Formula Detection Algorithm <algorithm_formula_detection>`."
  },
  {
    "path": "docs/en/index.rst",
    "content": ".. xtuner documentation master file, created by\n   sphinx-quickstart on Tue Jan  9 16:33:06 2024.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to the PDF-Extract-Kit Documentation\n==============================================\n\n.. figure:: ./_static/image/logo.png\n  :align: center\n  :alt: pdf-extract-kit\n  :class: no-scaled-link\n\n.. raw:: html\n\n   <p style=\"text-align:center\">\n   <strong>High-Quality Document Parsing Toolkit\n   </strong>\n   </p>\n\n   <p style=\"text-align:center\">\n   <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit\" data-show-count=\"true\" data-size=\"large\" aria-label=\"Star\">Star</a>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit/subscription\" data-icon=\"octicon-eye\" data-size=\"large\" aria-label=\"Watch\">Watch</a>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" aria-label=\"Fork\">Fork</a>\n   </p>\n\n\nTutorial\n-------------\n.. toctree::\n   :maxdepth: 2\n   :caption: Getting Started\n\n   get_started/installation.rst\n   get_started/pretrained_model.rst\n   get_started/quickstart.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Core Algorithm Modules\n\n   algorithm/layout_detection.rst\n   algorithm/formula_detection.rst\n   algorithm/formula_recognition.rst\n   algorithm/ocr.rst\n   algorithm/table_recognition.rst\n   algorithm/reading_order.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Task Extensions\n\n   task_extend/code.rst\n   task_extend/doc.rst\n   task_extend/evaluation.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Supported Models\n\n   models/supported.md\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Model Performance Evaluation\n\n   evaluation/layout_detection.rst\n   evaluation/formula_detection.rst\n   evaluation/formula_recognition.rst\n   evaluation/ocr.rst\n   evaluation/table_recognition.rst\n   evaluation/reading_order.rst\n   evaluation/pdf_extract.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: PDF Projects\n\n   project/pdf_extract.md\n   project/doc_translate.md\n   project/speed_up.md"
  },
  {
    "path": "docs/en/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.https://www.sphinx-doc.org/\n\texit /b 1\n)\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\n\n:end\npopd\n"
  },
  {
    "path": "docs/en/models/supported.md",
    "content": "# The Supported Models\n\n"
  },
  {
    "path": "docs/en/notes/changelog.md",
    "content": "<!--\n\n## vX.X.X (YYYY.MM.DD)\n\n### 亮点\n\n### 新功能和改进\n\n### Bug 修复\n\n### 贡献者\n\n-->\n\n# Changelog\n\n## v1.0.0 (2024-10-10)\n\nThe PDF-Extract-Kit-1.0 has been refactored with a more streamlined and user-friendly modular design! 🔥🔥🔥\n\n## v0.1.0 (2024-07-01)\n\nOfficial release of PDF-Extract-Kit! 🔥🔥🔥\n\n### Highlights\n\n- PDF-Extract-Kit-1.0 offers a high-quality layout detection model, DocLayout-YOLO."
  },
  {
    "path": "docs/en/project/doc_translate.rst",
    "content": "=================\nDocument Translation Project\n=================\n\nXXXX\nXXXX"
  },
  {
    "path": "docs/en/project/pdf_extract.rst",
    "content": "=================\nDocument Content Extraction Project\n=================\n\nIntroduction\n====================\n\nDocument content extraction aiming to extract all information of document file and convert it to computer readable result(such as markdown file). It's subtasks including layout detection, formula detection, formula recognition, OCR and other tasks.\n\n\nProject Usage\n====================\n\nWith the environment properly set up, simply run the project by executing ``project/pdf2markdown/scripts/run_project.py`` .\n\n.. code:: shell\n\n   $ python project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml\n\n\nProject Configuration\n--------------------\n\n.. code:: yaml\n\n    inputs: assets/demo/formula_detection\n    outputs: outputs/pdf2markdown\n    visualize: True\n    merge2markdown: True\n    tasks:\n        layout_detection:\n            model: layout_detection_yolo\n            model_config:\n                img_size: 1024\n                conf_thres: 0.25\n                iou_thres: 0.45\n                model_path: models/Layout/YOLO/doclayout_yolo_ft.pt\n        formula_detection:\n            model: formula_detection_yolo\n            model_config:\n                img_size: 1280\n                conf_thres: 0.25\n                iou_thres: 0.45\n                batch_size: 1\n                model_path: models/MFD/YOLO/yolo_v8_ft.pt\n        formula_recognition:\n            model: formula_recognition_unimernet\n            model_config:\n                batch_size: 128\n                cfg_path: pdf_extract_kit/configs/unimernet.yaml\n                model_path: models/MFR/unimernet_tiny\n        ocr:\n            model: ocr_ppocr\n            model_config:\n                lang: ch\n                show_log: True\n                det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n                rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n                det_db_box_thresh: 0.3\n\n- inputs/outputs: Define the input path and the output path, respectively.\n- visualize: Whether to visualize the project results. Visualized results will be saved in the outputs directory.\n- merge2markdown: Whether to merge the results into markdown documents. Only simple single-column text is supported. For markdown conversion of more complex layout documents, please refer to `MinerU <https://github.com/opendatalab/MinerU>`_ .\n- tasks: Define the task types, PDF document extraction includes layout detection, formula detection, formula recognition, and OCR tasks.\n- For details about the parameter meanings of each task and model, see the tutorial documentation of each task.\n\n\nDiverse Input Support\n--------------------\n\nThe Document content extraction script in PDF-Extract-Kit supports various input formats such as ``a single image/PDF``, ``a directory of image/PDF files``.\n\n\nOutput result\n--------------------\n\nThe extracted results of PDF documents are stored in the outputs path in the form of json. The format of json is as follows:\n\n.. code:: json\n\n    [\n        {\n            \"layout_dets\": [\n                {\n                    \"category_type\": \"text\",\n                    \"poly\": [\n                        380.6792698635707,\n                        159.85058512958923,\n                        765.1419999999998,\n                        159.85058512958923,\n                        765.1419999999998,\n                        192.51073013642917,\n                        380.6792698635707,\n                        192.51073013642917\n                    ],\n                    \"text\": \"this is an example text\",\n                    \"score\": 0.97\n                },\n                ...\n            ], \n            \"page_info\": {\n                \"page_no\": 0,\n                \"height\": 2339,\n                \"width\": 1654,\n            }\n        },\n        ...\n    ]\n\n- layout_dets: Single page of PDF or image content extraction results\n- category_type: The attribution of a single piece of content, such as headings, images, inline formulas, and so on\n- poly: The location coordinates of a single content block\n- text: Text content of a single content block\n- score: Confidence score\n- page_info: Page information, including page number and page size\n- page_no: Page number, counting from 0\n- height: Page size: height\n- width: Page size: width\n\nIf the ``merge2markdown`` parameter is True, an additional markdown file will be saved."
  },
  {
    "path": "docs/en/project/speed_up.rst",
    "content": "=================\nModel Acceleration Project\n=================\n\nXXXX\nXXXX"
  },
  {
    "path": "docs/en/switch_language.md",
    "content": "## <a href='https://pdf-extract-kit.readthedocs.io/en/latest/'>English</a>\n\n## <a href='https://pdf-extract-kit.readthedocs.io/zh_CN/latest/'>简体中文</a>\n"
  },
  {
    "path": "docs/en/task_extend/code.rst",
    "content": "==================================\nCode Implementation\n==================================\n\nThe core code of the PDF-Extract-Kit project is implemented in the `pdf_extract_kit` directory, which contains the following modules:\n\n- configs: Configuration files for specific modules, such as `pdf_extract_kit/configs/unimernet.yaml`. If the configuration is simple, it is recommended to define it in the `yaml` file's `model_config` in `repo_root/configs` for easier user modification.\n\n- dataset: A custom `ImageDataset` class used for loading and preprocessing image data. It supports various input types and can perform unified preprocessing operations on images (such as resizing, converting to tensors, etc.) to accelerate subsequent model inference.\n\n- evaluation: A module for evaluating model results, supporting evaluations for various task types such as `layout detection`, `formula detection`, `formula recognition`, etc., allowing users to fairly compare different tasks and models.\n\n- registry: The `Registry` class is a generic registry class that provides functions for registering, retrieving, and listing registered items. Users can use this class to create different types of registries, such as task registries, model registries, etc.\n\n- tasks: The core task module contains many different types of tasks, such as `layout detection`, `formula detection`, `formula recognition`, etc. Users typically only need to add code here to add new tasks and models.\n\n.. note::\n    Based on the above modular design, users generally only need to implement their new task class and corresponding model in `tasks` to extend new modules (in most cases, only the corresponding model needs to be implemented, as the task is already defined), and then register it in `registry`.\n\nBelow we take adding a YOLO-based `layout detection` model as an example to introduce how to add new tasks and models.\n\nTask Definition and Registration\n==============\n\nFirst, we add a `layout_detection` directory under `tasks`, and then add a `task.py` file in that directory to define the layout detection task class, as follows:\n\n.. code-block:: python\n\n    from pdf_extract_kit.registry.registry import TASK_REGISTRY\n    from pdf_extract_kit.tasks.base_task import BaseTask\n\n    @TASK_REGISTRY.register(\"layout_detection\")\n    class LayoutDetectionTask(BaseTask):\n        def __init__(self, model):\n            super().__init__(model)\n\n        def predict_images(self, input_data, result_path):\n            \"\"\"\n            Predict layouts in images.\n\n            Args:\n                input_data (str): Path to a single image file or a directory containing image files.\n                result_path (str): Path to save the prediction results.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            images = self.load_images(input_data)\n            # Perform detection\n            return self.model.predict(images, result_path)\n\n        def predict_pdfs(self, input_data, result_path):\n            \"\"\"\n            Predict layouts in PDF files.\n\n            Args:\n                input_data (str): Path to a single PDF file or a directory containing PDF files.\n                result_path (str): Path to save the prediction results.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            pdf_images = self.load_pdf_images(input_data)\n            # Perform detection\n            return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys()))\n\nAs you can see, the task definition includes the following key points:\n\n* Use the `@TASK_REGISTRY.register(\"layout_detection\")` syntax to directly register the layout task class under `TASK_REGISTRY`.\n* The `__init__` initialization function takes `model` as an argument, specifically referring to the `BaseTask` class.\n* Implement inference functions. Considering that layout detection usually processes images and PDF files, two functions `predict_images` and `predict_pdfs` are provided for users to choose flexibly.\n\nModel Definition and Registration\n==============\n\nNext, we implement the specific model by creating a `models` directory under `task` and adding `yolo.py` for YOLO model definition, as follows:\n\n.. code-block:: python\n\n    import os\n    import cv2\n    import torch\n    from torch.utils.data import DataLoader, Dataset\n    from ultralytics import YOLO\n    from pdf_extract_kit.registry import MODEL_REGISTRY\n    from pdf_extract_kit.utils.visualization import  visualize_bbox\n    from pdf_extract_kit.dataset.dataset import ImageDataset\n    import torchvision.transforms as transforms\n\n    @MODEL_REGISTRY.register('layout_detection_yolo')\n    class LayoutDetectionYOLO:\n        def __init__(self, config):\n            \"\"\"\n            Initialize the LayoutDetectionYOLO class.\n\n            Args:\n                config (dict): Configuration dictionary containing model parameters.\n            \"\"\"\n            # Mapping from class IDs to class names\n            self.id_to_names = {\n                0: 'title', \n                1: 'plain text',\n                2: 'abandon', \n                3: 'figure', \n                4: 'figure_caption', \n                5: 'table', \n                6: 'table_caption', \n                7: 'table_footnote', \n                8: 'isolate_formula', \n                9: 'formula_caption'\n            }\n\n            # Load the YOLO model from the specified path\n            self.model = YOLO(config['model_path'])\n\n            # Set model parameters\n            self.img_size = config.get('img_size', 1280)\n            self.pdf_dpi = config.get('pdf_dpi', 200)\n            self.conf_thres = config.get('conf_thres', 0.25)\n            self.iou_thres = config.get('iou_thres', 0.45)\n            self.visualize = config.get('visualize', False)\n            self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')\n            self.batch_size = config.get('batch_size', 1)\n\n        def predict(self, images, result_path, image_ids=None):\n            \"\"\"\n            Predict layouts in images.\n\n            Args:\n                images (list): List of images to be predicted.\n                result_path (str): Path to save the prediction results.\n                image_ids (list, optional): List of image IDs corresponding to the images.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            results = []\n            for idx, image in enumerate(images):\n                result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0]\n                if self.visualize:\n                    if not os.path.exists(result_path):\n                        os.makedirs(result_path)\n                    boxes = result.__dict__['boxes'].xyxy\n                    classes = result.__dict__['boxes'].cls\n                    vis_result = visualize_bbox(image, boxes, classes, self.id_to_names)\n\n                    # Determine the base name of the image\n                    if image_ids:\n                        base_name = image_ids[idx]\n                    else:\n                        base_name = os.path.basename(image)\n                    \n                    result_name = f\"{base_name}_MFD.png\"\n                    \n                    # Save the visualized result                \n                    cv2.imwrite(os.path.join(result_path, result_name), vis_result)\n                results.append(result)\n            return results\n\nAs you can see, the model definition includes the following key points:\n\n* Use the `@MODEL_REGISTRY.register('layout_detection_yolo')` syntax to directly register the YOLO layout model under `MODEL_REGISTRY`.\n* The initialization function needs to implement:\n    + The `id_to_names` category mapping for visualization.\n    + Model parameter configuration.\n    + Model initialization.\n* The model inference function needs to implement various types of model inference: it supports image lists and `PIL.Image` class, allowing users to perform inference directly based on image paths or image streams.\n\nAfter implementing the above class definition, add `LayoutDetectionYOLO` to the `__all__` in `__init__.py` under the `layout_detection` task.\n\n.. code-block:: python\n\n    from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO\n    from pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n    __all__ = [\n        \"LayoutDetectionYOLO\",\n    ]\n\n.. note:: \n    For the same task, we support multiple models. Users can choose which one to use based on evaluation results, considering model `accuracy`, `speed`, and `scenario adaptability`.\n\nAfter implementing the tasks and models, you can add a script program `layout_detection.py` under `repo_root/scripts`.\n\nExample Script\n==============\n\n.. code-block:: python\n\n    import os\n    import sys\n    import os.path as osp\n    import argparse\n\n    sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\n    from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\n    import pdf_extract_kit.tasks  # Ensure all task modules are imported\n\n    TASK_NAME = 'layout_detection'\n\n    def parse_args():\n        parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n        parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n        return parser.parse_args()\n\n    def main(config_path):\n        config = load_config(config_path)\n        task_instances = initialize_tasks_and_models(config)\n\n        # get input and output path from config\n        input_data = config.get('inputs', None)\n        result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n        # layout_detection_task\n        model_layout_detection = task_instances[TASK_NAME]\n\n        # for image detection\n        detection_results = model_layout_detection.predict_images(input_data, result_path)\n\n        # for pdf detection\n        # detection_results = model_layout_detection.predict_pdfs(input_data, result_path)\n\n        # print(detection_results)\n        print(f'The predicted results can be found at {result_path}')\n\n    if __name__ == \"__main__\":\n        args = parse_args()\n        main(args.config)\n\nSupport Type Extension\n==============\n\nBatch Processing Extension\n=============="
  },
  {
    "path": "docs/en/task_extend/doc.rst",
    "content": "==================================\nDocumentation Supplement\n==================================\n\n"
  },
  {
    "path": "docs/en/task_extend/evaluation.rst",
    "content": "==================================\nModel Performance Evaluation\n==================================\n\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx\nsphinx_rtd_theme\nmyst-parser\nsphinx-copybutton\nsphinx-argparse\nsphinx-book-theme"
  },
  {
    "path": "docs/zh_cn/.readthedocs.yaml",
    "content": "version: 2\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\nformats:\n  - epub\n\npython:\n  install:\n    - requirements: requirements/docs.txt\n\nsphinx:\n  configuration: docs/zh_cn/conf.py\n"
  },
  {
    "path": "docs/zh_cn/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/zh_cn/algorithm/formula_detection.rst",
    "content": "..  _algorithm_formula_detection:\n\n====================\n公式检测算法\n====================\n\n简介\n====================\n\n公式检测是针对给定的输入图像，检测出图像中所有包含公式的位置（包含行内公式和行间公式）\n\n.. note::\n\n   公式检测实际上属于布局检测子任务，但由于公式检查的复杂性，我们建议使用单独的公式检测模型解耦。\n   这样通常使得数据标注更加方便，且公式检测效果也更好。\n\n模型使用\n====================\n\n在配置好环境的情况下，直接执行 ``scripts/formula_detection.py`` 即可运行布局检测算法脚本。\n\n.. code:: shell\n\n   $ python scripts/formula_detection.py --config configs/formula_detection.yaml\n\n模型配置\n--------------------\n\n.. code:: yaml\n\n   inputs: assets/demo/formula_detection\n   outputs: outputs/formula_detection\n   tasks:\n      formula_detection:\n         model: formula_detection_yolo\n         model_config:\n            img_size: 1280\n            conf_thres: 0.25\n            iou_thres: 0.45\n            batch_size: 1\n            model_path: models/MFD/yolov8/weights.pt\n            visualize: True\n\n- inputs/outputs: 分别定义输入文件路径和可视化输出目录\n- tasks: 定义任务类型，当前只包含一个公式检测任务\n- model: 定义具体模型类型: 当前仅提供YOLO公式检测模型\n- model_config: 定义模型配置\n- img_size: 定义图像长边大小，短边会根据长边等比例缩放\n- conf_thres: 定义置信度阈值，仅检测大于该阈值的目标\n- iou_thres: 定义IoU阈值，去除重叠度大于该阈值的目标\n- batch_size: 定义批量大小，推理时每次同时推理的图像数，一般情况下越大推理速度越快，显卡越好该数值可以设置的越大\n- model_path: 模型权重路径\n- visualize: 是否对模型结果进行可视化，可视化结果会保存在outputs目录下。\n\n多样化输入支持\n--------------------\n\nPDF-Extract-Kit中的公式检测脚本支持 ``单个图像`` 、 ``只包含图像文件的目录`` 、 ``单个PDF文件`` 、 ``只包含PDF文件的目录`` 等输入形式。\n\n.. note:: \n\n   根据自己实际数据形式，修改 ``configs/formula_detection.yaml`` 中 ``inputs`` 的路径即可\n   - 单个图像: path/to/image  \n   - 图像文件夹: path/to/images  \n   - 单个PDF文件: path/to/pdf  \n   - PDF文件夹: path/to/pdfs  \n\n.. note::\n\n   当使用PDF作为输入时，需要将 ``formula_detection.py`` 中的 ``predict_images`` 修改为 ``predict_pdfs`` 。\n\n\n   .. code:: python\n\n      # for image detection\n      detection_results = model_formula_detection.predict_images(input_data, result_path)\n   \n\n   .. code:: python\n\n      # for pdf detection\n      detection_results = model_formula_detection.predict_pdfs(input_data, result_path)\n\n\n可视化结果查看\n--------------------\n\n当config文件中 ``visualize`` 设置为 ``True`` 时，可视化结果会保存在 ``outputs/formula_detection`` 目录下。\n\n.. note::\n\n   可视化可以方便对模型结果进行分析，但当进行大批量任务时，建议关掉可视化(设置 ``visualize`` 为 ``False`` )，减少内存和磁盘占用。"
  },
  {
    "path": "docs/zh_cn/algorithm/formula_recognition.rst",
    "content": "..  _algorithm_formula_recognition:\n\n============\n公式识别算法\n============\n\n简介\n=================\n\n公式检测是指给定输入公式图像，识别公式图像内容并转为 ``LaTeX`` 格式。\n\n模型使用\n=================\n\n在配置好环境的情况下，直接执行 ``scripts/formula_recognition.py`` 即可运行布局检测算法脚本。\n\n.. code:: shell\n\n   $ python scripts/formula_recognition.py --config configs/formula_recognition.yaml\n\n模型配置\n-----------------\n\n.. code:: yaml\n\n   inputs: assets/demo/formula_recognition\n   outputs: outputs/formula_recognition\n   tasks:\n      formula_recognition:\n         model: formula_recognition_unimernet\n         model_config:\n            cfg_path: pdf_extract_kit/configs/unimernet.yaml\n            model_path: models/MFR/unimernet_tiny\n            visualize: False\n\n- inputs/outputs: 分别定义输入文件路径和LaTeX预测结果目录\n- tasks: 定义任务类型，当前只包含一个公式识别任务\n- model: 定义具体模型类型: 当前仅提供 `UniMERNet <https://github.com/opendatalab/UniMERNet>`_ 公式识别模型\n- model_config: 定义模型配置\n- cfg_path: UniMERNet配置文件路径\n- model_path: 模型权重路径\n- visualize: 是否对模型结果进行可视化，可视化结果会保存在outputs目录下。\n\n多样化输入支持\n-----------------\n\nPDF-Extract-Kit中的公式检测脚本支持 ``单个公式图像`` 、 ``文档图像及对应公式区域``\n\n可视化结果查看\n-----------------\n\n当config文件中visualize设置为True时， ``LaTeX`` 预测结果会保存在outputs目录下。"
  },
  {
    "path": "docs/zh_cn/algorithm/layout_detection.rst",
    "content": ".. _algorithm_layout_detection:\n\n=================\n布局检测算法\n=================\n\n简介\n=================\n\n``布局检测`` 是文档内容提取的基础任务，目标对页面中不同类型的区域进行定位：如 ``图像`` 、 ``表格`` 、 ``文本`` 、 ``标题`` 等，方便后续高质量内容提取。对于 ``文本`` 、 ``标题`` 等区域，可以基于 ``OCR模型`` 进行文字识别，对于表格区域可以基于表格识别模型进行转换。\n\n模型使用\n=================\n\n布局检测模型支持以下模型：\n\n.. raw:: html\n\n    <style type=\"text/css\">\n    .tg  {border-collapse:collapse;border-color:#9ABAD9;border-spacing:0;}\n    .tg td{background-color:#EBF5FF;border-color:#9ABAD9;border-style:solid;border-width:1px;color:#444;\n      font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg th{background-color:#409cff;border-color:#9ABAD9;border-style:solid;border-width:1px;color:#fff;\n      font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg .tg-f8tz{background-color:#409cff;border-color:inherit;text-align:left;vertical-align:top}\n    .tg .tg-0lax{text-align:left;vertical-align:top}\n    .tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}\n    </style>\n    <table class=\"tg\"><thead>\n      <tr>\n        <th class=\"tg-0lax\">模型</th>\n        <th class=\"tg-f8tz\">简述</th>\n        <th class=\"tg-f8tz\">特点</th>\n        <th class=\"tg-f8tz\">模型权重</th>\n        <th class=\"tg-f8tz\">配置文件</th>\n      </tr></thead>\n    <tbody>\n      <tr>\n        <td class=\"tg-0lax\">DocLayout-YOLO</td>\n        <td class=\"tg-0pky\">基于YOLO-v10模型改进：<br>1. 生成多样性预训练数据，提升对多种类型文档泛化性<br>2. 模型结构改进，提升对多尺度目标感知能力<br>详见<a href=\"https://github.com/opendatalab/DocLayout-YOLO\" target=\"_blank\" rel=\"noopener noreferrer\">DocLayout-YOLO</a></td>\n        <td class=\"tg-0pky\">速度快、精度高</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/blob/main/models/Layout/YOLO/doclayout_yolo_ft.pt\" target=\"_blank\" rel=\"noopener noreferrer\">doclayout_yolo_ft.pt</a></td>\n        <td class=\"tg-0pky\">layout_detection.yaml</td>\n      </tr>\n      <tr>\n        <td class=\"tg-0lax\">YOLO-v10</td>\n        <td class=\"tg-0pky\">基础YOLO-v10模型</td>\n        <td class=\"tg-0pky\">速度快，精度一般</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/blob/main/models/Layout/YOLO/yolov10l_ft.pt\" target=\"_blank\" rel=\"noopener noreferrer\">yolov10l_ft.pt</a></td>\n        <td class=\"tg-0pky\">layout_detection_yolo.yaml</td>\n      </tr>\n      <tr>\n        <td class=\"tg-0lax\">LayoutLMv3</td>\n        <td class=\"tg-0pky\">基础LayoutLMv3模型</td>\n        <td class=\"tg-0pky\">速度慢，精度较好</td>\n        <td class=\"tg-0pky\"><a href=\"https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/tree/main/models/Layout/LayoutLMv3\" target=\"_blank\" rel=\"noopener noreferrer\">layoutlmv3_ft</a></td>\n        <td class=\"tg-0pky\">layout_detection_layoutlmv3.yaml</td>\n      </tr>\n    </tbody></table>\n\n\n在配置好环境的情况下，直接执行 ``scripts/layout_detection.py`` 即可运行布局检测算法脚本。\n\n\n**执行布局检测程序**\n\n.. code:: shell\n\n   $ python scripts/layout_detection.py --config configs/layout_detection.yaml\n\n模型配置\n-----------------\n\n**1. DocLayout-YOLO / YOLO-v10**\n\n.. code:: yaml\n\n    inputs: assets/demo/layout_detection\n    outputs: outputs/layout_detection\n    tasks:\n      layout_detection:\n        model: layout_detection_yolo\n        model_config:\n          img_size: 1024\n          conf_thres: 0.25\n          iou_thres: 0.45\n          model_path: path/to/doclayout_yolo_model\n          visualize: True\n\n- inputs/outputs: 分别定义输入文件路径和可视化输出目录\n- tasks: 定义任务类型，当前只包含一个布局检测任务\n- model: 定义具体模型类型，例如 ``layout_detection_yolo``\n- model_config: 定义模型配置\n- img_size: 定义图像长边大小，短边会根据长边等比例缩放，默认长边保持1024\n- conf_thres: 定义置信度阈值，仅检测大于该阈值的目标\n- iou_thres: 定义IoU阈值，去除重叠度大于该阈值的目标\n- model_path: 模型权重路径\n- visualize: 是否对模型结果进行可视化，可视化结果会保存在outputs目录下\n\n\n**2. LayoutLMv3**\n\n.. note::\n\n   LayoutLMv3 默认情况下不能直接运行。运行时请将配置文件修改为configs/layout_detection_layoutlmv3.yaml，并且请按照以下步骤进行配置修改：\n\n   1. **Detectron2 环境配置**\n\n   .. code-block:: bash\n\n      # 对于 Linux\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-linux_x86_64.whl\n\n      # 对于 macOS\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-macosx_10_9_universal2.whl\n\n      # 对于 Windows\n      pip install https://wheels-1251341229.cos.ap-shanghai.myqcloud.com/assets/whl/detectron2/detectron2-0.6-cp310-cp310-win_amd64.whl\n\n   2. **启用 LayoutLMv3 注册代码**\n\n   请取消注释以下链接中的代码行：\n   \n   - `第2行 <https://github.com/opendatalab/PDF-Extract-Kit/blob/main/pdf_extract_kit/tasks/layout_detection/__init__.py#L2>`_\n   - `第8行 <https://github.com/opendatalab/PDF-Extract-Kit/blob/main/pdf_extract_kit/tasks/layout_detection/__init__.py#L8>`_\n\n   .. code-block:: python\n\n      from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO\n      from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3\n      from pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n      __all__ = [\n         \"LayoutDetectionYOLO\",\n         \"LayoutDetectionLayoutlmv3\",\n      ]\n\n.. code:: yaml\n\n    inputs: assets/demo/layout_detection\n    outputs: outputs/layout_detection\n    tasks:\n      layout_detection:\n        model: layout_detection_layoutlmv3\n        model_config:\n          model_path: path/to/layoutlmv3_model\n\n- inputs/outputs: 分别定义输入文件路径和可视化输出目录\n- tasks: 定义任务类型，当前只包含一个布局检测任务\n- model: 定义具体模型类型，例如layout_detection_layoutlmv3\n- model_config: 定义模型配置\n- model_path: 模型权重路径\n\n\n多样化输入支持\n-----------------\n\nPDF-Extract-Kit中的布局检测脚本支持 ``单个图像`` 、 ``只包含图像文件的目录`` 、 ``单个PDF文件`` 、 ``只包含PDF文件的目录`` 等输入形式。\n\n.. note::\n\n   根据自己实际数据形式，修改configs/layout_detection.yaml中inputs的路径即可\n   - 单个图像: path/to/image  \n   - 图像文件夹: path/to/images  \n   - 单个PDF文件: path/to/pdf  \n   - PDF文件夹: path/to/pdfs  \n\n.. note::\n   当使用PDF作为输入时，需要将 ``layout_detection.py``\n\n   .. code:: python\n\n      # for image detection\n      detection_results = model_layout_detection.predict_images(input_data, result_path)\n\n   中的 ``predict_images`` 修改为 ``predict_pdfs`` 。\n\n   .. code:: python\n\n      # for pdf detection\n      detection_results = model_layout_detection.predict_pdfs(input_data, result_path)\n\n可视化结果查看\n-----------------\n\n当config文件中 ``visualize`` 设置为 ``True`` 时，可视化结果会保存在 ``outputs`` 目录下。\n\n.. note::\n\n   可视化可以方便对模型结果进行分析，但当进行大批量任务时，建议关掉可视化(设置 ``visualize`` 为 ``False`` )，减少内存和磁盘占用。"
  },
  {
    "path": "docs/zh_cn/algorithm/ocr.rst",
    "content": "..  _algorithm_ocr:\n==========================\n光学字符识别(OCR)算法\n==========================\n\n简介\n====================\n\n光学字符识别(OCR)是指对图片中的文字块进行检测和识别。\n\n\n模型使用\n====================\n\n在配置好环境的情况下，直接执行 ``scripts/ocr.py`` 即可运行OCR算法脚本。\n\n.. code:: shell\n\n   $ python scripts/ocr.py --config configs/ocr.yaml\n\n\n模型配置\n--------------------\n\n.. code:: yaml\n\n   inputs: assets/demo/ocr\n   outputs: outputs/ocr\n   visualize: True\n   tasks:\n      ocr:\n         model: ocr_ppocr\n         model_config:\n            lang: ch\n            show_log: True\n            det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n            rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n            det_db_box_thresh: 0.3\n\n- inputs/outputs: 分别定义输入文件路径和输出路径\n- visualize: 是否对模型结果进行可视化，可视化结果会保存在outputs目录下。\n- tasks: 定义任务类型，当前只包含一个OCR任务\n- model: 定义具体模型类型, 当前仅提供PaddleOCR模型\n- model_config: 定义模型配置\n- lang: 定义语种，默认语种ch支持中英文文字的检测和识别\n- show_log: 是否打印检测识别过程的日志\n- det_model_dir: 定义PaddleOCR检测模型的路径，指定路径不存在时，会自动下载模型权重到该路径\n- rec_model_dir: 定义PaddleOCR识别模型的路径，指定路径不存在时，会自动下载模型权重到该路径\n- det_db_box_thresh: 检测框筛选阈值，置信度低于该阈值的框会被舍弃\n\n\n多样化输入支持\n--------------------\n\nPDF-Extract-Kit中的OCR脚本支持 ``单个图像/PDF文件`` 、 ``包含图像/PDF文件的目录`` 等输入形式。\n\n\n可视化结果查看\n--------------------\n\n当config文件中 ``visualize`` 设置为 ``True`` 时，可视化结果会保存在 ``outputs`` 参数指定的目录下。\n\n.. note::\n\n   可视化可以方便对模型结果进行分析，但当进行大批量任务时，建议关掉可视化(设置 ``visualize`` 为 ``False`` )，减少内存和磁盘占用。"
  },
  {
    "path": "docs/zh_cn/algorithm/reading_order.rst",
    "content": "..  _algorithm_reading_oder:\n==============\n阅读顺序算法\n==============\n\nComming soon."
  },
  {
    "path": "docs/zh_cn/algorithm/table_recognition.rst",
    "content": "..  _algorithm_table_recognition:\n\n============\n表格识别算法\n============\n\n简介\n=================\n\n表格识别是指输入表格图像，识别表格结构和内容，并将其转换为 ``LaTeX`` 或 ``HTML`` 等格式。\n\n模型使用\n=================\n\n在配置好环境的情况下，直接执行 ``scripts/table_parsing.py`` 即可运行表格识别算法脚本。\n\n.. code:: shell\n\n   $ python scripts/table_parsing.py --config configs/table_parsing.yaml\n\n模型配置\n-----------------\n\n.. code:: yaml\n\n    inputs: assets/demo/table_parsing\n    outputs: outputs/table_parsing\n    tasks:\n      table_parsing:\n        model: table_parsing_struct_eqtable\n        model_config:\n          model_path: models/TabRec/StructEqTable\n          max_new_tokens: 1024\n          max_time: 30\n          output_format: latex\n          lmdeploy: False\n          flash_attn: True\n\n- inputs/outputs: 分别定义输入文件路径和表格识别结果目录\n- tasks: 定义任务类型，当前只包含一个表格识别任务\n- model: 定义具体模型类型: 当前使用 `StructEqTable  <https://github.com/UniModal4Reasoning/StructEqTable-Deploy>`_ 表格识别模型\n- model_config: 定义模型配置\n- model_path: 模型权重路径\n- max_new_tokens: 生成的最大token数量, 默认为1024, 最大支持4096\n- max_time: 模型运行的最大时间（秒）\n- output_format: 输出格式，默认设置为 ``latex``, 可选有 ``html`` 和 ``markdown``\n- lmdeploy: 是否使用 LMDeploy 进行部署，当前设置为 False\n- flash_attn: 是否使用flash attention，仅适用于Ampere GPU\n\n\n多样化输入支持\n-----------------\n\nPDF-Extract-Kit中的表格识别脚本支持 ``单个表格图像`` 和 ``多个表格图像`` 作为输入。\n\n.. note::\n\n   StructEqTable表格模型仅支持GPU设备下运行\n\n.. note::\n    \n    根据表格内容调整 ``max_new_tokens`` 和 ``max_time``, 默认分别为1024和30。\n\n.. note::\n    \n    lmdeploy为加速推理的选项，如果设置为True，将使用LMDeploy进行加速推理部署。\n    使用LMDeploy部署需要安装LMDeploy，安装方法参考 `LMDeploy <https://github.com/InternLM/lmdeploy>`_ 。\n\n"
  },
  {
    "path": "docs/zh_cn/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nimport os\nimport subprocess\nimport sys\n\ndef install(package):\n    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n\n# 安装 requirements.txt 中的依赖项\nrequirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))\nif os.path.exists(requirements_path):\n    with open(requirements_path) as f:\n        packages = f.readlines()\n    for package in packages:\n        install(package.strip())\n\nfrom sphinx.ext import autodoc\nsys.path.insert(0, os.path.abspath('../..'))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'PDF-Extract-Kit'\ncopyright = '2024, OpenDataLab'\nauthor = 'PDF-Extract-Kit Contributors'\n\n# The full version, including alpha/beta/rc tags\nversion_file = '../../pdf_extract_kit/version.py'\nwith open(version_file) as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.intersphinx',\n    'sphinx_copybutton',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'myst_parser',\n    'sphinxarg.ext',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# Exclude the prompt \"$\" when copying code\ncopybutton_prompt_text = r'\\$ '\ncopybutton_prompt_is_regexp = True\n\nlanguage = 'zh_CN'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/logo.png'\nhtml_theme_options = {\n    'path_to_docs': 'docs/zh_cn',\n    'repository_url': 'https://github.com/opendatalab/PDF-Extract-Kit',\n    'use_repository_button': True,\n}\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = ['_static']\n\n# Mock out external dependencies here.\nautodoc_mock_imports = [\n    'cpuinfo',\n    'torch',\n    'transformers',\n    'psutil',\n    'prometheus_client',\n    'sentencepiece',\n    'vllm.cuda_utils',\n    'vllm._C',\n    'numpy',\n    'tqdm',\n]\n\n\nclass MockedClassDocumenter(autodoc.ClassDocumenter):\n    \"\"\"Remove note about base class when a class is derived from object.\"\"\"\n\n    def add_line(self, line: str, source: str, *lineno: int) -> None:\n        if line == '   Bases: :py:class:`object`':\n            return\n        super().add_line(line, source, *lineno)\n\n\nautodoc.ClassDocumenter = MockedClassDocumenter\n\nnavigation_with_keys = False"
  },
  {
    "path": "docs/zh_cn/evaluation/formula_detection.rst",
    "content": "=====================\n公式检测算法评测\n=====================\n\nXXX"
  },
  {
    "path": "docs/zh_cn/evaluation/formula_recognition.rst",
    "content": "=====================\n公式识别算法评测\n=====================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/evaluation/layout_detection.rst",
    "content": "=====================\n布局检测算法评测\n=====================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/evaluation/ocr.rst",
    "content": "=====================\nOCR算法评测\n=====================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/evaluation/pdf_extract.rst",
    "content": "=====================\nPDF内容提取评测【端到端】\n=====================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/evaluation/reading_order.rst",
    "content": "=====================\n阅读顺序算法评测\n=====================\n\nXXX"
  },
  {
    "path": "docs/zh_cn/evaluation/table_recognition.rst",
    "content": "=====================\n表格识别算法评测\n=====================\n\nComming soon!\n"
  },
  {
    "path": "docs/zh_cn/get_started/installation.rst",
    "content": "==================================\n安装\n==================================\n\n本节中，我们将演示如何安装 PDF-Extract-Kit。\n\n最佳实践\n========\n\n我们推荐用户参照我们的最佳实践安装 PDF-Extract-Kit。\n推荐使用 Python-3.10 的 conda 虚拟环境安装 PDF-Extract-Kit。\n\n**步骤 1.** 使用 conda 先构建一个 Python-3.10 的虚拟环境\n\n.. code-block:: console\n\n    $ conda create -n pdf-extract-kit-1.0 python=3.10 -y\n    $ conda activate pdf-extract-kit-1.0\n\n**步骤 2.** 安装 PDF-Extract-Kit 的依赖项\n\n.. code-block:: console\n\n    $ # 对于GPU设备\n    $ pip install -r requirements.txt\n    $ # 对于CPU设备\n    $ pip install -r requirements-cpu.txt\n\n.. note::\n\n    考虑到用户环境配置的便捷性，我们在requirements.txt只包含当前最好模型需要的环境，目前包含  \n\n    - 布局检测：YOLO系列（YOLOv10, DocLayout-YOLO）  \n    - 公式检测：YOLO系列 (YOLOv8)  \n    - 公式识别：UniMERNet  \n    - OCR： PaddleOCR  \n\n    对于其他模型请，如LayoutLMv3需要单独安装环境，具体见\\ :ref:`布局检测算法 <algorithm_layout_detection>`"
  },
  {
    "path": "docs/zh_cn/get_started/pretrained_model.rst",
    "content": "==================================\n模型权重下载\n==================================\n\n在使用PDF-Extract-Kit前，我们需要下载所需要的模型权重。可以根据自己需求下载全部模型或者特定的模型文件（如公式检测MFD）\n\n[推荐] 方法 1：``snapshot_download``\n========================================\n\nHuggingFace\n------------\n\n``huggingface_hub.snapshot_download`` 支持下载特定的 HuggingFace Hub\n模型权重，并且允许多线程。您可以利用下列代码并行下载模型权重：\n\n.. code:: python\n\n   from huggingface_hub import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', max_workers=20)\n\n如果想仅下载单个算法模型（如公式检测任务的YOLO模型），可以使用如下代码：\n\n.. code:: python\n\n   from huggingface_hub import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') \n\n.. note::\n\n   其中，\\ ``repo_id`` 表示模型在 HuggingFace Hub 的名字、\\ ``local_dir`` 表示期望存储到的本地路径、\\ ``max_workers`` 表示下载的最大并行数，\\ ``allow_patterns`` 表示想要现在的文件。\n\n.. tip::\n\n   如果未指定 ``local_dir``\\ ，则将下载至 HuggingFace 的默认 cache 路径中（\\ ``~/.cache/huggingface/hub``\\ ）。若要修改默认 cache 路径，需要修改相关环境变量：\n\n   .. code:: console\n\n      $ # 默认为 ~/.cache/huggingface/\n      $ export HF_HOME=Comming soon!\n\n.. tip::\n   \n   如果觉得下载较慢（例如无法达到最大带宽等情况），可以尝试设置\\ ``export HF_HUB_ENABLE_HF_TRANSFER=1`` 以获得更高的下载速度。\n\nModelScope\n-----------\n\n``modelscope.snapshot_download``\n支持下载指定的模型权重，您可以利用下列命令下载模型：\n\n.. code:: python\n\n   from modelscope import snapshot_download\n\n   snapshot_download(model_id='opendatalab/pdf-extract-kit-1.0', cache_dir='./')\n\n如果想仅下载单个算法模型（如公式检测任务的YOLO模型），可以使用如下代码：\n\n.. code:: python\n\n   from modelscope import snapshot_download\n\n   snapshot_download(repo_id='opendatalab/pdf-extract-kit-1.0', local_dir='./', allow_patterns='models/MFD/YOLO/*') \n\n\n.. note::\n   其中，\\ ``model_id`` 表示模型在 ModelScope 模型库的名字，\\ ``cache_dir`` 表示期望存储到的本地路径， \\ ``allow_patterns`` 表示想要现在的文件。\n\n\n.. note::\n   ``modelscope.snapshot_download`` 不支持多线程并行下载。\n\n.. tip::\n\n   如果未指定 ``cache_dir``\\ ，则将下载至 ModelScope 的默认 cache 路径中（\\ ``~/.cache/huggingface/hub``\\ ）。\n\n   若要修改默认 cache 路径，需要修改相关环境变量：\n\n   .. code:: console\n\n      $ # 默认为 ~/.cache/modelscope/hub/\n      $ export MODELSCOPE_CACHE=XXXX\n\n\n\n方法 2： Git LFS\n===================\n\nHuggingFace 和 ModelScope 的远程模型仓库就是一个由 Git LFS 管理的 Git\n仓库。因此，我们可以利用 ``git clone`` 完成权重的下载：\n\n.. code:: console\n\n   $ git lfs install\n   $ # From HuggingFace\n   $ git lfs clone https://huggingface.co/opendatalab/pdf-extract-kit-1.0\n   $ # From ModelScope\n   $ git clone https://www.modelscope.cn/opendatalab/pdf-extract-kit-1.0.git\n"
  },
  {
    "path": "docs/zh_cn/get_started/quickstart.rst",
    "content": "==================================\n快速开始\n==================================\n\n配置好PDF-Extract-Kit环境，并下载好模型后，我们可以开始使用PDF-Extract-Kit了。\n\n\n\n布局检测示例\n==============\n\n布局检测提供了多种模型: ``LayoutLMv3``、 ``YOLOv10``、  ``DocLayout-YOLO``， 相比与 ``LayoutLMv3``， ``YOLOv10`` 速度更快， ``DocLayout-YOLO`` 则是基于 ``YOLOv10`` 的基础上进行多样性文档预训练及模型优化，速度快，精度高。\n\n**1. 使用布局检测模型**\n\n.. code-block:: console\n\n    $ python scripts/layout_detection.py --config configs/layout_detection.yaml\n\n执行完之后，我们可以在 ``outpus/layout_detection`` 目录下查看检测结果。\n\n.. note::   \n\n    ``layout_detection.yaml`` 设置输入、输出及模型配置，布局检测更详细教程见\\ :ref:`布局检测算法 <algorithm_layout_detection>` \\ 。\n\n\n公式检测示例\n==============\n\n\n.. code-block:: console\n\n    $ python scripts/formula_detection.py --config configs/formula_detection.yaml\n\n执行完之后，我们可以在 ``outpus/formula_detection`` 目录下查看检测结果。\n\n.. note::   \n\n    ``formula_detection.yaml`` 设置输入、输出及模型配置，公式检测更详细教程见 \\ :ref:`公式检测算法 <algorithm_formula_detection>` \\ 。\n"
  },
  {
    "path": "docs/zh_cn/index.rst",
    "content": ".. xtuner documentation master file, created by\n   sphinx-quickstart on Tue Jan  9 16:33:06 2024.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\n欢迎来到 PDF-Extract-Kit 的中文文档\n==============================================\n\n.. figure:: ./_static/image/logo.png\n  :align: center\n  :alt: pdf-extract-kit\n  :class: no-scaled-link\n\n.. raw:: html\n\n   <p style=\"text-align:center\">\n   <strong>高质量文档解析工具箱\n   </strong>\n   </p>\n\n   <p style=\"text-align:center\">\n   <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit\" data-show-count=\"true\" data-size=\"large\" aria-label=\"Star\">Star</a>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit/subscription\" data-icon=\"octicon-eye\" data-size=\"large\" aria-label=\"Watch\">Watch</a>\n   <a class=\"github-button\" href=\"https://github.com/opendatalab/PDF-Extract-Kit/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" aria-label=\"Fork\">Fork</a>\n   </p>\n\n\n文档\n-------------\n.. toctree::\n   :maxdepth: 2\n   :caption: 快速上手\n\n   get_started/installation.rst\n   get_started/pretrained_model.rst\n   get_started/quickstart.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 基础算法模块\n\n   algorithm/layout_detection.rst\n   algorithm/formula_detection.rst\n   algorithm/formula_recognition.rst\n   algorithm/ocr.rst\n   algorithm/table_recognition.rst\n   algorithm/reading_order.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 新任务拓展\n\n   task_extend/code.rst\n   task_extend/doc.rst\n   task_extend/evaluation.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 支持的模型列表\n\n   models/supported.md\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 模型性能评测\n\n   evaluation/layout_detection.rst\n   evaluation/formula_detection.rst\n   evaluation/formula_recognition.rst\n   evaluation/ocr.rst\n   evaluation/table_recognition.rst\n   evaluation/reading_order.rst\n   evaluation/pdf_extract.rst\n\n.. toctree::\n   :maxdepth: 2\n   :caption: PDF项目\n\n   project/pdf_extract.md\n   project/doc_translate.md\n   project/speed_up.md"
  },
  {
    "path": "docs/zh_cn/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.https://www.sphinx-doc.org/\n\texit /b 1\n)\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\n\n:end\npopd\n"
  },
  {
    "path": "docs/zh_cn/models/supported.md",
    "content": "# 已支持的模型\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/notes/changelog.md",
    "content": "<!--\n\n## vX.X.X (YYYY.MM.DD)\n\n### 亮点\n\n### 新功能和改进\n\n### Bug 修复\n\n### 贡献者\n\n-->\n\n# 变更日志\n\n\n## v0.2.0 (2024.09.30)\n\nPDF-Extract-Kit 代码重构，模块化设计更加简洁易用! 🔥🔥🔥\n\n## v0.1.0 (2024.07.01)\n\nPDF-Extract-Kit 正式发布！🔥🔥🔥\n\n### 亮点\n\n- PDF-Extract-Kit提供高质量布局检测模型 DocLayout-YOLO\n- PDF-Extract-Kit提供高质量公式检测模型 YOLOv8"
  },
  {
    "path": "docs/zh_cn/project/doc_translate.rst",
    "content": "=================\n文档翻译项目\n=================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/project/pdf_extract.rst",
    "content": "=================\n文档内容提取项目\n=================\n\n简介\n====================\n\n文档内容提取是利用布局检测，公式检测，公式识别，OCR等模型，提取文档中的信息，并转换为markdown文本。\n\n\n项目使用\n====================\n\n在配置好环境的情况下，直接执行 ``project/pdf2markdown/scripts/run_project.py`` 即可运行文档内容提取项目。\n\n.. code:: shell\n\n   $ python project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml\n\n\n项目配置\n--------------------\n\n.. code:: yaml\n\n    inputs: assets/demo/formula_detection\n    outputs: outputs/pdf2markdown\n    visualize: True\n    merge2markdown: True\n    tasks:\n        layout_detection:\n            model: layout_detection_yolo\n            model_config:\n                img_size: 1024\n                conf_thres: 0.25\n                iou_thres: 0.45\n                model_path: models/Layout/YOLO/doclayout_yolo_ft.pt\n        formula_detection:\n            model: formula_detection_yolo\n            model_config:\n                img_size: 1280\n                conf_thres: 0.25\n                iou_thres: 0.45\n                batch_size: 1\n                model_path: models/MFD/YOLO/yolo_v8_ft.pt\n        formula_recognition:\n            model: formula_recognition_unimernet\n            model_config:\n                batch_size: 128\n                cfg_path: pdf_extract_kit/configs/unimernet.yaml\n                model_path: models/MFR/unimernet_tiny\n        ocr:\n            model: ocr_ppocr\n            model_config:\n                lang: ch\n                show_log: True\n                det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n                rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n                det_db_box_thresh: 0.3\n\n- inputs/outputs: 分别定义输入文件路径和输出路径\n- visualize: 是否对模型结果进行可视化，可视化结果会保存在outputs目录下。\n- merge2markdown: 是否将结果合并为markdown文档，这里只支持简单的单栏文本从上往下进行拼接，更复杂布局文档的markdown转换请参考 `MinerU <https://github.com/opendatalab/MinerU>`_\n- tasks: 定义任务类型，PDF文档提取包含了布局检测、公式检测、公式识别、OCR等任务\n- 具体每个任务和模型的参数含义请参考各任务的教程文档\n\n\n多样化输入支持\n--------------------\n\nPDF文档内容提取支持 ``单个图像/PDF文件`` 、 ``包含图像/PDF文件的目录`` 等输入形式。\n\n\n输出结果\n--------------------\n\nPDF文档提取的结果以json形式保存在 ``outputs`` 路径下，json的格式如下所示：\n\n.. code:: json\n\n    [\n        {\n            \"layout_dets\": [\n                {\n                    \"category_type\": \"text\",\n                    \"poly\": [\n                        380.6792698635707,\n                        159.85058512958923,\n                        765.1419999999998,\n                        159.85058512958923,\n                        765.1419999999998,\n                        192.51073013642917,\n                        380.6792698635707,\n                        192.51073013642917\n                    ],\n                    \"text\": \"this is an example text\",\n                    \"score\": 0.97\n                },\n                ...\n            ], \n            \"page_info\": {\n                \"page_no\": 0,\n                \"height\": 2339,\n                \"width\": 1654,\n            }\n        },\n        ...\n    ]\n\n- layout_dets: 单页PDF或图片的内容提取结果\n- category_type: 单个内容块的所属内别，比如标题、图片、行内公式等等\n- poly: 单个内容块的位置坐标\n- text: 该文本块的文本内容\n- score: 检测的置信度\n- page_info: 页面信息，包含页码和页面尺寸\n- page_no: 页码，从0开始计数\n- height: 页面尺寸: 高\n- width: 页面尺寸: 宽\n\n如果 ``merge2markdown`` 参数为True的话，则会额外保存一个markdown文件。"
  },
  {
    "path": "docs/zh_cn/project/speed_up.rst",
    "content": "=================\n模型加速项目\n=================\n\nComming soon!"
  },
  {
    "path": "docs/zh_cn/switch_language.md",
    "content": "## <a href='https://pdf-extract-kit.readthedocs.io/en/latest/'>English</a>\n\n## <a href='https://pdf-extract-kit.readthedocs.io/zh_CN/latest/'>简体中文</a>\n"
  },
  {
    "path": "docs/zh_cn/task_extend/code.rst",
    "content": "==================================\n代码实现\n==================================\n\nPDF-Extract-Kit项目的核心代码实现在pdf_extract_kit目录下，该路径下包含下述几个模块：\n\n- configs: 特定模块的配置文件，如 ``pdf_extract_kit/configs/unimernet.yaml`` ，如果本身配置简单，建议放在 ``repo_root/configs`` 的 ``yaml`` 文件中的 ``model_config`` 里进行定义，方便用户修改。\n\n- dataset: 自定义的 ``ImageDataset`` 类，用于加载和预处理图像数据。它支持多种输入类型，并且可以对图像进行统一的预处理操作（如调整大小、转换为张量等），以便于后续的模型推理加速。\n\n- evaluation: 模型结果评测模块，支持多种任务类型评测，如 ``布局检测`` 、 ``公式检测`` 、 ``公式识别`` 等等，方便用户对不同任务、不同模型进行公平对比。\n\n- registry: ``Registry`` 类是一个通用的注册表类，提供了注册、获取和列出注册项的功能。用户可以使用该类创建不同类型的注册表，例如任务注册表、模型注册表等。\n\n- tasks: 最核心的任务模块，包含了许多不同类型的任务，如 ``布局检测`` 、 ``公式检测`` 、 ``公式识别`` 等等，用户添加新任务和新模型一般仅需要在这里进行代码添加。\n\n\n.. note::\n    基于上述的模块化设计，用户拓展新模块一般只需要在tasks里实现自己的新任务类及对应模型（更多情况下仅需要实现对应模型，任务已经定义好），然后在registry里注册即可。\n\n\n下面我们以添加基于 ``YOLO``的 ``布局检测`` 模型为例，介绍如何添加新任务和新模型.\n\n任务定义及注册\n==============\n\n首先我们在 ``tasks`` 下添加一个 ``layout_detection`` 目录，然后在该目录下添加一个 ``task.py`` 文件用于定义布局检测任务类，具体如下：\n\n.. code-block:: python\n\n    from pdf_extract_kit.registry.registry import TASK_REGISTRY\n    from pdf_extract_kit.tasks.base_task import BaseTask\n\n\n    @TASK_REGISTRY.register(\"layout_detection\")\n    class LayoutDetectionTask(BaseTask):\n        def __init__(self, model):\n            super().__init__(model)\n\n        def predict_images(self, input_data, result_path):\n            \"\"\"\n            Predict layouts in images.\n\n            Args:\n                input_data (str): Path to a single image file or a directory containing image files.\n                result_path (str): Path to save the prediction results.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            images = self.load_images(input_data)\n            # Perform detection\n            return self.model.predict(images, result_path)\n\n        def predict_pdfs(self, input_data, result_path):\n            \"\"\"\n            Predict layouts in PDF files.\n\n            Args:\n                input_data (str): Path to a single PDF file or a directory containing PDF files.\n                result_path (str): Path to save the prediction results.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            pdf_images = self.load_pdf_images(input_data)\n            # Perform detection\n            return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys()))\n\n可以看到，任务定义包含下面几个要点：\n\n* 使用 ``@TASK_REGISTRY.register(\"layout_detection\")`` 语法直接将布局任务类注册到 ``TASK_REGISTRY`` 下 ；\n* ``__init__`` 初始化函数传入 ``model`` , 具体参考 ``BaseTask`` 类\n* 实现推理函数，这里考虑到布局检测通常会处理图像类及PDF文件，所以提供了两个函数 ``predict_images`` 和 ``predict_pdfs`` ，方便用户灵活选择。\n\n模型定义及注册\n==============\n\n接下来我们实现具体模型，在task下面新建models目录，并添加yolo.py用于YOLO模型定义，具体定义如下：\n\n.. code-block:: python\n\n    import os\n    import cv2\n    import torch\n    from torch.utils.data import DataLoader, Dataset\n    from ultralytics import YOLO\n    from pdf_extract_kit.registry import MODEL_REGISTRY\n    from pdf_extract_kit.utils.visualization import  visualize_bbox\n    from pdf_extract_kit.dataset.dataset import ImageDataset\n    import torchvision.transforms as transforms\n\n\n    @MODEL_REGISTRY.register('layout_detection_yolo')\n    class LayoutDetectionYOLO:\n        def __init__(self, config):\n            \"\"\"\n            Initialize the LayoutDetectionYOLO class.\n\n            Args:\n                config (dict): Configuration dictionary containing model parameters.\n            \"\"\"\n            # Mapping from class IDs to class names\n            self.id_to_names = {\n                0: 'title', \n                1: 'plain text',\n                2: 'abandon', \n                3: 'figure', \n                4: 'figure_caption', \n                5: 'table', \n                6: 'table_caption', \n                7: 'table_footnote', \n                8: 'isolate_formula', \n                9: 'formula_caption'\n            }\n\n            # Load the YOLO model from the specified path\n            self.model = YOLO(config['model_path'])\n\n            # Set model parameters\n            self.img_size = config.get('img_size', 1280)\n            self.pdf_dpi = config.get('pdf_dpi', 200)\n            self.conf_thres = config.get('conf_thres', 0.25)\n            self.iou_thres = config.get('iou_thres', 0.45)\n            self.visualize = config.get('visualize', False)\n            self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')\n            self.batch_size = config.get('batch_size', 1)\n\n        def predict(self, images, result_path, image_ids=None):\n            \"\"\"\n            Predict layouts in images.\n\n            Args:\n                images (list): List of images to be predicted.\n                result_path (str): Path to save the prediction results.\n                image_ids (list, optional): List of image IDs corresponding to the images.\n\n            Returns:\n                list: List of prediction results.\n            \"\"\"\n            results = []\n            for idx, image in enumerate(images):\n                result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0]\n                if self.visualize:\n                    if not os.path.exists(result_path):\n                        os.makedirs(result_path)\n                    boxes = result.__dict__['boxes'].xyxy\n                    classes = result.__dict__['boxes'].cls\n                    vis_result = visualize_bbox(image, boxes, classes, self.id_to_names)\n\n                    # Determine the base name of the image\n                    if image_ids:\n                        base_name = image_ids[idx]\n                    else:\n                        base_name = os.path.basename(image)\n                    \n                    result_name = f\"{base_name}_MFD.png\"\n                    \n                    # Save the visualized result                \n                    cv2.imwrite(os.path.join(result_path, result_name), vis_result)\n                results.append(result)\n            return results\n\n\n可以看到，模型定义包含下面几个要点：\n\n* 使用 ``@MODEL_REGISTRY.register('layout_detection_yolo')`` 语法直接将yolo布局模型注册到 ``MODEL_REGISTRY`` 下；\n* 初始化函数需要实现：\n    + id_to_names的类别映射，用于可视化展示\n    + 模型参数配置\n    + 模型初始化\n* 模型推理函数需要实现多种类型的模型推理：这里支持图像列表和PIL.Image类，可以方便用户直接基于图像路径或者图像流进行推理。\n\n实现上述类定义后，将 ``LayoutDetectionYOLO`` 添加到 ``layout_detection`` 任务下 ``__init__.py`` 的 ``__all__`` 中即可。\n\n.. code-block:: python\n\n    from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO\n    from pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n\n    __all__ = [\n        \"LayoutDetectionYOLO\",\n    ]\n\n\n.. note:: \n    对于同一个任务，我们支持多种模型，用户具体选择哪个可以根据评测结果进行选择，结合模型 ``精度`` 、 ``速度`` 和 ``场景适配程度`` 进行选择。\n\n\n实现了任务和模型后，可以在 repo_root/scripts下添加脚本程序 ``layout_detection.py``\n\n示例脚本\n==============\n\n.. code-block:: python\n\n    import os\n    import sys\n    import os.path as osp\n    import argparse\n\n    sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\n    from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\n    import pdf_extract_kit.tasks  # 确保所有任务模块被导入\n\n    TASK_NAME = 'layout_detection'\n\n\n    def parse_args():\n        parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n        parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n        return parser.parse_args()\n\n    def main(config_path):\n        config = load_config(config_path)\n        task_instances = initialize_tasks_and_models(config)\n\n        # get input and output path from config\n        input_data = config.get('inputs', None)\n        result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n        # layout_detection_task\n        model_layout_detection = task_instances[TASK_NAME]\n\n        # for image detection\n        detection_results = model_layout_detection.predict_images(input_data, result_path)\n\n        # for pdf detection\n        # detection_results = model_layout_detection.predict_pdfs(input_data, result_path)\n\n        # print(detection_results)\n        print(f'The predicted results can be found at {result_path}')\n\n\n    if __name__ == \"__main__\":\n        args = parse_args()\n        main(args.config)\n\n支持类型拓展\n==============\n\n\n批处理拓展\n==============\n"
  },
  {
    "path": "docs/zh_cn/task_extend/doc.rst",
    "content": "==================================\n文档补充\n==================================\n\n在实现新的任务和模块后，需要在文档中补充相关内容，以便用户了解如何使用。\n\n具体可以参考布局检测任务使用文档：\\ :ref:`布局检测算法 <algorithm_layout_detection>` \n\n\n主要补充下述几个部分：\n\n* 任务简介  \n* 模型使用方式  \n* 配置文件解释  \n* 多样化输入支持（如果有）  \n* 可视化结果查看  "
  },
  {
    "path": "docs/zh_cn/task_extend/evaluation.rst",
    "content": "==================================\n模型评测\n==================================\n\nComming soon!"
  },
  {
    "path": "pdf_extract_kit/__init__.py",
    "content": "import os\nimport sys\n\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\n\nroot_dir = os.path.abspath(os.path.join(current_dir, '..'))\n\nif root_dir not in sys.path:\n    sys.path.insert(0, root_dir)"
  },
  {
    "path": "pdf_extract_kit/configs/unimernet.yaml",
    "content": "model:\n  arch: unimernet\n  model_type: unimernet\n  model_config:\n    model_name: ./models/unimernet_tiny\n    max_seq_len: 1536\n\n  load_pretrained: True\n  pretrained: './models/unimernet_tiny/pytorch_model.pth'\n  tokenizer_config:\n    path: ./models/unimernet_tiny\n\ndatasets:\n  formula_rec_eval:\n    vis_processor:\n      eval:\n        name: \"formula_image_eval\"\n        image_size:\n          - 192\n          - 672\n   \nrun:\n  runner: runner_iter\n  task: unimernet_train\n\n  batch_size_train: 64\n  batch_size_eval: 64\n  num_workers: 1\n\n  iters_per_inner_epoch: 2000\n  max_iters: 60000\n\n  seed: 42\n  output_dir: \"../output/demo\"\n\n  evaluate: True\n  test_splits: [ \"eval\" ]\n\n  device: \"cuda\"\n  world_size: 1\n  dist_url: \"env://\"\n  distributed: True\n  distributed_type: ddp  # or fsdp when train llm\n\n  generate_cfg:\n    temperature: 0.0"
  },
  {
    "path": "pdf_extract_kit/dataset/__init__.py",
    "content": ""
  },
  {
    "path": "pdf_extract_kit/dataset/dataset.py",
    "content": "import numpy as np\nimport torch\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nimport torchvision.transforms as transforms\n\n\nclass ResizeLongestSide:\n    def __init__(self, size):\n        self.size = size\n\n    def __call__(self, img):\n        # Get the original dimensions\n        width, height = img.size\n        # Determine the scaling factor\n        if width > height:\n            new_width = self.size\n            new_height = int(height * (self.size / float(width)))\n        else:\n            new_height = self.size\n            new_width = int(width * (self.size / float(height)))\n        # Resize the image\n        return img.resize((new_width, new_height), Image.BILINEAR)\n\n\nclass ImageDataset(Dataset):\n    def __init__(self, images, image_ids=None, img_size=1280):\n        \"\"\"\n        Initialize the ImageDataset class.\n        \n        Args:\n        - images (list): List of image paths or PIL.Image.Image objects.\n        - image_ids (list, optional): List of corresponding image IDs. If None, assumes images are paths.\n        - img_size (int): Size to which images' longest side will be resized.\n        \"\"\"\n        self.images = images\n        self.image_ids = image_ids if image_ids is not None else images\n        self.img_size = img_size\n        self.transform = transforms.Compose([\n            ResizeLongestSide(self.img_size),\n            transforms.ToTensor()\n        ])\n\n    def __len__(self):\n        \"\"\"\n        Return the size of the dataset.\n        \n        Returns:\n        int: Number of images in the dataset.\n        \"\"\"\n        return len(self.images)\n\n    def __getitem__(self, idx):\n        \"\"\"\n        Get an image and its corresponding ID by index.\n        \n        Args:\n        - idx (int): Index of the image to retrieve.\n        \n        Returns:\n        tuple: Transformed image tensor and corresponding image ID.\n        \"\"\"\n        image = self.images[idx]\n        image_id = self.image_ids[idx]\n\n        # Check if the image is a path or a PIL.Image object\n        if isinstance(image, str):\n            image = Image.open(image).convert('RGB')\n        elif isinstance(image, Image.Image):\n            image = image.convert('RGB')\n        else:\n            raise ValueError(\"Image must be a file path or a PIL.Image object\")\n\n        # Apply transformations\n        image = self.transform(image)\n\n        return image, image_id\n    \n    \nclass MathDataset(Dataset):\n    def __init__(self, image_paths, transform=None):\n        self.image_paths = image_paths\n        self.transform = transform\n\n    def __len__(self):\n        return len(self.image_paths)\n\n    def __getitem__(self, idx):\n        # if not pil image, then convert to pil image\n        if isinstance(self.image_paths[idx], str):\n            raw_image = Image.open(self.image_paths[idx])\n        else:\n            raw_image = self.image_paths[idx]\n        if self.transform:\n            image = self.transform(raw_image)\n        return image\n"
  },
  {
    "path": "pdf_extract_kit/registry/__init__.py",
    "content": "from .registry import TASK_REGISTRY, MODEL_REGISTRY"
  },
  {
    "path": "pdf_extract_kit/registry/registry.py",
    "content": "class Registry:\n    def __init__(self):\n        self._registry = {}\n\n    def register(self, name):\n        def decorator(item):\n            if name in self._registry:\n                raise ValueError(f\"Item {name} already registered.\")\n            self._registry[name] = item\n            return item\n        return decorator\n\n    def get(self, name):\n        if name not in self._registry:\n            raise ValueError(f\"Item {name} not found in registry.\")\n        return self._registry[name]\n\n    def list_items(self):\n        return list(self._registry.keys())\n\n# Create global registries for tasks and models\nTASK_REGISTRY = Registry()\nMODEL_REGISTRY = Registry()"
  },
  {
    "path": "pdf_extract_kit/tasks/__init__.py",
    "content": "from pdf_extract_kit.tasks.base_task import BaseTask\nfrom pdf_extract_kit.tasks.formula_detection.task import FormulaDetectionTask\nfrom pdf_extract_kit.tasks.formula_recognition.task import FormulaRecognitionTask\nfrom pdf_extract_kit.tasks.layout_detection.task import LayoutDetectionTask\nfrom pdf_extract_kit.tasks.ocr.task import OCRTask\nfrom pdf_extract_kit.tasks.table_parsing.task import TableParsingTask\n\nfrom pdf_extract_kit.registry.registry import TASK_REGISTRY\n\n__all__ = [\n    \"BaseTask\",\n    \"LayoutDetectionTask\",\n    \"FormulaRecognitionTask\",\n    \"LayoutDetectionTask\",\n    \"OCRTask\",\n    \"TableParsingTask\",\n]\n\ndef load_task(name, cfg=None):\n    \"\"\"\n    Example\n\n    >>> task = load_task(\"formula_detection\", cfg=None)\n    \"\"\"\n    task_class = TASK_REGISTRY.get(name)\n    task_instance = task_class(cfg)\n\n    return task_instance\n"
  },
  {
    "path": "pdf_extract_kit/tasks/base_task.py",
    "content": "import os\nfrom pdf_extract_kit.utils.data_preprocess import load_pdf\n\n\nclass BaseTask:\n    def __init__(self, model):\n        self.model = model\n\n    def load_images(self, input_data):\n        \"\"\"\n        Loads images from a single image path or a directory containing multiple images.\n\n        Args:\n            input_data (str): Path to a single image file or a directory containing image files.\n\n        Returns:\n            list: List of paths to all images to be predicted.\n        \"\"\"\n        images = []\n\n        if os.path.isdir(input_data):\n            # If input_data is a directory, check for nested directories\n            for root, dirs, files in os.walk(input_data):\n                if dirs:\n                    raise ValueError(\"Input directory should not contain nested directories: {}\".format(input_data))\n                for file in files:\n                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):\n                        image_path = os.path.join(root, file)\n                        images.append(image_path)\n                images = sorted(images)\n                break  # Only process the top-level directory\n        else:\n            # Determine the type of input data and process accordingly\n            if input_data.lower().endswith(('.png', '.jpg', '.jpeg')):\n                # If input is a single image file\n                images = [input_data]\n            else:\n                raise ValueError(\"Unsupported input data format: {}\".format(input_data))\n\n        return images\n\n    def load_pdf_images(self, input_data):\n        \"\"\"\n        Loads images from a single PDF file or directory containing multiple PDF files.\n\n        Args:\n            input_data (str): Path to a single PDF file or a directory containing PDF files.\n\n        Returns:\n            dict: Dictionary with image IDs (formed by PDF path and page number) as keys and corresponding PIL.Image objects as values.\n                  Note: Loading multiple PDFs at once is not recommended due to high memory consumption. Consider processing one PDF at a time externally using loops or multithreading.\n        \"\"\"\n        pdf_images = {}\n\n        if os.path.isdir(input_data):\n            # If input_data is a directory, check for nested directories\n            for root, dirs, files in os.walk(input_data):\n                if dirs:\n                    raise ValueError(\"Input directory should not contain nested directories: {}\".format(input_data))\n                for file in files:\n                    if file.lower().endswith(('.pdf')):\n                        pdf_path = os.path.join(root, file)\n                        images = load_pdf(pdf_path)\n                        for i, img in enumerate(images):\n                            img_id = f\"{os.path.splitext(file)[0]}_page_{i+1:04d}\"\n                            pdf_images[img_id] = img\n                # images = sorted(images)\n                break  # Only process the top-level directory\n        else:\n            # Determine the type of input data and process accordingly\n            if input_data.lower().endswith(('.pdf')):\n                # If input is a single image file\n                images = load_pdf(input_data)\n                for i, img in enumerate(images):\n                    img_id = f\"{os.path.splitext(os.path.basename(input_data))[0]}_page_{i+1:04d}\"\n                    pdf_images[img_id] = img\n            else:\n                raise ValueError(\"Unsupported input data format: {}\".format(input_data))\n\n        return pdf_images"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_detection/__init__.py",
    "content": "from pdf_extract_kit.tasks.formula_detection.models.yolo import FormulaDetectionYOLO\n\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n__all__ = [\n    \"FurmulaDetectionYOLO\",\n]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_detection/models/yolo.py",
    "content": "import os\nimport cv2\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\nfrom ultralytics import YOLO\nfrom pdf_extract_kit.registry import MODEL_REGISTRY\nfrom pdf_extract_kit.utils.visualization import visualize_bbox\nfrom pdf_extract_kit.dataset.dataset import ImageDataset\nimport torchvision.transforms as transforms\n\n\n@MODEL_REGISTRY.register('formula_detection_yolo')\nclass FormulaDetectionYOLO:\n    def __init__(self, config):\n        \"\"\"\n        Initialize the FormulaDetectionYOLO class.\n\n        Args:\n            config (dict): Configuration dictionary containing model parameters.\n        \"\"\"\n        # Mapping from class IDs to class names\n        self.id_to_names = {\n            0: 'inline',\n            1: 'isolated'\n        }\n\n        # Load the YOLO model from the specified path\n        self.model = YOLO(config['model_path'])\n\n        # Set model parameters\n        self.img_size = config.get('img_size', 1280)\n        self.pdf_dpi = config.get('pdf_dpi', 200)\n        self.conf_thres = config.get('conf_thres', 0.25)\n        self.iou_thres = config.get('iou_thres', 0.45)\n        self.visualize = config.get('visualize', False)\n        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')\n        self.batch_size = config.get('batch_size', 1)\n\n    def predict(self, images, result_path, image_ids=None):\n        \"\"\"\n        Predict formulas in images.\n\n        Args:\n            images (list): List of images to be predicted.\n            result_path (str): Path to save the prediction results.\n            image_ids (list, optional): List of image IDs corresponding to the images.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        results = []\n        for idx, image in enumerate(images):\n            result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0]\n            if self.visualize:\n                if not os.path.exists(result_path):\n                    os.makedirs(result_path)\n                boxes = result.__dict__['boxes'].xyxy\n                classes = result.__dict__['boxes'].cls\n                scores = result.__dict__['boxes'].conf\n                \n                vis_result = visualize_bbox(image, boxes, classes, scores, self.id_to_names)\n\n                # Determine the base name of the image\n                if image_ids:\n                    base_name = image_ids[idx]\n                else:\n                    # base_name = os.path.basename(image)                    \n                    base_name = os.path.splitext(os.path.basename(image))[0]  # Remove file extension\n\n                \n                result_name = f\"{base_name}_MFD.png\"\n                \n                # Save the visualized result                \n                cv2.imwrite(os.path.join(result_path, result_name), vis_result)\n            results.append(result)\n        return results"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_detection/task.py",
    "content": "from pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.tasks.base_task import BaseTask\n\n@TASK_REGISTRY.register(\"formula_detection\")\nclass FormulaDetectionTask(BaseTask):\n    def __init__(self, model):\n        super().__init__(model)\n\n    def predict_images(self, input_data, result_path):\n        \"\"\"\n        Predict formulas in images.\n\n        Args:\n            input_data (str): Path to a single image file or a directory containing image files.\n            result_path (str): Path to save the prediction results.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        images = self.load_images(input_data)\n        # Perform detection\n        return self.model.predict(images, result_path)\n\n    def predict_pdfs(self, input_data, result_path):\n        \"\"\"\n        Predict formulas in PDF files.\n\n        Args:\n            input_data (str): Path to a single PDF file or a directory containing PDF files.\n            result_path (str): Path to save the prediction results.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        pdf_images = self.load_pdf_images(input_data)\n        # Perform detection\n        return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys()))"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_recognition/__init__.py",
    "content": "from pdf_extract_kit.tasks.formula_recognition.models.unimernet import FormulaRecognitionUniMERNet\n\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n__all__ = [\n    \"FurmulaRecognitionUniMERNet\",\n]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_recognition/models/unimernet.py",
    "content": "import os\nimport logging\nimport argparse\n\nimport cv2\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport unimernet.tasks as tasks\nfrom unimernet.common.config import Config\nfrom unimernet.processors import load_processor\n\nfrom pdf_extract_kit.registry import MODEL_REGISTRY\n\n\n@MODEL_REGISTRY.register('formula_recognition_unimernet')\nclass FormulaRecognitionUniMERNet:\n    def __init__(self, config):\n        \"\"\"\n        Initialize the FormulaRecognitionUniMERNet class.\n\n        Args:\n            config (dict): Configuration dictionary containing model parameters.\n        \"\"\"\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.model_dir = config['model_path']\n        self.cfg_path = config.get('cfg_path', \"pdf_extract_kit/configs/unimernet.yaml\")\n        self.batch_size = config.get('batch_size', 1)\n\n        # Load the UniMERNet model\n        self.model, self.vis_processor = self.load_model_and_processor()\n\n    def load_model_and_processor(self):\n        try:\n            args = argparse.Namespace(cfg_path=self.cfg_path, options=None)\n            cfg = Config(args)\n            cfg.config.model.pretrained = os.path.join(self.model_dir, \"pytorch_model.pth\")\n            cfg.config.model.model_config.model_name = self.model_dir\n            cfg.config.model.tokenizer_config.path = self.model_dir\n            task = tasks.setup_task(cfg)\n            model = task.build_model(cfg).to(self.device)\n            vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)\n            return model, vis_processor\n        except Exception as e:\n            logging.error(f\"Error loading model and processor: {e}\")\n            raise\n    \n    def predict(self, images, result_path):\n        results = []\n        for image_path in images:\n            # Read the image using OpenCV\n            open_cv_image = cv2.imread(image_path)\n            if open_cv_image is None:\n                logging.error(f\"Error: Unable to open image at {image_path}\")\n                continue\n            # Convert the OpenCV image to PIL.Image format\n            raw_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))\n\n            try:\n                # Process the image using the visual processor and prepare it for the model\n                image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)\n\n                # Generate the prediction using the model\n                output = self.model.generate({\"image\": image})\n                pred = output[\"pred_str\"][0]\n                logging.info(f'Prediction for {image_path}:\\n{pred}')\n\n                # cv2.imshow('Original Image', open_cv_image)\n                # cv2.waitKey(0)\n                # cv2.destroyAllWindows()\n\n                results.append(pred)\n            except Exception as e:\n                logging.error(f\"Error processing image {image_path}: {e}\")\n    \n        return results"
  },
  {
    "path": "pdf_extract_kit/tasks/formula_recognition/task.py",
    "content": "from pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.tasks.base_task import BaseTask\n\n\n@TASK_REGISTRY.register(\"formula_recognition\")\nclass FormulaRecognitionTask(BaseTask):\n    def __init__(self, model):\n        super().__init__(model)\n\n    def predict(self, input_data, result_path, bboxes=None):\n        images = self.load_images(input_data)\n        # Perform recognition\n        return self.model.predict(images, result_path)"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/__init__.py",
    "content": "from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO\n# from pdf_extract_kit.tasks.layout_detection.models.layoutlmv3 import LayoutDetectionLayoutlmv3\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n\n__all__ = [\n    \"LayoutDetectionYOLO\",\n    # \"LayoutDetectionLayoutlmv3\",\n]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/__init__.py",
    "content": ""
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3.py",
    "content": "import os\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\nfrom pdf_extract_kit.utils.visualization import visualize_bbox\n\nfrom .layoutlmv3_util.model_init import Layoutlmv3_Predictor\n\n@MODEL_REGISTRY.register(\"layout_detection_layoutlmv3\")\nclass LayoutDetectionLayoutlmv3:\n    def __init__(self, config):\n        \"\"\"\n        Initialize the LayoutDetectionYOLO class.\n\n        Args:\n            config (dict): Configuration dictionary containing model parameters.\n        \"\"\"\n        # Mapping from class IDs to class names\n        self.id_to_names = {\n            0: 'title', \n            1: 'plain text',\n            2: 'abandon', \n            3: 'figure', \n            4: 'figure_caption', \n            5: 'table', \n            6: 'table_caption', \n            7: 'table_footnote', \n            8: 'isolate_formula', \n            9: 'formula_caption'\n        }\n        self.model = Layoutlmv3_Predictor(config.get('model_path', None))\n        self.visualize = config.get('visualize', False)\n\n    def predict(self, images, result_path, image_ids=None):\n        \"\"\"\n        Predict layouts in images.\n\n        Args:\n            images (list): List of images to be predicted.\n            result_path (str): Path to save the prediction results.\n            image_ids (list, optional): List of image IDs corresponding to the images.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        if not os.path.exists(result_path):\n            os.makedirs(result_path)\n        \n        results = []\n        for idx, im_file in enumerate(images):\n            if isinstance(im_file, Image.Image):\n                im = im_file.convert(\"RGB\")  # extracted PDF pages\n            elif isinstance(im_file, str):\n                im = Image.open(im_file).convert(\"RGB\")  # image path\n            layout_res = self.model(np.array(im), ignore_catids=[])\n            poly = np.array([det[\"poly\"] for det in layout_res[\"layout_dets\"]])\n            boxes = poly[:, [0,1,4,5]] \n            scores = np.array([det[\"score\"] for det in layout_res[\"layout_dets\"]])\n            classes = np.array([det[\"category_id\"] for det in layout_res[\"layout_dets\"]])\n            \n            if self.visualize:\n                vis_result = visualize_bbox(im_file, boxes, classes, scores, self.id_to_names)\n                # Determine the base name of the image\n                if image_ids:\n                    base_name = image_ids[idx]\n                else:\n                    base_name = os.path.splitext(os.path.basename(im_file))[0]  # Remove file extension\n                result_name = f\"{base_name}_layout.png\"\n                # Save the visualized result                \n                cv2.imwrite(os.path.join(result_path, result_name), vis_result)\n\n            # append result\n            results.append({\n                \"im_path\": im_file,\n                \"boxes\": boxes,\n                \"scores\": scores,\n                \"classes\": classes,\n            })\n        return results\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/backbone.py",
    "content": "# --------------------------------------------------------------------------------\n# VIT: Multi-Path Vision Transformer for Dense Prediction\n# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).\n# All Rights Reserved.\n# Written by Youngwan Lee\n# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the\n# LICENSE file in the root directory of this source tree.\n# --------------------------------------------------------------------------------\n# References:\n# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# CoaT: https://github.com/mlpc-ucsd/CoaT\n# --------------------------------------------------------------------------------\n\n\nimport torch\n\nfrom detectron2.layers import (\n    ShapeSpec,\n)\nfrom detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN\nfrom detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool\n\nfrom .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16\nfrom .deit import deit_base_patch16, mae_base_patch16\nfrom .layoutlmft.models.layoutlmv3 import LayoutLMv3Model\nfrom transformers import AutoConfig\n\n__all__ = [\n    \"build_vit_fpn_backbone\",\n]\n\n\nclass VIT_Backbone(Backbone):\n    \"\"\"\n    Implement VIT backbone.\n    \"\"\"\n\n    def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,\n                 config_path=None, image_only=False, cfg=None):\n        super().__init__()\n        self._out_features = out_features\n        if 'base' in name:\n            self._out_feature_strides = {\"layer3\": 4, \"layer5\": 8, \"layer7\": 16, \"layer11\": 32}\n            self._out_feature_channels = {\"layer3\": 768, \"layer5\": 768, \"layer7\": 768, \"layer11\": 768}\n        else:\n            self._out_feature_strides = {\"layer7\": 4, \"layer11\": 8, \"layer15\": 16, \"layer23\": 32}\n            self._out_feature_channels = {\"layer7\": 1024, \"layer11\": 1024, \"layer15\": 1024, \"layer23\": 1024}\n\n        if name == 'beit_base_patch16':\n            model_func = beit_base_patch16\n        elif name == 'dit_base_patch16':\n            model_func = dit_base_patch16\n        elif name == \"deit_base_patch16\":\n            model_func = deit_base_patch16\n        elif name == \"mae_base_patch16\":\n            model_func = mae_base_patch16\n        elif name == \"dit_large_patch16\":\n            model_func = dit_large_patch16\n        elif name == \"beit_large_patch16\":\n            model_func = beit_large_patch16\n\n        if 'beit' in name or 'dit' in name:\n            if pos_type == \"abs\":\n                self.backbone = model_func(img_size=img_size,\n                                           out_features=out_features,\n                                           drop_path_rate=drop_path,\n                                           use_abs_pos_emb=True,\n                                           **model_kwargs)\n            elif pos_type == \"shared_rel\":\n                self.backbone = model_func(img_size=img_size,\n                                           out_features=out_features,\n                                           drop_path_rate=drop_path,\n                                           use_shared_rel_pos_bias=True,\n                                           **model_kwargs)\n            elif pos_type == \"rel\":\n                self.backbone = model_func(img_size=img_size,\n                                           out_features=out_features,\n                                           drop_path_rate=drop_path,\n                                           use_rel_pos_bias=True,\n                                           **model_kwargs)\n            else:\n                raise ValueError()\n        elif \"layoutlmv3\" in name:\n            config = AutoConfig.from_pretrained(config_path)\n            # disable relative bias as DiT\n            config.has_spatial_attention_bias = False\n            config.has_relative_attention_bias = False\n            self.backbone = LayoutLMv3Model(config, detection=True,\n                                               out_features=out_features, image_only=image_only)\n        else:\n            self.backbone = model_func(img_size=img_size,\n                                       out_features=out_features,\n                                       drop_path_rate=drop_path,\n                                       **model_kwargs)\n        self.name = name\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.\n\n        Returns:\n            dict[str->Tensor]: names and the corresponding features\n        \"\"\"\n        if \"layoutlmv3\" in self.name:\n            return self.backbone.forward(\n                input_ids=x[\"input_ids\"] if \"input_ids\" in x else None,\n                bbox=x[\"bbox\"] if \"bbox\" in x else None,\n                images=x[\"images\"] if \"images\" in x else None,\n                attention_mask=x[\"attention_mask\"] if \"attention_mask\" in x else None,\n                # output_hidden_states=True,\n            )\n        assert x.dim() == 4, f\"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!\"\n        return self.backbone.forward_features(x)\n\n    def output_shape(self):\n        return {\n            name: ShapeSpec(\n                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]\n            )\n            for name in self._out_features\n        }\n\n\ndef build_VIT_backbone(cfg):\n    \"\"\"\n    Create a VIT instance from config.\n\n    Args:\n        cfg: a detectron2 CfgNode\n\n    Returns:\n        A VIT backbone instance.\n    \"\"\"\n    # fmt: off\n    name = cfg.MODEL.VIT.NAME\n    out_features = cfg.MODEL.VIT.OUT_FEATURES\n    drop_path = cfg.MODEL.VIT.DROP_PATH\n    img_size = cfg.MODEL.VIT.IMG_SIZE\n    pos_type = cfg.MODEL.VIT.POS_TYPE\n\n    model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace(\"`\", \"\"))\n\n    if 'layoutlmv3' in name:\n        if cfg.MODEL.CONFIG_PATH != '':\n            config_path = cfg.MODEL.CONFIG_PATH\n        else:\n            config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '')  # layoutlmv3 pre-trained models\n            config_path = config_path.replace('model_final.pth', '')  # detection fine-tuned models\n    else:\n        config_path = None\n\n    return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,\n                        config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)\n\n\n@BACKBONE_REGISTRY.register()\ndef build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):\n    \"\"\"\n    Create a VIT w/ FPN backbone.\n\n    Args:\n        cfg: a detectron2 CfgNode\n\n    Returns:\n        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.\n    \"\"\"\n    bottom_up = build_VIT_backbone(cfg)\n    in_features = cfg.MODEL.FPN.IN_FEATURES\n    out_channels = cfg.MODEL.FPN.OUT_CHANNELS\n    backbone = FPN(\n        bottom_up=bottom_up,\n        in_features=in_features,\n        out_channels=out_channels,\n        norm=cfg.MODEL.FPN.NORM,\n        top_block=LastLevelMaxPool(),\n        fuse_type=cfg.MODEL.FPN.FUSE_TYPE,\n    )\n    return backbone\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/beit.py",
    "content": "\"\"\" Vision Transformer (ViT) in PyTorch\n\nA PyTorch implement of Vision Transformers as described in\n'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929\n\nThe official jax code is released and available at https://github.com/google-research/vision_transformer\n\nStatus/TODO:\n* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.\n* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.\n* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.\n* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.\n\nAcknowledgments:\n* The paper authors for releasing code and weights, thanks!\n* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out\nfor some einops/einsum fun\n* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT\n* Bert reference code checks against Huggingface Transformers and Tensorflow Bert\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\nimport warnings\nimport math\nimport torch\nfrom functools import partial\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),\n        **kwargs\n    }\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            # trunc_normal_(self.relative_position_bias_table, std=.0)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None, training_window_size=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            if training_window_size == self.window_size:\n                relative_position_bias = \\\n                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                        self.window_size[0] * self.window_size[1] + 1,\n                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n                attn = attn + relative_position_bias.unsqueeze(0)\n            else:\n                training_window_size = tuple(training_window_size.tolist())\n                new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3\n                # new_num_relative_dis 为 所有可能的相对位置选项，包含cls-cls，tok-cls，与cls-tok\n                new_relative_position_bias_table = F.interpolate(\n                    self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,\n                                                                                 2 * self.window_size[0] - 1,\n                                                                                 2 * self.window_size[1] - 1),\n                    size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',\n                    align_corners=False)\n                new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,\n                                                                                         new_num_relative_distance - 3).permute(\n                    1, 0)\n                new_relative_position_bias_table = torch.cat(\n                    [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)\n\n                # get pair-wise relative position index for each token inside the window\n                coords_h = torch.arange(training_window_size[0])\n                coords_w = torch.arange(training_window_size[1])\n                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n                relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n                relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n                relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0\n                relative_coords[:, :, 1] += training_window_size[1] - 1\n                relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1\n                relative_position_index = \\\n                    torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,\n                                dtype=relative_coords.dtype)\n                relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n                relative_position_index[0, 0:] = new_num_relative_distance - 3\n                relative_position_index[0:, 0] = new_num_relative_distance - 2\n                relative_position_index[0, 0] = new_num_relative_distance - 1\n\n                relative_position_bias = \\\n                    new_relative_position_bias_table[relative_position_index.view(-1)].view(\n                        training_window_size[0] * training_window_size[1] + 1,\n                        training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n                attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None, training_window_size=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(\n                self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,\n                                                            training_window_size=training_window_size))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches_w = self.patch_shape[0]\n        self.num_patches_h = self.patch_shape[1]\n        # the so-called patch_shape is the patch shape during pre-training\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, position_embedding=None, **kwargs):\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #     f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        Hp, Wp = x.shape[2], x.shape[3]\n\n        if position_embedding is not None:\n            # interpolate the position embedding to the corresponding size\n            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,\n                                                                                                                  1, 2)\n            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')\n            x = x + position_embedding\n\n        x = x.flatten(2).transpose(1, 2)\n        return x, (Hp, Wp)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n\n    def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_heads = num_heads\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self, training_window_size):\n        if training_window_size == self.window_size:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        else:\n            training_window_size = tuple(training_window_size.tolist())\n            new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3\n            # new_num_relative_dis 为 所有可能的相对位置选项，包含cls-cls，tok-cls，与cls-tok\n            new_relative_position_bias_table = F.interpolate(\n                self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,\n                                                                             2 * self.window_size[0] - 1,\n                                                                             2 * self.window_size[1] - 1),\n                size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',\n                align_corners=False)\n            new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,\n                                                                                     new_num_relative_distance - 3).permute(\n                1, 0)\n            new_relative_position_bias_table = torch.cat(\n                [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(training_window_size[0])\n            coords_w = torch.arange(training_window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += training_window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,\n                            dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = new_num_relative_distance - 3\n            relative_position_index[0:, 0] = new_num_relative_distance - 2\n            relative_position_index[0, 0] = new_num_relative_distance - 1\n\n            relative_position_bias = \\\n                new_relative_position_bias_table[relative_position_index.view(-1)].view(\n                    training_window_size[0] * training_window_size[1] + 1,\n                    training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n        return relative_position_bias\n\n\nclass BEiT(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n\n    def __init__(self,\n                 img_size=[224, 224],\n                 patch_size=16,\n                 in_chans=3,\n                 num_classes=80,\n                 embed_dim=768,\n                 depth=12,\n                 num_heads=12,\n                 mlp_ratio=4.,\n                 qkv_bias=False,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 hybrid_backbone=None,\n                 norm_layer=None,\n                 init_values=None,\n                 use_abs_pos_emb=False,\n                 use_rel_pos_bias=False,\n                 use_shared_rel_pos_bias=False,\n                 use_checkpoint=True,\n                 pretrained=None,\n                 out_features=None,\n                 ):\n\n        super(BEiT, self).__init__()\n\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.use_checkpoint = use_checkpoint\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.out_features = out_features\n        self.out_indices = [int(name[5:]) for name in out_features]\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            self.pos_embed = None\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        self.use_shared_rel_pos_bias = use_shared_rel_pos_bias\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n\n        # trunc_normal_(self.mask_token, std=.02)\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                # nn.SyncBatchNorm(embed_dim),\n                nn.BatchNorm2d(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    '''\n    def init_weights(self):\n        \"\"\"Initialize the weights in backbone.\n\n        Args:\n            pretrained (str, optional): Path to pre-trained weights.\n                Defaults to None.\n        \"\"\"\n        logger = get_root_logger()\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n\n        if self.init_cfg is None:\n            logger.warn(f'No pre-trained weights for '\n                        f'{self.__class__.__name__}, '\n                        f'training start from scratch')\n        else:\n            assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n                                                  f'specify `Pretrained` in ' \\\n                                                  f'`init_cfg` in ' \\\n                                                  f'{self.__class__.__name__} '\n            logger.info(f\"Will load ckpt from {self.init_cfg['checkpoint']}\")\n            load_checkpoint(self,\n                            filename=self.init_cfg['checkpoint'],\n                            strict=False,\n                            logger=logger,\n                            beit_spec_expand_rel_pos = self.use_rel_pos_bias,\n                            )\n    '''\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)\n        # Hp, Wp are HW for patches\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        if self.pos_embed is not None:\n            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]\n        x = torch.cat((cls_tokens, x), dim=1)\n        x = self.pos_drop(x)\n\n        features = []\n        training_window_size = torch.tensor([Hp, Wp])\n\n        rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None\n\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)\n            else:\n                x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)\n            if i in self.out_indices:\n                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n\n        feat_out = {}\n\n        for name, value in zip(self.out_features, features):\n            feat_out[name] = value\n\n        return feat_out\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n\n\ndef beit_base_patch16(pretrained=False, **kwargs):\n    model = BEiT(\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_values=None,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\ndef beit_large_patch16(pretrained=False, **kwargs):\n    model = BEiT(\n        patch_size=16,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_values=None,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\ndef dit_base_patch16(pretrained=False, **kwargs):\n    model = BEiT(\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_values=0.1,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\ndef dit_large_patch16(pretrained=False, **kwargs):\n    model = BEiT(\n        patch_size=16,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_values=1e-5,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\nif __name__ == '__main__':\n    model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)\n    model = model.to(\"cuda:0\")\n    input1 = torch.rand(2, 3, 512, 762).to(\"cuda:0\")\n    input2 = torch.rand(2, 3, 800, 1200).to(\"cuda:0\")\n    input3 = torch.rand(2, 3, 720, 1000).to(\"cuda:0\")\n    output1 = model(input1)\n    output2 = model(input2)\n    output3 = model(input3)\n    print(\"all done\")\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/deit.py",
    "content": "\"\"\"\nMostly copy-paste from DINO and timm library:\nhttps://github.com/facebookresearch/dino\nhttps://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n\"\"\"\nimport warnings\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import trunc_normal_, drop_path, to_2tuple\nfrom functools import partial\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),\n        **kwargs\n    }\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,\n                                      C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n\n        self.num_patches_w, self.num_patches_h = self.window_size\n\n        self.num_patches = self.window_size[0] * self.window_size[1]\n        self.img_size = img_size\n        self.patch_size = patch_size\n\n        self.proj = nn.Conv2d(in_chans, embed_dim,\n                              kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        x = self.proj(x)\n        return x\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(\n                    1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass ViT(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n\n    def __init__(self,\n                 model_name='vit_base_patch16_224',\n                 img_size=384,\n                 patch_size=16,\n                 in_chans=3,\n                 embed_dim=1024,\n                 depth=24,\n                 num_heads=16,\n                 num_classes=19,\n                 mlp_ratio=4.,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop_rate=0.1,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 hybrid_backbone=None,\n                 norm_layer=partial(nn.LayerNorm, eps=1e-6),\n                 norm_cfg=None,\n                 pos_embed_interp=False,\n                 random_init=False,\n                 align_corners=False,\n                 use_checkpoint=False,\n                 num_extra_tokens=1,\n                 out_features=None,\n                 **kwargs,\n                 ):\n\n        super(ViT, self).__init__()\n        self.model_name = model_name\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n        self.depth = depth\n        self.num_heads = num_heads\n        self.num_classes = num_classes\n        self.mlp_ratio = mlp_ratio\n        self.qkv_bias = qkv_bias\n        self.qk_scale = qk_scale\n        self.drop_rate = drop_rate\n        self.attn_drop_rate = attn_drop_rate\n        self.drop_path_rate = drop_path_rate\n        self.hybrid_backbone = hybrid_backbone\n        self.norm_layer = norm_layer\n        self.norm_cfg = norm_cfg\n        self.pos_embed_interp = pos_embed_interp\n        self.random_init = random_init\n        self.align_corners = align_corners\n        self.use_checkpoint = use_checkpoint\n        self.num_extra_tokens = num_extra_tokens\n        self.out_features = out_features\n        self.out_indices = [int(name[5:]) for name in out_features]\n\n        # self.num_stages = self.depth\n        # self.out_indices = tuple(range(self.num_stages))\n\n        if self.hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)\n        self.num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n\n        if self.num_extra_tokens == 2:\n            self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n\n        self.pos_embed = nn.Parameter(torch.zeros(\n            1, self.num_patches + self.num_extra_tokens, self.embed_dim))\n        self.pos_drop = nn.Dropout(p=self.drop_rate)\n\n        # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches\n        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,\n                                                self.depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,\n                qk_scale=self.qk_scale,\n                drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)\n            for i in range(self.depth)])\n\n        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here\n        # self.repr = nn.Linear(embed_dim, representation_size)\n        # self.repr_act = nn.Tanh()\n\n        if patch_size == 16:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                nn.SyncBatchNorm(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n        elif patch_size == 8:\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Identity()\n\n            self.fpn3 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n            )\n\n            self.fpn4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n            )\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        if self.num_extra_tokens==2:\n            trunc_normal_(self.dist_token, std=0.2)\n        self.apply(self._init_weights)\n        # self.fix_init_weight()\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    '''\n    def init_weights(self):\n        logger = get_root_logger()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n        if self.init_cfg is None:\n            logger.warn(f'No pre-trained weights for '\n                        f'{self.__class__.__name__}, '\n                        f'training start from scratch')\n        else:\n            assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n                                                  f'specify `Pretrained` in ' \\\n                                                  f'`init_cfg` in ' \\\n                                                  f'{self.__class__.__name__} '\n            logger.info(f\"Will load ckpt from {self.init_cfg['checkpoint']}\")\n            load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)\n    '''\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def _conv_filter(self, state_dict, patch_size=16):\n        \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n        out_dict = {}\n        for k, v in state_dict.items():\n            if 'patch_embed.proj.weight' in k:\n                v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n            out_dict[k] = v\n        return out_dict\n\n    def to_2D(self, x):\n        n, hw, c = x.shape\n        h = w = int(math.sqrt(hw))\n        x = x.transpose(1, 2).reshape(n, c, h, w)\n        return x\n\n    def to_1D(self, x):\n        n, c, h, w = x.shape\n        x = x.reshape(n, c, -1).transpose(1, 2)\n        return x\n\n    def interpolate_pos_encoding(self, x, w, h):\n        npatch = x.shape[1] - self.num_extra_tokens\n        N = self.pos_embed.shape[1] - self.num_extra_tokens\n        if npatch == N and w == h:\n            return self.pos_embed\n\n        class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]\n\n        patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]\n\n        dim = x.shape[-1]\n        w0 = w // self.patch_embed.patch_size[0]\n        h0 = h // self.patch_embed.patch_size[1]\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        w0, h0 = w0 + 0.1, h0 + 0.1\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n            mode='bicubic',\n        )\n        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n\n        return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)\n\n    def prepare_tokens(self, x, mask=None):\n        B, nc, w, h = x.shape\n        # patch linear embedding\n        x = self.patch_embed(x)\n\n        # mask image modeling\n        if mask is not None:\n            x = self.mask_model(x, mask)\n        x = x.flatten(2).transpose(1, 2)\n\n        # add the [CLS] token to the embed patch tokens\n        all_tokens = [self.cls_token.expand(B, -1, -1)]\n\n        if self.num_extra_tokens == 2:\n            dist_tokens = self.dist_token.expand(B, -1, -1)\n            all_tokens.append(dist_tokens)\n        all_tokens.append(x)\n\n        x = torch.cat(all_tokens, dim=1)\n\n        # add positional encoding to each token\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        return self.pos_drop(x)\n\n    def forward_features(self, x):\n        # print(f\"==========shape of x is {x.shape}==========\")\n        B, _, H, W = x.shape\n        Hp, Wp = H // self.patch_size, W // self.patch_size\n        x = self.prepare_tokens(x)\n\n        features = []\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n            if i in self.out_indices:\n                xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)\n                features.append(xp.contiguous())\n\n        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n        for i in range(len(features)):\n            features[i] = ops[i](features[i])\n\n        feat_out = {}\n\n        for name, value in zip(self.out_features, features):\n            feat_out[name] = value\n\n        return feat_out\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        return x\n\n\ndef deit_base_patch16(pretrained=False, **kwargs):\n    model = ViT(\n        patch_size=16,\n        drop_rate=0.,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        num_classes=1000,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        use_checkpoint=True,\n        num_extra_tokens=2,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\ndef mae_base_patch16(pretrained=False, **kwargs):\n    model = ViT(\n        patch_size=16,\n        drop_rate=0.,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        num_classes=1000,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        use_checkpoint=True,\n        num_extra_tokens=1,\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/__init__.py",
    "content": "from .models import (\n    LayoutLMv3Config,\n    LayoutLMv3ForTokenClassification,\n    LayoutLMv3ForQuestionAnswering,\n    LayoutLMv3ForSequenceClassification,\n    LayoutLMv3Tokenizer,\n)\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/__init__.py",
    "content": "# flake8: noqa\nfrom .data_collator import DataCollatorForKeyValueExtraction\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/cord.py",
    "content": "'''\nReference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py\n'''\n\n\nimport json\nimport os\nfrom pathlib import Path\nimport datasets\nfrom .image_utils import load_image, normalize_bbox\nlogger = datasets.logging.get_logger(__name__)\n_CITATION = \"\"\"\\\n@article{park2019cord,\n  title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},\n  author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}\n  booktitle={Document Intelligence Workshop at Neural Information Processing Systems}\n  year={2019}\n}\n\"\"\"\n_DESCRIPTION = \"\"\"\\\nhttps://github.com/clovaai/cord/\n\"\"\"\n\ndef quad_to_box(quad):\n    # test 87 is wrongly annotated\n    box = (\n        max(0, quad[\"x1\"]),\n        max(0, quad[\"y1\"]),\n        quad[\"x3\"],\n        quad[\"y3\"]\n    )\n    if box[3] < box[1]:\n        bbox = list(box)\n        tmp = bbox[3]\n        bbox[3] = bbox[1]\n        bbox[1] = tmp\n        box = tuple(bbox)\n    if box[2] < box[0]:\n        bbox = list(box)\n        tmp = bbox[2]\n        bbox[2] = bbox[0]\n        bbox[0] = tmp\n        box = tuple(bbox)\n    return box\n\ndef _get_drive_url(url):\n    base_url = 'https://drive.google.com/uc?id='\n    split_url = url.split('/')\n    return base_url + split_url[5]\n\n_URLS = [\n    _get_drive_url(\"https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/\"),\n    _get_drive_url(\"https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/\")\n    # If you failed to download the dataset through the automatic downloader,\n    # you can download it manually and modify the code to get the local dataset.\n    # Or you can use the following links. Please follow the original LICENSE of CORD for usage.\n    # \"https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip\",\n    # \"https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip\"\n]\n\nclass CordConfig(datasets.BuilderConfig):\n    \"\"\"BuilderConfig for CORD\"\"\"\n    def __init__(self, **kwargs):\n        \"\"\"BuilderConfig for CORD.\n        Args:\n          **kwargs: keyword arguments forwarded to super.\n        \"\"\"\n        super(CordConfig, self).__init__(**kwargs)\n\nclass Cord(datasets.GeneratorBasedBuilder):\n    BUILDER_CONFIGS = [\n        CordConfig(name=\"cord\", version=datasets.Version(\"1.0.0\"), description=\"CORD dataset\"),\n    ]\n\n    def _info(self):\n        return datasets.DatasetInfo(\n            description=_DESCRIPTION,\n            features=datasets.Features(\n                {\n                    \"id\": datasets.Value(\"string\"),\n                    \"words\": datasets.Sequence(datasets.Value(\"string\")),\n                    \"bboxes\": datasets.Sequence(datasets.Sequence(datasets.Value(\"int64\"))),\n                    \"ner_tags\": datasets.Sequence(\n                        datasets.features.ClassLabel(\n                            names=[\"O\",\"B-MENU.NM\",\"B-MENU.NUM\",\"B-MENU.UNITPRICE\",\"B-MENU.CNT\",\"B-MENU.DISCOUNTPRICE\",\"B-MENU.PRICE\",\"B-MENU.ITEMSUBTOTAL\",\"B-MENU.VATYN\",\"B-MENU.ETC\",\"B-MENU.SUB_NM\",\"B-MENU.SUB_UNITPRICE\",\"B-MENU.SUB_CNT\",\"B-MENU.SUB_PRICE\",\"B-MENU.SUB_ETC\",\"B-VOID_MENU.NM\",\"B-VOID_MENU.PRICE\",\"B-SUB_TOTAL.SUBTOTAL_PRICE\",\"B-SUB_TOTAL.DISCOUNT_PRICE\",\"B-SUB_TOTAL.SERVICE_PRICE\",\"B-SUB_TOTAL.OTHERSVC_PRICE\",\"B-SUB_TOTAL.TAX_PRICE\",\"B-SUB_TOTAL.ETC\",\"B-TOTAL.TOTAL_PRICE\",\"B-TOTAL.TOTAL_ETC\",\"B-TOTAL.CASHPRICE\",\"B-TOTAL.CHANGEPRICE\",\"B-TOTAL.CREDITCARDPRICE\",\"B-TOTAL.EMONEYPRICE\",\"B-TOTAL.MENUTYPE_CNT\",\"B-TOTAL.MENUQTY_CNT\",\"I-MENU.NM\",\"I-MENU.NUM\",\"I-MENU.UNITPRICE\",\"I-MENU.CNT\",\"I-MENU.DISCOUNTPRICE\",\"I-MENU.PRICE\",\"I-MENU.ITEMSUBTOTAL\",\"I-MENU.VATYN\",\"I-MENU.ETC\",\"I-MENU.SUB_NM\",\"I-MENU.SUB_UNITPRICE\",\"I-MENU.SUB_CNT\",\"I-MENU.SUB_PRICE\",\"I-MENU.SUB_ETC\",\"I-VOID_MENU.NM\",\"I-VOID_MENU.PRICE\",\"I-SUB_TOTAL.SUBTOTAL_PRICE\",\"I-SUB_TOTAL.DISCOUNT_PRICE\",\"I-SUB_TOTAL.SERVICE_PRICE\",\"I-SUB_TOTAL.OTHERSVC_PRICE\",\"I-SUB_TOTAL.TAX_PRICE\",\"I-SUB_TOTAL.ETC\",\"I-TOTAL.TOTAL_PRICE\",\"I-TOTAL.TOTAL_ETC\",\"I-TOTAL.CASHPRICE\",\"I-TOTAL.CHANGEPRICE\",\"I-TOTAL.CREDITCARDPRICE\",\"I-TOTAL.EMONEYPRICE\",\"I-TOTAL.MENUTYPE_CNT\",\"I-TOTAL.MENUQTY_CNT\"]\n                        )\n                    ),\n                    \"image\": datasets.Array3D(shape=(3, 224, 224), dtype=\"uint8\"),\n                    \"image_path\": datasets.Value(\"string\"),\n                }\n            ),\n            supervised_keys=None,\n            citation=_CITATION,\n            homepage=\"https://github.com/clovaai/cord/\",\n        )\n\n    def _split_generators(self, dl_manager):\n        \"\"\"Returns SplitGenerators.\"\"\"\n        \"\"\"Uses local files located with data_dir\"\"\"\n        downloaded_file = dl_manager.download_and_extract(_URLS)\n        # move files from the second URL together with files from the first one.\n        dest = Path(downloaded_file[0])/\"CORD\"\n        for split in [\"train\", \"dev\", \"test\"]:\n            for file_type in [\"image\", \"json\"]:\n                if split == \"test\" and file_type == \"json\":\n                    continue\n                files = (Path(downloaded_file[1])/\"CORD\"/split/file_type).iterdir()\n                for f in files:\n                    os.rename(f, dest/split/file_type/f.name)\n        return [\n            datasets.SplitGenerator(\n                name=datasets.Split.TRAIN, gen_kwargs={\"filepath\": dest/\"train\"}\n            ),\n            datasets.SplitGenerator(\n                name=datasets.Split.VALIDATION, gen_kwargs={\"filepath\": dest/\"dev\"}\n            ),\n            datasets.SplitGenerator(\n                name=datasets.Split.TEST, gen_kwargs={\"filepath\": dest/\"test\"}\n            ),\n        ]\n\n    def get_line_bbox(self, bboxs):\n        x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]\n        y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]\n\n        x0, y0, x1, y1 = min(x), min(y), max(x), max(y)\n\n        assert x1 >= x0 and y1 >= y0\n        bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]\n        return bbox\n\n    def _generate_examples(self, filepath):\n        logger.info(\"⏳ Generating examples from = %s\", filepath)\n        ann_dir = os.path.join(filepath, \"json\")\n        img_dir = os.path.join(filepath, \"image\")\n        for guid, file in enumerate(sorted(os.listdir(ann_dir))):\n            words = []\n            bboxes = []\n            ner_tags = []\n            file_path = os.path.join(ann_dir, file)\n            with open(file_path, \"r\", encoding=\"utf8\") as f:\n                data = json.load(f)\n            image_path = os.path.join(img_dir, file)\n            image_path = image_path.replace(\"json\", \"png\")\n            image, size = load_image(image_path)\n            for item in data[\"valid_line\"]:\n                cur_line_bboxes = []\n                line_words, label = item[\"words\"], item[\"category\"]\n                line_words = [w for w in line_words if w[\"text\"].strip() != \"\"]\n                if len(line_words) == 0:\n                    continue\n                if label == \"other\":\n                    for w in line_words:\n                        words.append(w[\"text\"])\n                        ner_tags.append(\"O\")\n                        cur_line_bboxes.append(normalize_bbox(quad_to_box(w[\"quad\"]), size))\n                else:\n                    words.append(line_words[0][\"text\"])\n                    ner_tags.append(\"B-\" + label.upper())\n                    cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0][\"quad\"]), size))\n                    for w in line_words[1:]:\n                        words.append(w[\"text\"])\n                        ner_tags.append(\"I-\" + label.upper())\n                        cur_line_bboxes.append(normalize_bbox(quad_to_box(w[\"quad\"]), size))\n                # by default: --segment_level_layout 1\n                # if do not want to use segment_level_layout, comment the following line\n                cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)\n                bboxes.extend(cur_line_bboxes)\n            # yield guid, {\"id\": str(guid), \"words\": words, \"bboxes\": bboxes, \"ner_tags\": ner_tags, \"image\": image}\n            yield guid, {\"id\": str(guid), \"words\": words, \"bboxes\": bboxes, \"ner_tags\": ner_tags,\n                         \"image\": image, \"image_path\": image_path}\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/data_collator.py",
    "content": "import torch\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nfrom transformers import BatchEncoding, PreTrainedTokenizerBase\nfrom transformers.data.data_collator import (\n    DataCollatorMixin,\n    _torch_collate_batch,\n)\nfrom transformers.file_utils import PaddingStrategy\n\nfrom typing import NewType\nInputDataClass = NewType(\"InputDataClass\", Any)\n\ndef pre_calc_rel_mat(segment_ids):\n    valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),\n                             device=segment_ids.device, dtype=torch.bool)\n    for i in range(segment_ids.shape[0]):\n        for j in range(segment_ids.shape[1]):\n            valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]\n\n    return valid_span\n\n@dataclass\nclass DataCollatorForKeyValueExtraction(DataCollatorMixin):\n    \"\"\"\n    Data collator that will dynamically pad the inputs received, as well as the labels.\n    Args:\n        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):\n            The tokenizer used for encoding the data.\n        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n            among:\n            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n              sequence if provided).\n            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n              maximum acceptable input length for the model if that argument is not provided.\n            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n              different lengths).\n        max_length (:obj:`int`, `optional`):\n            Maximum length of the returned list and optionally padding length (see above).\n        pad_to_multiple_of (:obj:`int`, `optional`):\n            If set will pad the sequence to a multiple of the provided value.\n            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=\n            7.5 (Volta).\n        label_pad_token_id (:obj:`int`, `optional`, defaults to -100):\n            The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).\n    \"\"\"\n\n    tokenizer: PreTrainedTokenizerBase\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    label_pad_token_id: int = -100\n\n    def __call__(self, features):\n        label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None\n\n        images = None\n        if \"images\" in features[0]:\n            images = torch.stack([torch.tensor(d.pop(\"images\")) for d in features])\n            IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1\n\n        batch = self.tokenizer.pad(\n            features,\n            padding=self.padding,\n            max_length=self.max_length,\n            pad_to_multiple_of=self.pad_to_multiple_of,\n            # Conversion to tensors will fail if we have labels as they are not of the same length yet.\n            return_tensors=\"pt\" if labels is None else None,\n        )\n\n        if images is not None:\n            batch[\"images\"] = images\n            batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v\n                     for k, v in batch.items()}\n            visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)\n            batch[\"attention_mask\"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)\n\n        if labels is None:\n            return batch\n\n        has_bbox_input = \"bbox\" in features[0]\n        has_position_input = \"position_ids\" in features[0]\n        padding_idx=self.tokenizer.pad_token_id\n        sequence_length = torch.tensor(batch[\"input_ids\"]).shape[1]\n        padding_side = self.tokenizer.padding_side\n        if padding_side == \"right\":\n            batch[\"labels\"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]\n            if has_bbox_input:\n                batch[\"bbox\"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch[\"bbox\"]]\n            if has_position_input:\n                batch[\"position_ids\"] = [position_id + [padding_idx] * (sequence_length - len(position_id))\n                                          for position_id in batch[\"position_ids\"]]\n\n        else:\n            batch[\"labels\"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]\n            if has_bbox_input:\n                batch[\"bbox\"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch[\"bbox\"]]\n            if has_position_input:\n                batch[\"position_ids\"] = [[padding_idx] * (sequence_length - len(position_id))\n                                          + position_id for position_id in batch[\"position_ids\"]]\n\n        if 'segment_ids' in batch:\n            assert 'position_ids' in batch\n            for i in range(len(batch['segment_ids'])):\n                batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [\n                    batch['segment_ids'][i][-1] + 2] * IMAGE_LEN\n\n        batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}\n\n        if 'segment_ids' in batch:\n            valid_span = pre_calc_rel_mat(\n                segment_ids=batch['segment_ids']\n            )\n            batch['valid_span'] = valid_span\n            del batch['segment_ids']\n\n        if images is not None:\n            visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100\n            batch[\"labels\"] = torch.cat([batch['labels'], visual_labels], dim=1)\n\n        return batch\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/funsd.py",
    "content": "# coding=utf-8\n'''\nReference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py\n'''\nimport json\nimport os\n\nimport datasets\n\nfrom .image_utils import load_image, normalize_bbox\n\n\nlogger = datasets.logging.get_logger(__name__)\n\n\n_CITATION = \"\"\"\\\n@article{Jaume2019FUNSDAD,\n  title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},\n  author={Guillaume Jaume and H. K. Ekenel and J. Thiran},\n  journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},\n  year={2019},\n  volume={2},\n  pages={1-6}\n}\n\"\"\"\n\n_DESCRIPTION = \"\"\"\\\nhttps://guillaumejaume.github.io/FUNSD/\n\"\"\"\n\n\nclass FunsdConfig(datasets.BuilderConfig):\n    \"\"\"BuilderConfig for FUNSD\"\"\"\n\n    def __init__(self, **kwargs):\n        \"\"\"BuilderConfig for FUNSD.\n\n        Args:\n          **kwargs: keyword arguments forwarded to super.\n        \"\"\"\n        super(FunsdConfig, self).__init__(**kwargs)\n\n\nclass Funsd(datasets.GeneratorBasedBuilder):\n    \"\"\"Conll2003 dataset.\"\"\"\n\n    BUILDER_CONFIGS = [\n        FunsdConfig(name=\"funsd\", version=datasets.Version(\"1.0.0\"), description=\"FUNSD dataset\"),\n    ]\n\n    def _info(self):\n        return datasets.DatasetInfo(\n            description=_DESCRIPTION,\n            features=datasets.Features(\n                {\n                    \"id\": datasets.Value(\"string\"),\n                    \"tokens\": datasets.Sequence(datasets.Value(\"string\")),\n                    \"bboxes\": datasets.Sequence(datasets.Sequence(datasets.Value(\"int64\"))),\n                    \"ner_tags\": datasets.Sequence(\n                        datasets.features.ClassLabel(\n                            names=[\"O\", \"B-HEADER\", \"I-HEADER\", \"B-QUESTION\", \"I-QUESTION\", \"B-ANSWER\", \"I-ANSWER\"]\n                        )\n                    ),\n                    \"image\": datasets.Array3D(shape=(3, 224, 224), dtype=\"uint8\"),\n                    \"image_path\": datasets.Value(\"string\"),\n                }\n            ),\n            supervised_keys=None,\n            homepage=\"https://guillaumejaume.github.io/FUNSD/\",\n            citation=_CITATION,\n        )\n\n    def _split_generators(self, dl_manager):\n        \"\"\"Returns SplitGenerators.\"\"\"\n        downloaded_file = dl_manager.download_and_extract(\"https://guillaumejaume.github.io/FUNSD/dataset.zip\")\n        return [\n            datasets.SplitGenerator(\n                name=datasets.Split.TRAIN, gen_kwargs={\"filepath\": f\"{downloaded_file}/dataset/training_data/\"}\n            ),\n            datasets.SplitGenerator(\n                name=datasets.Split.TEST, gen_kwargs={\"filepath\": f\"{downloaded_file}/dataset/testing_data/\"}\n            ),\n        ]\n\n    def get_line_bbox(self, bboxs):\n        x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]\n        y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]\n\n        x0, y0, x1, y1 = min(x), min(y), max(x), max(y)\n\n        assert x1 >= x0 and y1 >= y0\n        bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]\n        return bbox\n\n    def _generate_examples(self, filepath):\n        logger.info(\"⏳ Generating examples from = %s\", filepath)\n        ann_dir = os.path.join(filepath, \"annotations\")\n        img_dir = os.path.join(filepath, \"images\")\n        for guid, file in enumerate(sorted(os.listdir(ann_dir))):\n            tokens = []\n            bboxes = []\n            ner_tags = []\n\n            file_path = os.path.join(ann_dir, file)\n            with open(file_path, \"r\", encoding=\"utf8\") as f:\n                data = json.load(f)\n            image_path = os.path.join(img_dir, file)\n            image_path = image_path.replace(\"json\", \"png\")\n            image, size = load_image(image_path)\n            for item in data[\"form\"]:\n                cur_line_bboxes = []\n                words, label = item[\"words\"], item[\"label\"]\n                words = [w for w in words if w[\"text\"].strip() != \"\"]\n                if len(words) == 0:\n                    continue\n                if label == \"other\":\n                    for w in words:\n                        tokens.append(w[\"text\"])\n                        ner_tags.append(\"O\")\n                        cur_line_bboxes.append(normalize_bbox(w[\"box\"], size))\n                else:\n                    tokens.append(words[0][\"text\"])\n                    ner_tags.append(\"B-\" + label.upper())\n                    cur_line_bboxes.append(normalize_bbox(words[0][\"box\"], size))\n                    for w in words[1:]:\n                        tokens.append(w[\"text\"])\n                        ner_tags.append(\"I-\" + label.upper())\n                        cur_line_bboxes.append(normalize_bbox(w[\"box\"], size))\n                # by default: --segment_level_layout 1\n                # if do not want to use segment_level_layout, comment the following line\n                cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)\n                # box = normalize_bbox(item[\"box\"], size)\n                # cur_line_bboxes = [box for _ in range(len(words))]\n                bboxes.extend(cur_line_bboxes)\n            yield guid, {\"id\": str(guid), \"tokens\": tokens, \"bboxes\": bboxes, \"ner_tags\": ner_tags,\n                         \"image\": image, \"image_path\": image_path}"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/image_utils.py",
    "content": "import torchvision.transforms.functional as F\nimport warnings\nimport math\nimport random\nimport numpy as np\nfrom PIL import Image\nimport torch\n\nfrom detectron2.data.detection_utils import read_image\nfrom detectron2.data.transforms import ResizeTransform, TransformList\n\ndef normalize_bbox(bbox, size):\n    return [\n        int(1000 * bbox[0] / size[0]),\n        int(1000 * bbox[1] / size[1]),\n        int(1000 * bbox[2] / size[0]),\n        int(1000 * bbox[3] / size[1]),\n    ]\n\n\ndef load_image(image_path):\n    image = read_image(image_path, format=\"BGR\")\n    h = image.shape[0]\n    w = image.shape[1]\n    img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])\n    image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1)  # copy to make it writeable\n    return image, (w, h)\n\n\ndef crop(image, i, j, h, w, boxes=None):\n    cropped_image = F.crop(image, i, j, h, w)\n\n    if boxes is not None:\n        # Currently we cannot use this case since when some boxes is out of the cropped image,\n        # it may be better to drop out these boxes along with their text input (instead of min or clamp)\n        # which haven't been implemented here\n        max_size = torch.as_tensor([w, h], dtype=torch.float32)\n        cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])\n        cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)\n        cropped_boxes = cropped_boxes.clamp(min=0)\n        boxes = cropped_boxes.reshape(-1, 4)\n\n    return cropped_image, boxes\n\n\ndef resize(image, size, interpolation, boxes=None):\n    # It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,\n    # which is compatible with a square image size of 224x224\n    rescaled_image = F.resize(image, size, interpolation)\n\n    if boxes is None:\n        return rescaled_image, None\n\n    ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))\n    ratio_width, ratio_height = ratios\n\n    # boxes = boxes.copy()\n    scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])\n\n    return rescaled_image, scaled_boxes\n\n\ndef clamp(num, min_value, max_value):\n    return max(min(num, max_value), min_value)\n\n\ndef get_bb(bb, page_size):\n    bbs = [float(j) for j in bb]\n    xs, ys = [], []\n    for i, b in enumerate(bbs):\n        if i % 2 == 0:\n            xs.append(b)\n        else:\n            ys.append(b)\n    (width, height) = page_size\n    return_bb = [\n        clamp(min(xs), 0, width - 1),\n        clamp(min(ys), 0, height - 1),\n        clamp(max(xs), 0, width - 1),\n        clamp(max(ys), 0, height - 1),\n    ]\n    return_bb = [\n            int(1000 * return_bb[0] / width),\n            int(1000 * return_bb[1] / height),\n            int(1000 * return_bb[2] / width),\n            int(1000 * return_bb[3] / height),\n        ]\n    return return_bb\n\n\nclass ToNumpy:\n\n    def __call__(self, pil_img):\n        np_img = np.array(pil_img, dtype=np.uint8)\n        if np_img.ndim < 3:\n            np_img = np.expand_dims(np_img, axis=-1)\n        np_img = np.rollaxis(np_img, 2)  # HWC to CHW\n        return np_img\n\n\nclass ToTensor:\n\n    def __init__(self, dtype=torch.float32):\n        self.dtype = dtype\n\n    def __call__(self, pil_img):\n        np_img = np.array(pil_img, dtype=np.uint8)\n        if np_img.ndim < 3:\n            np_img = np.expand_dims(np_img, axis=-1)\n        np_img = np.rollaxis(np_img, 2)  # HWC to CHW\n        return torch.from_numpy(np_img).to(dtype=self.dtype)\n\n\n_pil_interpolation_to_str = {\n    F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',\n    F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',\n    F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',\n    F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',\n    F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',\n    F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',\n}\n\n\ndef _pil_interp(method):\n    if method == 'bicubic':\n        return F.InterpolationMode.BICUBIC\n    elif method == 'lanczos':\n        return F.InterpolationMode.LANCZOS\n    elif method == 'hamming':\n        return F.InterpolationMode.HAMMING\n    else:\n        # default bilinear, do we want to allow nearest?\n        return F.InterpolationMode.BILINEAR\n\n\nclass Compose:\n    \"\"\"Composes several transforms together. This transform does not support torchscript.\n    Please, see the note below.\n\n    Args:\n        transforms (list of ``Transform`` objects): list of transforms to compose.\n\n    Example:\n        >>> transforms.Compose([\n        >>>     transforms.CenterCrop(10),\n        >>>     transforms.PILToTensor(),\n        >>>     transforms.ConvertImageDtype(torch.float),\n        >>> ])\n\n    .. note::\n        In order to script the transformations, please use ``torch.nn.Sequential`` as below.\n\n        >>> transforms = torch.nn.Sequential(\n        >>>     transforms.CenterCrop(10),\n        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n        >>> )\n        >>> scripted_transforms = torch.jit.script(transforms)\n\n        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require\n        `lambda` functions or ``PIL.Image``.\n\n    \"\"\"\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, img, augmentation=False, box=None):\n        for t in self.transforms:\n            img = t(img, augmentation, box)\n        return img\n\n\nclass RandomResizedCropAndInterpolationWithTwoPic:\n    \"\"\"Crop the given PIL Image to random size and aspect ratio with random interpolation.\n    A crop of random size (default: of 0.08 to 1.0) of the original size and a random\n    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop\n    is finally resized to given size.\n    This is popularly used to train the Inception networks.\n    Args:\n        size: expected output size of each edge\n        scale: range of size of the origin size cropped\n        ratio: range of aspect ratio of the origin aspect ratio cropped\n        interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),\n                 interpolation='bilinear', second_interpolation='lanczos'):\n        if isinstance(size, tuple):\n            self.size = size\n        else:\n            self.size = (size, size)\n        if second_size is not None:\n            if isinstance(second_size, tuple):\n                self.second_size = second_size\n            else:\n                self.second_size = (second_size, second_size)\n        else:\n            self.second_size = None\n        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):\n            warnings.warn(\"range should be of kind (min, max)\")\n\n        self.interpolation = _pil_interp(interpolation)\n        self.second_interpolation = _pil_interp(second_interpolation)\n        self.scale = scale\n        self.ratio = ratio\n\n    @staticmethod\n    def get_params(img, scale, ratio):\n        \"\"\"Get parameters for ``crop`` for a random sized crop.\n        Args:\n            img (PIL Image): Image to be cropped.\n            scale (tuple): range of size of the origin size cropped\n            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped\n        Returns:\n            tuple: params (i, j, h, w) to be passed to ``crop`` for a random\n                sized crop.\n        \"\"\"\n        area = img.size[0] * img.size[1]\n\n        for attempt in range(10):\n            target_area = random.uniform(*scale) * area\n            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))\n            aspect_ratio = math.exp(random.uniform(*log_ratio))\n\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w <= img.size[0] and h <= img.size[1]:\n                i = random.randint(0, img.size[1] - h)\n                j = random.randint(0, img.size[0] - w)\n                return i, j, h, w\n\n        # Fallback to central crop\n        in_ratio = img.size[0] / img.size[1]\n        if in_ratio < min(ratio):\n            w = img.size[0]\n            h = int(round(w / min(ratio)))\n        elif in_ratio > max(ratio):\n            h = img.size[1]\n            w = int(round(h * max(ratio)))\n        else:  # whole image\n            w = img.size[0]\n            h = img.size[1]\n        i = (img.size[1] - h) // 2\n        j = (img.size[0] - w) // 2\n        return i, j, h, w\n\n    def __call__(self, img, augmentation=False, box=None):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped and resized.\n        Returns:\n            PIL Image: Randomly cropped and resized image.\n        \"\"\"\n        if augmentation:\n            i, j, h, w = self.get_params(img, self.scale, self.ratio)\n            img = F.crop(img, i, j, h, w)\n            # img, box = crop(img, i, j, h, w, box)\n        img = F.resize(img, self.size, self.interpolation)\n        second_img = F.resize(img, self.second_size, self.second_interpolation) \\\n            if self.second_size is not None else None\n        return img, second_img\n\n    def __repr__(self):\n        if isinstance(self.interpolation, (tuple, list)):\n            interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])\n        else:\n            interpolate_str = _pil_interpolation_to_str[self.interpolation]\n        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)\n        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))\n        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))\n        format_string += ', interpolation={0}'.format(interpolate_str)\n        if self.second_size is not None:\n            format_string += ', second_size={0}'.format(self.second_size)\n            format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])\n        format_string += ')'\n        return format_string\n\n\ndef pil_loader(path: str) -> Image.Image:\n    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n    with open(path, 'rb') as f:\n        img = Image.open(f)\n        return img.convert('RGB')\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/data/xfund.py",
    "content": "import os\nimport json\n\nimport torch\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import transforms\nfrom PIL import Image\n\nfrom .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic\n\nXFund_label2ids = {\n    \"O\":0,\n    'B-HEADER':1,\n    'I-HEADER':2,\n    'B-QUESTION':3,\n    'I-QUESTION':4,\n    'B-ANSWER':5,\n    'I-ANSWER':6,\n}\n\nclass xfund_dataset(Dataset):\n    def box_norm(self, box, width, height):\n        def clip(min_num, num, max_num):\n            return min(max(num, min_num), max_num)\n\n        x0, y0, x1, y1 = box\n        x0 = clip(0, int((x0 / width) * 1000), 1000)\n        y0 = clip(0, int((y0 / height) * 1000), 1000)\n        x1 = clip(0, int((x1 / width) * 1000), 1000)\n        y1 = clip(0, int((y1 / height) * 1000), 1000)\n        assert x1 >= x0\n        assert y1 >= y0\n        return [x0, y0, x1, y1]\n\n    def get_segment_ids(self, bboxs):\n        segment_ids = []\n        for i in range(len(bboxs)):\n            if i == 0:\n                segment_ids.append(0)\n            else:\n                if bboxs[i - 1] == bboxs[i]:\n                    segment_ids.append(segment_ids[-1])\n                else:\n                    segment_ids.append(segment_ids[-1] + 1)\n        return segment_ids\n\n    def get_position_ids(self, segment_ids):\n        position_ids = []\n        for i in range(len(segment_ids)):\n            if i == 0:\n                position_ids.append(2)\n            else:\n                if segment_ids[i] == segment_ids[i - 1]:\n                    position_ids.append(position_ids[-1] + 1)\n                else:\n                    position_ids.append(2)\n        return position_ids\n\n    def load_data(\n            self,\n            data_file,\n    ):\n        # re-org data format\n        total_data = {\"id\": [], \"lines\": [], \"bboxes\": [], \"ner_tags\": [], \"image_path\": []}\n        for i in range(len(data_file['documents'])):\n            width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][\n                'height']\n\n            cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []\n            for j in range(len(data_file['documents'][i]['document'])):\n                cur_item = data_file['documents'][i]['document'][j]\n                cur_doc_lines.append(cur_item['text'])\n                cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))\n                cur_doc_ner_tags.append(cur_item['label'])\n            total_data['id'] += [len(total_data['id'])]\n            total_data['lines'] += [cur_doc_lines]\n            total_data['bboxes'] += [cur_doc_bboxes]\n            total_data['ner_tags'] += [cur_doc_ner_tags]\n            total_data['image_path'] += [data_file['documents'][i]['img']['fname']]\n\n        # tokenize text and get bbox/label\n        total_input_ids, total_bboxs, total_label_ids = [], [], []\n        for i in range(len(total_data['lines'])):\n            cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []\n            for j in range(len(total_data['lines'][i])):\n                cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']\n                if len(cur_input_ids) == 0: continue\n\n                cur_label = total_data['ner_tags'][i][j].upper()\n                if cur_label == 'OTHER':\n                    cur_labels = [\"O\"] * len(cur_input_ids)\n                    for k in range(len(cur_labels)):\n                        cur_labels[k] = self.label2ids[cur_labels[k]]\n                else:\n                    cur_labels = [cur_label] * len(cur_input_ids)\n                    cur_labels[0] = self.label2ids['B-' + cur_labels[0]]\n                    for k in range(1, len(cur_labels)):\n                        cur_labels[k] = self.label2ids['I-' + cur_labels[k]]\n                assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)\n                cur_doc_input_ids += cur_input_ids\n                cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)\n                cur_doc_labels += cur_labels\n            assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)\n            assert len(cur_doc_input_ids) > 0\n\n            total_input_ids.append(cur_doc_input_ids)\n            total_bboxs.append(cur_doc_bboxs)\n            total_label_ids.append(cur_doc_labels)\n        assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)\n\n        # split text to several slices because of over-length\n        input_ids, bboxs, labels = [], [], []\n        segment_ids, position_ids = [], []\n        image_path = []\n        for i in range(len(total_input_ids)):\n            start = 0\n            cur_iter = 0\n            while start < len(total_input_ids[i]):\n                end = min(start + 510, len(total_input_ids[i]))\n\n                input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])\n                bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])\n                labels.append([-100] + total_label_ids[i][start: end] + [-100])\n\n                cur_segment_ids = self.get_segment_ids(bboxs[-1])\n                cur_position_ids = self.get_position_ids(cur_segment_ids)\n                segment_ids.append(cur_segment_ids)\n                position_ids.append(cur_position_ids)\n                image_path.append(os.path.join(self.args.data_dir, \"images\", total_data['image_path'][i]))\n\n                start = end\n                cur_iter += 1\n\n        assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)\n        assert len(segment_ids) == len(image_path)\n\n        res = {\n            'input_ids': input_ids,\n            'bbox': bboxs,\n            'labels': labels,\n            'segment_ids': segment_ids,\n            'position_ids': position_ids,\n            'image_path': image_path,\n        }\n        return res\n\n    def __init__(\n            self,\n            args,\n            tokenizer,\n            mode\n    ):\n        self.args = args\n        self.mode = mode\n        self.cur_la = args.language\n        self.tokenizer = tokenizer\n        self.label2ids = XFund_label2ids\n\n\n        self.common_transform = Compose([\n            RandomResizedCropAndInterpolationWithTwoPic(\n                size=args.input_size, interpolation=args.train_interpolation,\n            ),\n        ])\n\n        self.patch_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor((0.5, 0.5, 0.5)),\n                std=torch.tensor((0.5, 0.5, 0.5)))\n        ])\n\n        data_file = json.load(\n            open(os.path.join(args.data_dir, \"{}.{}.json\".format(self.cur_la, 'train' if mode == 'train' else 'val')),\n                 'r'))\n\n        self.feature = self.load_data(data_file)\n\n    def __len__(self):\n        return len(self.feature['input_ids'])\n\n    def __getitem__(self, index):\n        input_ids = self.feature[\"input_ids\"][index]\n\n        # attention_mask = self.feature[\"attention_mask\"][index]\n        attention_mask = [1] * len(input_ids)\n        labels = self.feature[\"labels\"][index]\n        bbox = self.feature[\"bbox\"][index]\n        segment_ids = self.feature['segment_ids'][index]\n        position_ids = self.feature['position_ids'][index]\n\n        img = pil_loader(self.feature['image_path'][index])\n        for_patches, _ = self.common_transform(img, augmentation=False)\n        patch = self.patch_transform(for_patches)\n\n        assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)\n\n        res = {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"labels\": labels,\n            \"bbox\": bbox,\n            \"segment_ids\": segment_ids,\n            \"position_ids\": position_ids,\n            \"images\": patch,\n        }\n        return res\n\ndef pil_loader(path: str) -> Image.Image:\n    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n    with open(path, 'rb') as f:\n        img = Image.open(f)\n        return img.convert('RGB')"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/__init__.py",
    "content": "from .layoutlmv3 import (\n    LayoutLMv3Config,\n    LayoutLMv3ForTokenClassification,\n    LayoutLMv3ForQuestionAnswering,\n    LayoutLMv3ForSequenceClassification,\n    LayoutLMv3Tokenizer,\n)\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/__init__.py",
    "content": "from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \\\n    AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer\nfrom transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter\n\nfrom .configuration_layoutlmv3 import LayoutLMv3Config\nfrom .modeling_layoutlmv3 import (\n    LayoutLMv3ForTokenClassification,\n    LayoutLMv3ForQuestionAnswering,\n    LayoutLMv3ForSequenceClassification,\n    LayoutLMv3Model,\n)\nfrom .tokenization_layoutlmv3 import LayoutLMv3Tokenizer\nfrom .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast\n\n\n#AutoConfig.register(\"layoutlmv3\", LayoutLMv3Config)\n#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)\n#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)\n#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)\n#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)\n#AutoTokenizer.register(\n#    LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast\n#)\nSLOW_TO_FAST_CONVERTERS.update({\"LayoutLMv3Tokenizer\": RobertaConverter})\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py",
    "content": "# coding=utf-8\nfrom transformers.models.bert.configuration_bert import BertConfig\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\nLAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"layoutlmv3-base\": \"https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json\",\n    \"layoutlmv3-large\": \"https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json\",\n    # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3\n}\n\n\nclass LayoutLMv3Config(BertConfig):\n    model_type = \"layoutlmv3\"\n\n    def __init__(\n        self,\n        pad_token_id=1,\n        bos_token_id=0,\n        eos_token_id=2,\n        max_2d_position_embeddings=1024,\n        coordinate_size=None,\n        shape_size=None,\n        has_relative_attention_bias=False,\n        rel_pos_bins=32,\n        max_rel_pos=128,\n        has_spatial_attention_bias=False,\n        rel_2d_pos_bins=64,\n        max_rel_2d_pos=256,\n        visual_embed=True,\n        mim=False,\n        wpa_task=False,\n        discrete_vae_weight_path='',\n        discrete_vae_type='dall-e',\n        input_size=224,\n        second_input_size=112,\n        device='cuda',\n        **kwargs\n    ):\n        \"\"\"Constructs RobertaConfig.\"\"\"\n        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)\n        self.max_2d_position_embeddings = max_2d_position_embeddings\n        self.coordinate_size = coordinate_size\n        self.shape_size = shape_size\n        self.has_relative_attention_bias = has_relative_attention_bias\n        self.rel_pos_bins = rel_pos_bins\n        self.max_rel_pos = max_rel_pos\n        self.has_spatial_attention_bias = has_spatial_attention_bias\n        self.rel_2d_pos_bins = rel_2d_pos_bins\n        self.max_rel_2d_pos = max_rel_2d_pos\n        self.visual_embed = visual_embed\n        self.mim = mim\n        self.wpa_task = wpa_task\n        self.discrete_vae_weight_path = discrete_vae_weight_path\n        self.discrete_vae_type = discrete_vae_type\n        self.input_size = input_size\n        self.second_input_size = second_input_size\n        self.device = device\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch LayoutLMv3 model. \"\"\"\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers import apply_chunking_to_forward\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    MaskedLMOutput,\n    TokenClassifierOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer\nfrom transformers.models.roberta.modeling_roberta import (\n    RobertaIntermediate,\n    RobertaLMHead,\n    RobertaOutput,\n    RobertaSelfOutput,\n)\nfrom transformers.utils import logging\n\nfrom .configuration_layoutlmv3 import LayoutLMv3Config\nfrom timm.models.layers import to_2tuple\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        # The following variables are used in detection mycheckpointer.py\n        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.num_patches_w = self.patch_shape[0]\n        self.num_patches_h = self.patch_shape[1]\n\n    def forward(self, x, position_embedding=None):\n        x = self.proj(x)\n\n        if position_embedding is not None:\n            # interpolate the position embedding to the corresponding size\n            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2)\n            Hp, Wp = x.shape[2], x.shape[3]\n            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')\n            x = x + position_embedding\n\n        x = x.flatten(2).transpose(1, 2)\n        return x\n\nclass LayoutLMv3Embeddings(nn.Module):\n    \"\"\"\n    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.\n    \"\"\"\n\n    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)\n        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)\n\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1)))\n\n        # End copy\n        self.padding_idx = config.pad_token_id\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx\n        )\n\n        self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)\n        self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)\n        self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)\n        self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)\n\n    def _calc_spatial_position_embeddings(self, bbox):\n        try:\n            assert torch.all(0 <= bbox) and torch.all(bbox <= 1023)\n            left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])\n            upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])\n            right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])\n            lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])\n        except IndexError as e:\n            raise IndexError(\"The :obj:`bbox` coordinate values should be within 0-1000 range.\") from e\n\n        h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))\n        w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))\n\n        # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)\n        spatial_position_embeddings = torch.cat(\n            [\n                left_position_embeddings,\n                upper_position_embeddings,\n                right_position_embeddings,\n                lower_position_embeddings,\n                h_position_embeddings,\n                w_position_embeddings,\n            ],\n            dim=-1,\n        )\n        return spatial_position_embeddings\n\n    def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):\n        \"\"\"\n        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols\n        are ignored. This is modified from fairseq's `utils.make_positions`.\n\n        Args:\n            x: torch.Tensor x:\n\n        Returns: torch.Tensor\n        \"\"\"\n        # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.\n        mask = input_ids.ne(padding_idx).int()\n        incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask\n        return incremental_indices.long() + padding_idx\n\n    def forward(\n        self,\n        input_ids=None,\n        bbox=None,\n        token_type_ids=None,\n        position_ids=None,\n        inputs_embeds=None,\n        past_key_values_length=0,\n    ):\n        if position_ids is None:\n            if input_ids is not None:\n                # Create the position ids from the input token ids. Any padded tokens remain padded.\n                position_ids = self.create_position_ids_from_input_ids(\n                    input_ids, self.padding_idx, past_key_values_length).to(input_ids.device)\n            else:\n                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)\n\n        if input_ids is not None:\n            input_shape = input_ids.size()\n        else:\n            input_shape = inputs_embeds.size()[:-1]\n\n        if token_type_ids is None:\n            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n        token_type_embeddings = self.token_type_embeddings(token_type_ids)\n\n        embeddings = inputs_embeds + token_type_embeddings\n        position_embeddings = self.position_embeddings(position_ids)\n        embeddings += position_embeddings\n\n        spatial_position_embeddings = self._calc_spatial_position_embeddings(bbox)\n\n        embeddings = embeddings + spatial_position_embeddings\n\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n    def create_position_ids_from_inputs_embeds(self, inputs_embeds):\n        \"\"\"\n        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.\n\n        Args:\n            inputs_embeds: torch.Tensor≈\n\n        Returns: torch.Tensor\n        \"\"\"\n        input_shape = inputs_embeds.size()[:-1]\n        sequence_length = input_shape[1]\n\n        position_ids = torch.arange(\n            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device\n        )\n        return position_ids.unsqueeze(0).expand(input_shape)\n\n\nclass LayoutLMv3PreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = LayoutLMv3Config\n    base_model_prefix = \"layoutlmv3\"\n\n    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\nclass LayoutLMv3SelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n            raise ValueError(\n                f\"The hidden size ({config.hidden_size}) is not a multiple of the number of attention \"\n                f\"heads ({config.num_attention_heads})\"\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.has_relative_attention_bias = config.has_relative_attention_bias\n        self.has_spatial_attention_bias = config.has_spatial_attention_bias\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def cogview_attn(self, attention_scores, alpha=32):\n        '''\n        https://arxiv.org/pdf/2105.13290.pdf\n        Section 2.4 Stabilization of training: Precision Bottleneck Relaxation (PB-Relax).\n        A replacement of the original nn.Softmax(dim=-1)(attention_scores)\n        Seems the new attention_probs will result in a slower speed and a little bias\n        Can use torch.allclose(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison\n        The smaller atol (e.g., 1e-08), the better.\n        '''\n        scaled_attention_scores = attention_scores / alpha\n        max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)\n        # max_value = scaled_attention_scores.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)\n        new_attention_scores = (scaled_attention_scores - max_value) * alpha\n        return nn.Softmax(dim=-1)(new_attention_scores)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        rel_pos=None,\n        rel_2d_pos=None,\n    ):\n        mixed_query_layer = self.query(hidden_states)\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention and past_key_value is not None:\n            # reuse k,v, cross_attentions\n            key_layer = past_key_value[0]\n            value_layer = past_key_value[1]\n            attention_mask = encoder_attention_mask\n        elif is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.\n        # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)\n        attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))\n\n        if self.has_relative_attention_bias and self.has_spatial_attention_bias:\n            attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)\n        elif self.has_relative_attention_bias:\n            attention_scores += rel_pos / math.sqrt(self.attention_head_size)\n\n        # if self.has_relative_attention_bias:\n        #     attention_scores += rel_pos\n        # if self.has_spatial_attention_bias:\n        #     attention_scores += rel_2d_pos\n\n        # attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        # attention_probs = nn.Softmax(dim=-1)(attention_scores)  # comment the line below and use this line for speedup\n        attention_probs = self.cogview_attn(attention_scores)  # to stablize training\n        # assert torch.allclose(attention_probs, nn.Softmax(dim=-1)(attention_scores), atol=1e-8)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs = attention_probs * head_mask\n\n        context_layer = torch.matmul(attention_probs, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)\n\n        return outputs\n\n\nclass LayoutLMv3Attention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.self = LayoutLMv3SelfAttention(config)\n        self.output = RobertaSelfOutput(config)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        rel_pos=None,\n        rel_2d_pos=None,\n    ):\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n            rel_pos=rel_pos,\n            rel_2d_pos=rel_2d_pos,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them\n        return outputs\n\n\nclass LayoutLMv3Layer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = LayoutLMv3Attention(config)\n        assert not config.is_decoder and not config.add_cross_attention, \\\n            \"This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder.\"\n        self.intermediate = RobertaIntermediate(config)\n        self.output = RobertaOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        rel_pos=None,\n        rel_2d_pos=None,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n            rel_pos=rel_pos,\n            rel_2d_pos=rel_2d_pos,\n        )\n        attention_output = self_attention_outputs[0]\n\n        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights\n\n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output\n        )\n        outputs = (layer_output,) + outputs\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass LayoutLMv3Encoder(nn.Module):\n    def __init__(self, config, detection=False, out_features=None):\n        super().__init__()\n        self.config = config\n        self.detection = detection\n        self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])\n        self.gradient_checkpointing = False\n\n        self.has_relative_attention_bias = config.has_relative_attention_bias\n        self.has_spatial_attention_bias = config.has_spatial_attention_bias\n\n        if self.has_relative_attention_bias:\n            self.rel_pos_bins = config.rel_pos_bins\n            self.max_rel_pos = config.max_rel_pos\n            self.rel_pos_onehot_size = config.rel_pos_bins\n            self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)\n\n        if self.has_spatial_attention_bias:\n            self.max_rel_2d_pos = config.max_rel_2d_pos\n            self.rel_2d_pos_bins = config.rel_2d_pos_bins\n            self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins\n            self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)\n            self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)\n\n        if self.detection:\n            self.gradient_checkpointing = True\n            embed_dim = self.config.hidden_size\n            self.out_features = out_features\n            self.out_indices = [int(name[5:]) for name in out_features]\n            self.fpn1 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n                # nn.SyncBatchNorm(embed_dim),\n                nn.BatchNorm2d(embed_dim),\n                nn.GELU(),\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn2 = nn.Sequential(\n                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),\n            )\n\n            self.fpn3 = nn.Identity()\n\n            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)\n            self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]\n\n    def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):\n        ret = 0\n        if bidirectional:\n            num_buckets //= 2\n            ret += (relative_position > 0).long() * num_buckets\n            n = torch.abs(relative_position)\n        else:\n            n = torch.max(-relative_position, torch.zeros_like(relative_position))\n        # now n is in the range [0, inf)\n\n        # half of the buckets are for exact increments in positions\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).to(torch.long)\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def _cal_1d_pos_emb(self, hidden_states, position_ids, valid_span):\n        VISUAL_NUM = 196 + 1\n\n        rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)\n\n        if valid_span is not None:\n            # for the text part, if two words are not in the same line,\n            # set their distance to the max value (position_ids.shape[-1])\n            rel_pos_mat[(rel_pos_mat > 0) & (valid_span == False)] = position_ids.shape[1]\n            rel_pos_mat[(rel_pos_mat < 0) & (valid_span == False)] = -position_ids.shape[1]\n\n            # image-text, minimum distance\n            rel_pos_mat[:, -VISUAL_NUM:, :-VISUAL_NUM] = 0\n            rel_pos_mat[:, :-VISUAL_NUM, -VISUAL_NUM:] = 0\n\n        rel_pos = self.relative_position_bucket(\n            rel_pos_mat,\n            num_buckets=self.rel_pos_bins,\n            max_distance=self.max_rel_pos,\n        )\n        rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)\n        rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)\n        rel_pos = rel_pos.contiguous()\n        return rel_pos\n\n    def _cal_2d_pos_emb(self, hidden_states, bbox):\n        position_coord_x = bbox[:, :, 0]\n        position_coord_y = bbox[:, :, 3]\n        rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)\n        rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)\n        rel_pos_x = self.relative_position_bucket(\n            rel_pos_x_2d_mat,\n            num_buckets=self.rel_2d_pos_bins,\n            max_distance=self.max_rel_2d_pos,\n        )\n        rel_pos_y = self.relative_position_bucket(\n            rel_pos_y_2d_mat,\n            num_buckets=self.rel_2d_pos_bins,\n            max_distance=self.max_rel_2d_pos,\n        )\n        rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)\n        rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)\n        rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)\n        rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)\n        rel_pos_x = rel_pos_x.contiguous()\n        rel_pos_y = rel_pos_y.contiguous()\n        rel_2d_pos = rel_pos_x + rel_pos_y\n        return rel_2d_pos\n\n    def forward(\n        self,\n        hidden_states,\n        bbox=None,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        position_ids=None,\n        Hp=None,\n        Wp=None,\n        valid_span=None,\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None\n\n        next_decoder_cache = () if use_cache else None\n\n        rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids, valid_span) if self.has_relative_attention_bias else None\n        rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None\n\n        if self.detection:\n            feat_out = {}\n            j = 0\n\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n                        # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)\n                        # The above line will cause error:\n                        # RuntimeError: Trying to backward through the graph a second time\n                        # (or directly access saved tensors after they have already been freed).\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    rel_pos,\n                    rel_2d_pos\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    rel_pos=rel_pos,\n                    rel_2d_pos=rel_2d_pos,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n                if self.config.add_cross_attention:\n                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)\n\n            if self.detection and i in self.out_indices:\n                xp = hidden_states[:, -Hp*Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp)\n                feat_out[self.out_features[j]] = self.ops[j](xp.contiguous())\n                j += 1\n\n        if self.detection:\n            return feat_out\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass LayoutLMv3Model(LayoutLMv3PreTrainedModel):\n    \"\"\"\n    \"\"\"\n\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta\n    def __init__(self, config, detection=False, out_features=None, image_only=False):\n        super().__init__(config)\n        self.config = config\n        assert not config.is_decoder and not config.add_cross_attention, \\\n            \"This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder.\"\n        self.detection = detection\n        if not self.detection:\n            self.image_only = False\n        else:\n            assert config.visual_embed\n            self.image_only = image_only\n\n        if not self.image_only:\n            self.embeddings = LayoutLMv3Embeddings(config)\n        self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features)\n\n        if config.visual_embed:\n            embed_dim = self.config.hidden_size\n            # use the default pre-training parameters for fine-tuning (e.g., input_size)\n            # when the input_size is larger in fine-tuning, we will interpolate the position embedding in forward\n            self.patch_embed = PatchEmbed(embed_dim=embed_dim)\n\n            patch_size = 16\n            size = int(self.config.input_size / patch_size)\n            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n            self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, embed_dim))\n            self.pos_drop = nn.Dropout(p=0.)\n\n            self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n            self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n            if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:\n                self._init_visual_bbox(img_size=(size, size))\n\n            from functools import partial\n            norm_layer = partial(nn.LayerNorm, eps=1e-6)\n            self.norm = norm_layer(embed_dim)\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def _init_visual_bbox(self, img_size=(14, 14), max_len=1000):\n        visual_bbox_x = torch.div(torch.arange(0, max_len * (img_size[1] + 1), max_len),\n                                  img_size[1], rounding_mode='trunc')\n        visual_bbox_y = torch.div(torch.arange(0, max_len * (img_size[0] + 1), max_len),\n                                  img_size[0], rounding_mode='trunc')\n        visual_bbox = torch.stack(\n            [\n                visual_bbox_x[:-1].repeat(img_size[0], 1),\n                visual_bbox_y[:-1].repeat(img_size[1], 1).transpose(0, 1),\n                visual_bbox_x[1:].repeat(img_size[0], 1),\n                visual_bbox_y[1:].repeat(img_size[1], 1).transpose(0, 1),\n            ],\n            dim=-1,\n        ).view(-1, 4)\n\n        cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])\n        self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)\n\n    def _calc_visual_bbox(self, device, dtype, bsz):  # , img_size=(14, 14), max_len=1000):\n        visual_bbox = self.visual_bbox.repeat(bsz, 1, 1)\n        visual_bbox = visual_bbox.to(device).type(dtype)\n        return visual_bbox\n\n    def forward_image(self, x):\n        if self.detection:\n            x = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)\n        else:\n            x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        if self.pos_embed is not None and self.detection:\n            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]\n\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None and not self.detection:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        x = self.norm(x)\n        return x\n\n    # Copied from transformers.models.bert.modeling_bert.BertModel.forward\n    def forward(\n        self,\n        input_ids=None,\n        bbox=None,\n        attention_mask=None,\n        token_type_ids=None,\n        valid_span=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        images=None,\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        use_cache = False\n\n        # if input_ids is not None and inputs_embeds is not None:\n        #     raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        if input_ids is not None:\n            input_shape = input_ids.size()\n            batch_size, seq_length = input_shape\n            device = input_ids.device\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = inputs_embeds.device\n        elif images is not None:\n            batch_size = len(images)\n            device = images.device\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds or images\")\n\n        if not self.image_only:\n            # past_key_values_length\n            past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n\n            if attention_mask is None:\n                attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)\n            if token_type_ids is None:\n                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)\n\n        encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if not self.image_only:\n            if bbox is None:\n                bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)\n\n            embedding_output = self.embeddings(\n                input_ids=input_ids,\n                bbox=bbox,\n                position_ids=position_ids,\n                token_type_ids=token_type_ids,\n                inputs_embeds=inputs_embeds,\n                past_key_values_length=past_key_values_length,\n            )\n\n        final_bbox = final_position_ids = None\n        Hp = Wp = None\n        if images is not None:\n            patch_size = 16\n            Hp, Wp = int(images.shape[2] / patch_size), int(images.shape[3] / patch_size)\n            visual_emb = self.forward_image(images)\n            if self.detection:\n                visual_attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)\n                if self.image_only:\n                    attention_mask = visual_attention_mask\n                else:\n                    attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)\n            elif self.image_only:\n                attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)\n\n            if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:\n                if self.config.has_spatial_attention_bias:\n                    visual_bbox = self._calc_visual_bbox(device, dtype=torch.long, bsz=batch_size)\n                    if self.image_only:\n                        final_bbox = visual_bbox\n                    else:\n                        final_bbox = torch.cat([bbox, visual_bbox], dim=1)\n\n                visual_position_ids = torch.arange(0, visual_emb.shape[1], dtype=torch.long, device=device).repeat(\n                    batch_size, 1)\n                if self.image_only:\n                    final_position_ids = visual_position_ids\n                else:\n                    position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)\n                    position_ids = position_ids.expand_as(input_ids)\n                    final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)\n\n            if self.image_only:\n                embedding_output = visual_emb\n            else:\n                embedding_output = torch.cat([embedding_output, visual_emb], dim=1)\n            embedding_output = self.LayerNorm(embedding_output)\n            embedding_output = self.dropout(embedding_output)\n        elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:\n            if self.config.has_spatial_attention_bias:\n                final_bbox = bbox\n            if self.config.has_relative_attention_bias:\n                position_ids = self.embeddings.position_ids[:, :input_shape[1]]\n                position_ids = position_ids.expand_as(input_ids)\n                final_position_ids = position_ids\n\n        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device)\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            bbox=final_bbox,\n            position_ids=final_position_ids,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            Hp=Hp,\n            Wp=Wp,\n            valid_span=valid_span,\n        )\n\n        if self.detection:\n            return encoder_outputs\n\n        sequence_output = encoder_outputs[0]\n        pooled_output = None\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n\n\nclass LayoutLMv3ClassificationHead(nn.Module):\n    \"\"\"\n    Head for sentence-level classification tasks.\n    Reference: RobertaClassificationHead\n    \"\"\"\n\n    def __init__(self, config, pool_feature=False):\n        super().__init__()\n        self.pool_feature = pool_feature\n        if pool_feature:\n            self.dense = nn.Linear(config.hidden_size*3, config.hidden_size)\n        else:\n            self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        classifier_dropout = (\n            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob\n        )\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)\n\n    def forward(self, x):\n        # x = features[:, 0, :]  # take <s> token (equiv. to [CLS])\n        x = self.dropout(x)\n        x = self.dense(x)\n        x = torch.tanh(x)\n        x = self.dropout(x)\n        x = self.out_proj(x)\n        return x\n\n\nclass LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.layoutlmv3 = LayoutLMv3Model(config)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        if config.num_labels < 10:\n            self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n        else:\n            self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)\n\n        self.init_weights()\n\n    def forward(\n        self,\n        input_ids=None,\n        bbox=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        valid_span=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        images=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n            1]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.layoutlmv3(\n            input_ids,\n            bbox=bbox,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            images=images,\n            valid_span=valid_span,\n        )\n\n        sequence_output = outputs[0]\n\n        sequence_output = self.dropout(sequence_output)\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            # Only keep active parts of the loss\n            if attention_mask is not None:\n                active_loss = attention_mask.view(-1) == 1\n                active_logits = logits.view(-1, self.num_labels)\n                active_labels = torch.where(\n                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n                )\n                loss = loss_fct(active_logits, active_labels)\n            else:\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.layoutlmv3 = LayoutLMv3Model(config)\n        # self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)\n        self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)\n\n        self.init_weights()\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        valid_span=None,\n        head_mask=None,\n        inputs_embeds=None,\n        start_positions=None,\n        end_positions=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        bbox=None,\n        images=None,\n    ):\n        r\"\"\"\n        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.layoutlmv3(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            bbox=bbox,\n            images=images,\n            valid_span=valid_span,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions = start_positions.clamp(0, ignored_index)\n            end_positions = end_positions.clamp(0, ignored_index)\n\n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.config = config\n        self.layoutlmv3 = LayoutLMv3Model(config)\n        self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)\n\n        self.init_weights()\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        valid_span=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        bbox=None,\n        images=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.layoutlmv3(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            bbox=bbox,\n            images=images,\n            valid_span=valid_span,\n        )\n\n        sequence_output = outputs[0][:, 0, :]\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(logits, labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tokenization classes for LayoutLMv3, refer to RoBERTa.\"\"\"\n\nfrom transformers.models.roberta import RobertaTokenizer\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\n    \"vocab_file\": \"vocab.json\",\n    \"merges_file\": \"merges.txt\",\n}\n\nclass LayoutLMv3Tokenizer(RobertaTokenizer):\n    vocab_files_names = VOCAB_FILES_NAMES\n    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fast Tokenization classes for LayoutLMv3, refer to RoBERTa.\"\"\"\n\n\nfrom transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast\nfrom transformers.utils import logging\n\nfrom .tokenization_layoutlmv3 import LayoutLMv3Tokenizer\n\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"vocab.json\", \"merges_file\": \"merges.txt\", \"tokenizer_file\": \"tokenizer.json\"}\n\n\nclass LayoutLMv3TokenizerFast(RobertaTokenizerFast):\n    vocab_files_names = VOCAB_FILES_NAMES\n    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n    slow_tokenizer_class = LayoutLMv3Tokenizer\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmv3_base_inference.yaml",
    "content": "AUG:\n  DETR: true\nCACHE_DIR: ~/cache/huggingface\nCUDNN_BENCHMARK: false\nDATALOADER:\n  ASPECT_RATIO_GROUPING: true\n  FILTER_EMPTY_ANNOTATIONS: false\n  NUM_WORKERS: 4\n  REPEAT_THRESHOLD: 0.0\n  SAMPLER_TRAIN: TrainingSampler\nDATASETS:\n  PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000\n  PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000\n  PROPOSAL_FILES_TEST: []\n  PROPOSAL_FILES_TRAIN: []\n  TEST:\n  - scihub_train\n  TRAIN:\n  - scihub_train\nGLOBAL:\n  HACK: 1.0\nICDAR_DATA_DIR_TEST: ''\nICDAR_DATA_DIR_TRAIN: ''\nINPUT:\n  CROP:\n    ENABLED: true\n    SIZE:\n    - 384\n    - 600\n    TYPE: absolute_range\n  FORMAT: RGB\n  MASK_FORMAT: polygon\n  MAX_SIZE_TEST: 1333\n  MAX_SIZE_TRAIN: 1333\n  MIN_SIZE_TEST: 800\n  MIN_SIZE_TRAIN:\n  - 480\n  - 512\n  - 544\n  - 576\n  - 608\n  - 640\n  - 672\n  - 704\n  - 736\n  - 768\n  - 800\n  MIN_SIZE_TRAIN_SAMPLING: choice\n  RANDOM_FLIP: horizontal\nMODEL:\n  ANCHOR_GENERATOR:\n    ANGLES:\n    - - -90\n      - 0\n      - 90\n    ASPECT_RATIOS:\n    - - 0.5\n      - 1.0\n      - 2.0\n    NAME: DefaultAnchorGenerator\n    OFFSET: 0.0\n    SIZES:\n    - - 32\n    - - 64\n    - - 128\n    - - 256\n    - - 512\n  BACKBONE:\n    FREEZE_AT: 2\n    NAME: build_vit_fpn_backbone\n  CONFIG_PATH: ''\n  DEVICE: cuda\n  FPN:\n    FUSE_TYPE: sum\n    IN_FEATURES:\n    - layer3\n    - layer5\n    - layer7\n    - layer11\n    NORM: ''\n    OUT_CHANNELS: 256\n  IMAGE_ONLY: true\n  KEYPOINT_ON: false\n  LOAD_PROPOSALS: false\n  MASK_ON: true\n  META_ARCHITECTURE: VLGeneralizedRCNN\n  PANOPTIC_FPN:\n    COMBINE:\n      ENABLED: true\n      INSTANCES_CONFIDENCE_THRESH: 0.5\n      OVERLAP_THRESH: 0.5\n      STUFF_AREA_LIMIT: 4096\n    INSTANCE_LOSS_WEIGHT: 1.0\n  PIXEL_MEAN:\n  - 127.5\n  - 127.5\n  - 127.5\n  PIXEL_STD:\n  - 127.5\n  - 127.5\n  - 127.5\n  PROPOSAL_GENERATOR:\n    MIN_SIZE: 0\n    NAME: RPN\n  RESNETS:\n    DEFORM_MODULATED: false\n    DEFORM_NUM_GROUPS: 1\n    DEFORM_ON_PER_STAGE:\n    - false\n    - false\n    - false\n    - false\n    DEPTH: 50\n    NORM: FrozenBN\n    NUM_GROUPS: 1\n    OUT_FEATURES:\n    - res4\n    RES2_OUT_CHANNELS: 256\n    RES5_DILATION: 1\n    STEM_OUT_CHANNELS: 64\n    STRIDE_IN_1X1: true\n    WIDTH_PER_GROUP: 64\n  RETINANET:\n    BBOX_REG_LOSS_TYPE: smooth_l1\n    BBOX_REG_WEIGHTS:\n    - 1.0\n    - 1.0\n    - 1.0\n    - 1.0\n    FOCAL_LOSS_ALPHA: 0.25\n    FOCAL_LOSS_GAMMA: 2.0\n    IN_FEATURES:\n    - p3\n    - p4\n    - p5\n    - p6\n    - p7\n    IOU_LABELS:\n    - 0\n    - -1\n    - 1\n    IOU_THRESHOLDS:\n    - 0.4\n    - 0.5\n    NMS_THRESH_TEST: 0.5\n    NORM: ''\n    NUM_CLASSES: 10\n    NUM_CONVS: 4\n    PRIOR_PROB: 0.01\n    SCORE_THRESH_TEST: 0.05\n    SMOOTH_L1_LOSS_BETA: 0.1\n    TOPK_CANDIDATES_TEST: 1000\n  ROI_BOX_CASCADE_HEAD:\n    BBOX_REG_WEIGHTS:\n    - - 10.0\n      - 10.0\n      - 5.0\n      - 5.0\n    - - 20.0\n      - 20.0\n      - 10.0\n      - 10.0\n    - - 30.0\n      - 30.0\n      - 15.0\n      - 15.0\n    IOUS:\n    - 0.5\n    - 0.6\n    - 0.7\n  ROI_BOX_HEAD:\n    BBOX_REG_LOSS_TYPE: smooth_l1\n    BBOX_REG_LOSS_WEIGHT: 1.0\n    BBOX_REG_WEIGHTS:\n    - 10.0\n    - 10.0\n    - 5.0\n    - 5.0\n    CLS_AGNOSTIC_BBOX_REG: true\n    CONV_DIM: 256\n    FC_DIM: 1024\n    NAME: FastRCNNConvFCHead\n    NORM: ''\n    NUM_CONV: 0\n    NUM_FC: 2\n    POOLER_RESOLUTION: 7\n    POOLER_SAMPLING_RATIO: 0\n    POOLER_TYPE: ROIAlignV2\n    SMOOTH_L1_BETA: 0.0\n    TRAIN_ON_PRED_BOXES: false\n  ROI_HEADS:\n    BATCH_SIZE_PER_IMAGE: 512\n    IN_FEATURES:\n    - p2\n    - p3\n    - p4\n    - p5\n    IOU_LABELS:\n    - 0\n    - 1\n    IOU_THRESHOLDS:\n    - 0.5\n    NAME: CascadeROIHeads\n    NMS_THRESH_TEST: 0.5\n    NUM_CLASSES: 10\n    POSITIVE_FRACTION: 0.25\n    PROPOSAL_APPEND_GT: true\n    SCORE_THRESH_TEST: 0.05\n  ROI_KEYPOINT_HEAD:\n    CONV_DIMS:\n    - 512\n    - 512\n    - 512\n    - 512\n    - 512\n    - 512\n    - 512\n    - 512\n    LOSS_WEIGHT: 1.0\n    MIN_KEYPOINTS_PER_IMAGE: 1\n    NAME: KRCNNConvDeconvUpsampleHead\n    NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: true\n    NUM_KEYPOINTS: 17\n    POOLER_RESOLUTION: 14\n    POOLER_SAMPLING_RATIO: 0\n    POOLER_TYPE: ROIAlignV2\n  ROI_MASK_HEAD:\n    CLS_AGNOSTIC_MASK: false\n    CONV_DIM: 256\n    NAME: MaskRCNNConvUpsampleHead\n    NORM: ''\n    NUM_CONV: 4\n    POOLER_RESOLUTION: 14\n    POOLER_SAMPLING_RATIO: 0\n    POOLER_TYPE: ROIAlignV2\n  RPN:\n    BATCH_SIZE_PER_IMAGE: 256\n    BBOX_REG_LOSS_TYPE: smooth_l1\n    BBOX_REG_LOSS_WEIGHT: 1.0\n    BBOX_REG_WEIGHTS:\n    - 1.0\n    - 1.0\n    - 1.0\n    - 1.0\n    BOUNDARY_THRESH: -1\n    CONV_DIMS:\n    - -1\n    HEAD_NAME: StandardRPNHead\n    IN_FEATURES:\n    - p2\n    - p3\n    - p4\n    - p5\n    - p6\n    IOU_LABELS:\n    - 0\n    - -1\n    - 1\n    IOU_THRESHOLDS:\n    - 0.3\n    - 0.7\n    LOSS_WEIGHT: 1.0\n    NMS_THRESH: 0.7\n    POSITIVE_FRACTION: 0.5\n    POST_NMS_TOPK_TEST: 1000\n    POST_NMS_TOPK_TRAIN: 2000\n    PRE_NMS_TOPK_TEST: 1000\n    PRE_NMS_TOPK_TRAIN: 2000\n    SMOOTH_L1_BETA: 0.0\n  SEM_SEG_HEAD:\n    COMMON_STRIDE: 4\n    CONVS_DIM: 128\n    IGNORE_VALUE: 255\n    IN_FEATURES:\n    - p2\n    - p3\n    - p4\n    - p5\n    LOSS_WEIGHT: 1.0\n    NAME: SemSegFPNHead\n    NORM: GN\n    NUM_CLASSES: 10\n  VIT:\n    DROP_PATH: 0.1\n    IMG_SIZE:\n    - 224\n    - 224\n    NAME: layoutlmv3_base\n    OUT_FEATURES:\n    - layer3\n    - layer5\n    - layer7\n    - layer11\n    POS_TYPE: abs\n  WEIGHTS: \nOUTPUT_DIR: \nSCIHUB_DATA_DIR_TRAIN: ~/publaynet/layout_scihub/train\nSEED: 42\nSOLVER:\n  AMP:\n    ENABLED: true\n  BACKBONE_MULTIPLIER: 1.0\n  BASE_LR: 0.0002\n  BIAS_LR_FACTOR: 1.0\n  CHECKPOINT_PERIOD: 2000\n  CLIP_GRADIENTS:\n    CLIP_TYPE: full_model\n    CLIP_VALUE: 1.0\n    ENABLED: true\n    NORM_TYPE: 2.0\n  GAMMA: 0.1\n  GRADIENT_ACCUMULATION_STEPS: 1\n  IMS_PER_BATCH: 32\n  LR_SCHEDULER_NAME: WarmupCosineLR\n  MAX_ITER: 20000\n  MOMENTUM: 0.9\n  NESTEROV: false\n  OPTIMIZER: ADAMW\n  REFERENCE_WORLD_SIZE: 0\n  STEPS:\n  - 10000\n  WARMUP_FACTOR: 0.01\n  WARMUP_ITERS: 333\n  WARMUP_METHOD: linear\n  WEIGHT_DECAY: 0.05\n  WEIGHT_DECAY_BIAS: null\n  WEIGHT_DECAY_NORM: 0.0\nTEST:\n  AUG:\n    ENABLED: false\n    FLIP: true\n    MAX_SIZE: 4000\n    MIN_SIZES:\n    - 400\n    - 500\n    - 600\n    - 700\n    - 800\n    - 900\n    - 1000\n    - 1100\n    - 1200\n  DETECTIONS_PER_IMAGE: 100\n  EVAL_PERIOD: 1000\n  EXPECTED_RESULTS: []\n  KEYPOINT_OKS_SIGMAS: []\n  PRECISE_BN:\n    ENABLED: false\n    NUM_ITER: 200\nVERSION: 2\nVIS_PERIOD: 0\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/model_init.py",
    "content": "from .visualizer import Visualizer\nfrom .rcnn_vl import *\nfrom .backbone import *\n\nfrom detectron2.config import get_cfg\nfrom detectron2.config import CfgNode as CN\nfrom detectron2.data import MetadataCatalog, DatasetCatalog\nfrom detectron2.data.datasets import register_coco_instances\nfrom detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor\n\ndef add_vit_config(cfg):\n    \"\"\"\n    Add config for VIT.\n    \"\"\"\n    _C = cfg\n\n    _C.MODEL.VIT = CN()\n\n    # CoaT model name.\n    _C.MODEL.VIT.NAME = \"\"\n\n    # Output features from CoaT backbone.\n    _C.MODEL.VIT.OUT_FEATURES = [\"layer3\", \"layer5\", \"layer7\", \"layer11\"]\n\n    _C.MODEL.VIT.IMG_SIZE = [224, 224]\n\n    _C.MODEL.VIT.POS_TYPE = \"shared_rel\"\n\n    _C.MODEL.VIT.DROP_PATH = 0.\n\n    _C.MODEL.VIT.MODEL_KWARGS = \"{}\"\n\n    _C.SOLVER.OPTIMIZER = \"ADAMW\"\n\n    _C.SOLVER.BACKBONE_MULTIPLIER = 1.0\n\n    _C.AUG = CN()\n\n    _C.AUG.DETR = False\n\n    _C.MODEL.IMAGE_ONLY = True\n    _C.PUBLAYNET_DATA_DIR_TRAIN = \"\"\n    _C.PUBLAYNET_DATA_DIR_TEST = \"\"\n    _C.FOOTNOTE_DATA_DIR_TRAIN = \"\"\n    _C.FOOTNOTE_DATA_DIR_VAL = \"\"\n    _C.SCIHUB_DATA_DIR_TRAIN = \"\"\n    _C.SCIHUB_DATA_DIR_TEST = \"\"\n    _C.JIAOCAI_DATA_DIR_TRAIN = \"\"\n    _C.JIAOCAI_DATA_DIR_TEST = \"\"\n    _C.ICDAR_DATA_DIR_TRAIN = \"\"\n    _C.ICDAR_DATA_DIR_TEST = \"\"\n    _C.M6DOC_DATA_DIR_TEST = \"\"\n    _C.DOCSTRUCTBENCH_DATA_DIR_TEST = \"\"\n    _C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = \"\"\n    _C.CACHE_DIR = \"\"\n    _C.MODEL.CONFIG_PATH = \"\"\n\n    # effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS\n    # maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS\n    _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1\n\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = get_cfg()\n    # add_coat_config(cfg)\n    add_vit_config(cfg)\n    cfg.merge_from_file(args.config_file)\n    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2  # set threshold for this model\n    cfg.merge_from_list(args.opts)\n    cfg.freeze()\n    default_setup(cfg, args)\n    \n    register_coco_instances(\n        \"scihub_train\",\n        {},\n        cfg.SCIHUB_DATA_DIR_TRAIN + \".json\",\n        cfg.SCIHUB_DATA_DIR_TRAIN\n    )\n    \n    return cfg\n\n\nclass DotDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(DotDict, self).__init__(*args, **kwargs)\n\n    def __getattr__(self, key):\n        if key not in self.keys():\n            return None\n        value = self[key]\n        if isinstance(value, dict):\n            value = DotDict(value)\n        return value\n    \n    def __setattr__(self, key, value):\n        self[key] = value\n        \nclass Layoutlmv3_Predictor(object):\n    def __init__(self, weights):\n        layout_args = {\n            \"config_file\": \"pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/layoutlmv3_base_inference.yaml\",\n            \"resume\": False,\n            \"eval_only\": False,\n            \"num_gpus\": 1,\n            \"num_machines\": 1,\n            \"machine_rank\": 0,\n            \"dist_url\": \"tcp://127.0.0.1:57823\",\n            \"opts\": [\"MODEL.WEIGHTS\", weights],\n        }\n        layout_args = DotDict(layout_args)\n\n        cfg = setup(layout_args)\n        self.mapping = [\"title\", \"plain text\", \"abandon\", \"figure\", \"figure_caption\", \"table\", \"table_caption\", \"table_footnote\", \"isolate_formula\", \"formula_caption\"]\n        MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping\n        self.predictor = DefaultPredictor(cfg)\n        \n    def __call__(self, image, ignore_catids=[]):\n        page_layout_result = {\n            \"layout_dets\": []\n        }\n        outputs = self.predictor(image)\n        boxes = outputs[\"instances\"].to(\"cpu\")._fields[\"pred_boxes\"].tensor.tolist()\n        labels = outputs[\"instances\"].to(\"cpu\")._fields[\"pred_classes\"].tolist()\n        scores = outputs[\"instances\"].to(\"cpu\")._fields[\"scores\"].tolist()\n        for bbox_idx in range(len(boxes)):\n            if labels[bbox_idx] in ignore_catids:\n                continue\n            page_layout_result[\"layout_dets\"].append({\n                \"category_id\": labels[bbox_idx],\n                \"poly\": [\n                    boxes[bbox_idx][0], boxes[bbox_idx][1],\n                    boxes[bbox_idx][2], boxes[bbox_idx][1],\n                    boxes[bbox_idx][2], boxes[bbox_idx][3],\n                    boxes[bbox_idx][0], boxes[bbox_idx][3],\n                ],\n                \"score\": scores[bbox_idx]\n            })\n        return page_layout_result"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/rcnn_vl.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport logging\nimport numpy as np\nfrom typing import Dict, List, Optional, Tuple\nimport torch\nfrom torch import nn\n\nfrom detectron2.config import configurable\nfrom detectron2.structures import ImageList, Instances\nfrom detectron2.utils.events import get_event_storage\n\nfrom detectron2.modeling.backbone import Backbone, build_backbone\nfrom detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY\n\nfrom detectron2.modeling.meta_arch import GeneralizedRCNN\n\nfrom detectron2.modeling.postprocessing import detector_postprocess\nfrom detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image\nfrom contextlib import contextmanager\nfrom itertools import count\n\n@META_ARCH_REGISTRY.register()\nclass VLGeneralizedRCNN(GeneralizedRCNN):\n    \"\"\"\n    Generalized R-CNN. Any models that contains the following three components:\n    1. Per-image feature extraction (aka backbone)\n    2. Region proposal generation\n    3. Per-region feature extraction and prediction\n    \"\"\"\n\n    def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):\n        \"\"\"\n        Args:\n            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .\n                Each item in the list contains the inputs for one image.\n                For now, each item in the list is a dict that contains:\n\n                * image: Tensor, image in (C, H, W) format.\n                * instances (optional): groundtruth :class:`Instances`\n                * proposals (optional): :class:`Instances`, precomputed proposals.\n\n                Other information that's included in the original dicts, such as:\n\n                * \"height\", \"width\" (int): the output resolution of the model, used in inference.\n                  See :meth:`postprocess` for details.\n\n        Returns:\n            list[dict]:\n                Each dict is the output for one input image.\n                The dict contains one key \"instances\" whose value is a :class:`Instances`.\n                The :class:`Instances` object has the following keys:\n                \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\", \"pred_keypoints\"\n        \"\"\"\n        if not self.training:\n            return self.inference(batched_inputs)\n\n        images = self.preprocess_image(batched_inputs)\n        if \"instances\" in batched_inputs[0]:\n            gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n        else:\n            gt_instances = None\n\n        # features = self.backbone(images.tensor)\n        input = self.get_batch(batched_inputs, images)\n        features = self.backbone(input)\n\n        if self.proposal_generator is not None:\n            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)\n        else:\n            assert \"proposals\" in batched_inputs[0]\n            proposals = [x[\"proposals\"].to(self.device) for x in batched_inputs]\n            proposal_losses = {}\n\n        _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)\n        if self.vis_period > 0:\n            storage = get_event_storage()\n            if storage.iter % self.vis_period == 0:\n                self.visualize_training(batched_inputs, proposals)\n\n        losses = {}\n        losses.update(detector_losses)\n        losses.update(proposal_losses)\n        return losses\n\n    def inference(\n        self,\n        batched_inputs: List[Dict[str, torch.Tensor]],\n        detected_instances: Optional[List[Instances]] = None,\n        do_postprocess: bool = True,\n    ):\n        \"\"\"\n        Run inference on the given inputs.\n\n        Args:\n            batched_inputs (list[dict]): same as in :meth:`forward`\n            detected_instances (None or list[Instances]): if not None, it\n                contains an `Instances` object per image. The `Instances`\n                object contains \"pred_boxes\" and \"pred_classes\" which are\n                known boxes in the image.\n                The inference will then skip the detection of bounding boxes,\n                and only predict other per-ROI outputs.\n            do_postprocess (bool): whether to apply post-processing on the outputs.\n\n        Returns:\n            When do_postprocess=True, same as in :meth:`forward`.\n            Otherwise, a list[Instances] containing raw network outputs.\n        \"\"\"\n        assert not self.training\n\n        images = self.preprocess_image(batched_inputs)\n        # features = self.backbone(images.tensor)\n        input = self.get_batch(batched_inputs, images)\n        features = self.backbone(input)\n\n        if detected_instances is None:\n            if self.proposal_generator is not None:\n                proposals, _ = self.proposal_generator(images, features, None)\n            else:\n                assert \"proposals\" in batched_inputs[0]\n                proposals = [x[\"proposals\"].to(self.device) for x in batched_inputs]\n\n            results, _ = self.roi_heads(images, features, proposals, None)\n        else:\n            detected_instances = [x.to(self.device) for x in detected_instances]\n            results = self.roi_heads.forward_with_given_boxes(features, detected_instances)\n\n        if do_postprocess:\n            assert not torch.jit.is_scripting(), \"Scripting is not supported for postprocess.\"\n            return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)\n        else:\n            return results\n\n    def get_batch(self, examples, images):\n        if len(examples) >= 1 and \"bbox\" not in examples[0]:  # image_only\n            return {\"images\": images.tensor}\n\n        return input\n\n    def _batch_inference(self, batched_inputs, detected_instances=None):\n        \"\"\"\n        Execute inference on a list of inputs,\n        using batch size = self.batch_size (e.g., 2), instead of the length of the list.\n\n        Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`\n        \"\"\"\n        if detected_instances is None:\n            detected_instances = [None] * len(batched_inputs)\n\n        outputs = []\n        inputs, instances = [], []\n        for idx, input, instance in zip(count(), batched_inputs, detected_instances):\n            inputs.append(input)\n            instances.append(instance)\n            if len(inputs) == 2 or idx == len(batched_inputs) - 1:\n                outputs.extend(\n                    self.inference(\n                        inputs,\n                        instances if instances[0] is not None else None,\n                        do_postprocess=True,  # False\n                    )\n                )\n                inputs, instances = [], []\n        return outputs\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/layoutlmv3_util/visualizer.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport colorsys\nimport logging\nimport math\nimport numpy as np\nfrom enum import Enum, unique\nimport cv2\nimport matplotlib as mpl\nimport matplotlib.colors as mplc\nimport matplotlib.figure as mplfigure\nimport pycocotools.mask as mask_util\nimport torch\nfrom matplotlib.backends.backend_agg import FigureCanvasAgg\nfrom PIL import Image\n\nfrom detectron2.data import MetadataCatalog\nfrom detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes\nfrom detectron2.utils.file_io import PathManager\n\nfrom detectron2.utils.colormap import random_color\n\nimport pdb\n\nlogger = logging.getLogger(__name__)\n\n__all__ = [\"ColorMode\", \"VisImage\", \"Visualizer\"]\n\n\n_SMALL_OBJECT_AREA_THRESH = 1000\n_LARGE_MASK_AREA_THRESH = 120000\n_OFF_WHITE = (1.0, 1.0, 240.0 / 255)\n_BLACK = (0, 0, 0)\n_RED = (1.0, 0, 0)\n\n_KEYPOINT_THRESHOLD = 0.05\n\n#CLASS_NAMES = [\"footnote\", \"footer\", \"header\"]\n\n@unique\nclass ColorMode(Enum):\n    \"\"\"\n    Enum of different color modes to use for instance visualizations.\n    \"\"\"\n\n    IMAGE = 0\n    \"\"\"\n    Picks a random color for every instance and overlay segmentations with low opacity.\n    \"\"\"\n    SEGMENTATION = 1\n    \"\"\"\n    Let instances of the same category have similar colors\n    (from metadata.thing_colors), and overlay them with\n    high opacity. This provides more attention on the quality of segmentation.\n    \"\"\"\n    IMAGE_BW = 2\n    \"\"\"\n    Same as IMAGE, but convert all areas without masks to gray-scale.\n    Only available for drawing per-instance mask predictions.\n    \"\"\"\n\n\nclass GenericMask:\n    \"\"\"\n    Attribute:\n        polygons (list[ndarray]): list[ndarray]: polygons for this mask.\n            Each ndarray has format [x, y, x, y, ...]\n        mask (ndarray): a binary mask\n    \"\"\"\n\n    def __init__(self, mask_or_polygons, height, width):\n        self._mask = self._polygons = self._has_holes = None\n        self.height = height\n        self.width = width\n\n        m = mask_or_polygons\n        if isinstance(m, dict):\n            # RLEs\n            assert \"counts\" in m and \"size\" in m\n            if isinstance(m[\"counts\"], list):  # uncompressed RLEs\n                h, w = m[\"size\"]\n                assert h == height and w == width\n                m = mask_util.frPyObjects(m, h, w)\n            self._mask = mask_util.decode(m)[:, :]\n            return\n\n        if isinstance(m, list):  # list[ndarray]\n            self._polygons = [np.asarray(x).reshape(-1) for x in m]\n            return\n\n        if isinstance(m, np.ndarray):  # assumed to be a binary mask\n            assert m.shape[1] != 2, m.shape\n            assert m.shape == (\n                height,\n                width,\n            ), f\"mask shape: {m.shape}, target dims: {height}, {width}\"\n            self._mask = m.astype(\"uint8\")\n            return\n\n        raise ValueError(\"GenericMask cannot handle object {} of type '{}'\".format(m, type(m)))\n\n    @property\n    def mask(self):\n        if self._mask is None:\n            self._mask = self.polygons_to_mask(self._polygons)\n        return self._mask\n\n    @property\n    def polygons(self):\n        if self._polygons is None:\n            self._polygons, self._has_holes = self.mask_to_polygons(self._mask)\n        return self._polygons\n\n    @property\n    def has_holes(self):\n        if self._has_holes is None:\n            if self._mask is not None:\n                self._polygons, self._has_holes = self.mask_to_polygons(self._mask)\n            else:\n                self._has_holes = False  # if original format is polygon, does not have holes\n        return self._has_holes\n\n    def mask_to_polygons(self, mask):\n        # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level\n        # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.\n        # Internal contours (holes) are placed in hierarchy-2.\n        # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.\n        mask = np.ascontiguousarray(mask)  # some versions of cv2 does not support incontiguous arr\n        res = cv2.findContours(mask.astype(\"uint8\"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)\n        hierarchy = res[-1]\n        if hierarchy is None:  # empty mask\n            return [], False\n        has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0\n        res = res[-2]\n        res = [x.flatten() for x in res]\n        # These coordinates from OpenCV are integers in range [0, W-1 or H-1].\n        # We add 0.5 to turn them into real-value coordinate space. A better solution\n        # would be to first +0.5 and then dilate the returned polygon by 0.5.\n        res = [x + 0.5 for x in res if len(x) >= 6]\n        return res, has_holes\n\n    def polygons_to_mask(self, polygons):\n        rle = mask_util.frPyObjects(polygons, self.height, self.width)\n        rle = mask_util.merge(rle)\n        return mask_util.decode(rle)[:, :]\n\n    def area(self):\n        return self.mask.sum()\n\n    def bbox(self):\n        p = mask_util.frPyObjects(self.polygons, self.height, self.width)\n        p = mask_util.merge(p)\n        bbox = mask_util.toBbox(p)\n        bbox[2] += bbox[0]\n        bbox[3] += bbox[1]\n        return bbox\n\n\nclass _PanopticPrediction:\n    \"\"\"\n    Unify different panoptic annotation/prediction formats\n    \"\"\"\n\n    def __init__(self, panoptic_seg, segments_info, metadata=None):\n        if segments_info is None:\n            assert metadata is not None\n            # If \"segments_info\" is None, we assume \"panoptic_img\" is a\n            # H*W int32 image storing the panoptic_id in the format of\n            # category_id * label_divisor + instance_id. We reserve -1 for\n            # VOID label.\n            label_divisor = metadata.label_divisor\n            segments_info = []\n            for panoptic_label in np.unique(panoptic_seg.numpy()):\n                if panoptic_label == -1:\n                    # VOID region.\n                    continue\n                pred_class = panoptic_label // label_divisor\n                isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()\n                segments_info.append(\n                    {\n                        \"id\": int(panoptic_label),\n                        \"category_id\": int(pred_class),\n                        \"isthing\": bool(isthing),\n                    }\n                )\n        del metadata\n\n        self._seg = panoptic_seg\n\n        self._sinfo = {s[\"id\"]: s for s in segments_info}  # seg id -> seg info\n        segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)\n        areas = areas.numpy()\n        sorted_idxs = np.argsort(-areas)\n        self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]\n        self._seg_ids = self._seg_ids.tolist()\n        for sid, area in zip(self._seg_ids, self._seg_areas):\n            if sid in self._sinfo:\n                self._sinfo[sid][\"area\"] = float(area)\n\n    def non_empty_mask(self):\n        \"\"\"\n        Returns:\n            (H, W) array, a mask for all pixels that have a prediction\n        \"\"\"\n        empty_ids = []\n        for id in self._seg_ids:\n            if id not in self._sinfo:\n                empty_ids.append(id)\n        if len(empty_ids) == 0:\n            return np.zeros(self._seg.shape, dtype=np.uint8)\n        assert (\n            len(empty_ids) == 1\n        ), \">1 ids corresponds to no labels. This is currently not supported\"\n        return (self._seg != empty_ids[0]).numpy().astype(np.bool)\n\n    def semantic_masks(self):\n        for sid in self._seg_ids:\n            sinfo = self._sinfo.get(sid)\n            if sinfo is None or sinfo[\"isthing\"]:\n                # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.\n                continue\n            yield (self._seg == sid).numpy().astype(np.bool), sinfo\n\n    def instance_masks(self):\n        for sid in self._seg_ids:\n            sinfo = self._sinfo.get(sid)\n            if sinfo is None or not sinfo[\"isthing\"]:\n                continue\n            mask = (self._seg == sid).numpy().astype(np.bool)\n            if mask.sum() > 0:\n                yield mask, sinfo\n\n\ndef _create_text_labels(classes, scores, class_names, is_crowd=None):\n    \"\"\"\n    Args:\n        classes (list[int] or None):\n        scores (list[float] or None):\n        class_names (list[str] or None):\n        is_crowd (list[bool] or None):\n\n    Returns:\n        list[str] or None\n    \"\"\"\n    #class_names = CLASS_NAMES\n    labels = None\n    if classes is not None:\n        if class_names is not None and len(class_names) > 0:\n            labels = [class_names[i] for i in classes]\n        else:\n            labels = [str(i) for i in classes]\n            \n    if scores is not None:\n        if labels is None:\n            labels = [\"{:.0f}%\".format(s * 100) for s in scores]\n        else:\n            labels = [\"{} {:.0f}%\".format(l, s * 100) for l, s in zip(labels, scores)]\n    if labels is not None and is_crowd is not None:\n        labels = [l + (\"|crowd\" if crowd else \"\") for l, crowd in zip(labels, is_crowd)]\n    return labels\n\n\nclass VisImage:\n    def __init__(self, img, scale=1.0):\n        \"\"\"\n        Args:\n            img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].\n            scale (float): scale the input image\n        \"\"\"\n        self.img = img\n        self.scale = scale\n        self.width, self.height = img.shape[1], img.shape[0]\n        self._setup_figure(img)\n\n    def _setup_figure(self, img):\n        \"\"\"\n        Args:\n            Same as in :meth:`__init__()`.\n\n        Returns:\n            fig (matplotlib.pyplot.figure): top level container for all the image plot elements.\n            ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.\n        \"\"\"\n        fig = mplfigure.Figure(frameon=False)\n        self.dpi = fig.get_dpi()\n        # add a small 1e-2 to avoid precision lost due to matplotlib's truncation\n        # (https://github.com/matplotlib/matplotlib/issues/15363)\n        fig.set_size_inches(\n            (self.width * self.scale + 1e-2) / self.dpi,\n            (self.height * self.scale + 1e-2) / self.dpi,\n        )\n        self.canvas = FigureCanvasAgg(fig)\n        # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)\n        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])\n        ax.axis(\"off\")\n        self.fig = fig\n        self.ax = ax\n        self.reset_image(img)\n\n    def reset_image(self, img):\n        \"\"\"\n        Args:\n            img: same as in __init__\n        \"\"\"\n        img = img.astype(\"uint8\")\n        self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation=\"nearest\")\n\n    def save(self, filepath):\n        \"\"\"\n        Args:\n            filepath (str): a string that contains the absolute path, including the file name, where\n                the visualized image will be saved.\n        \"\"\"\n        self.fig.savefig(filepath)\n\n    def get_image(self):\n        \"\"\"\n        Returns:\n            ndarray:\n                the visualized image of shape (H, W, 3) (RGB) in uint8 type.\n                The shape is scaled w.r.t the input image using the given `scale` argument.\n        \"\"\"\n        canvas = self.canvas\n        s, (width, height) = canvas.print_to_buffer()\n        # buf = io.BytesIO()  # works for cairo backend\n        # canvas.print_rgba(buf)\n        # width, height = self.width, self.height\n        # s = buf.getvalue()\n\n        buffer = np.frombuffer(s, dtype=\"uint8\")\n\n        img_rgba = buffer.reshape(height, width, 4)\n        rgb, alpha = np.split(img_rgba, [3], axis=2)\n        return rgb.astype(\"uint8\")\n\n\nclass Visualizer:\n    \"\"\"\n    Visualizer that draws data about detection/segmentation on images.\n\n    It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`\n    that draw primitive objects to images, as well as high-level wrappers like\n    `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`\n    that draw composite data in some pre-defined style.\n\n    Note that the exact visualization style for the high-level wrappers are subject to change.\n    Style such as color, opacity, label contents, visibility of labels, or even the visibility\n    of objects themselves (e.g. when the object is too small) may change according\n    to different heuristics, as long as the results still look visually reasonable.\n\n    To obtain a consistent style, you can implement custom drawing functions with the\n    abovementioned primitive methods instead. If you need more customized visualization\n    styles, you can process the data yourself following their format documented in\n    tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not\n    intend to satisfy everyone's preference on drawing styles.\n\n    This visualizer focuses on high rendering quality rather than performance. It is not\n    designed to be used for real-time applications.\n    \"\"\"\n\n    # TODO implement a fast, rasterized version using OpenCV\n\n    def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):\n        \"\"\"\n        Args:\n            img_rgb: a numpy array of shape (H, W, C), where H and W correspond to\n                the height and width of the image respectively. C is the number of\n                color channels. The image is required to be in RGB format since that\n                is a requirement of the Matplotlib library. The image is also expected\n                to be in the range [0, 255].\n            metadata (Metadata): dataset metadata (e.g. class names and colors)\n            instance_mode (ColorMode): defines one of the pre-defined style for drawing\n                instances on an image.\n        \"\"\"\n        self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)\n        if metadata is None:\n            metadata = MetadataCatalog.get(\"__nonexist__\")\n        self.metadata = metadata\n        self.output = VisImage(self.img, scale=scale)\n        self.cpu_device = torch.device(\"cpu\")\n\n        # too small texts are useless, therefore clamp to 9\n        self._default_font_size = max(\n            np.sqrt(self.output.height * self.output.width) // 90, 10 // scale\n        )\n        self._instance_mode = instance_mode\n        self.keypoint_threshold = _KEYPOINT_THRESHOLD\n\n    def draw_instance_predictions(self, predictions):\n        \"\"\"\n        Draw instance-level prediction results on an image.\n\n        Args:\n            predictions (Instances): the output of an instance detection/segmentation\n                model. Following fields will be used to draw:\n                \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\" (or \"pred_masks_rle\").\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        boxes = predictions.pred_boxes if predictions.has(\"pred_boxes\") else None\n        scores = predictions.scores if predictions.has(\"scores\") else None\n        classes = predictions.pred_classes.tolist() if predictions.has(\"pred_classes\") else None\n        labels = _create_text_labels(classes, scores, self.metadata.get(\"thing_classes\", None))\n        keypoints = predictions.pred_keypoints if predictions.has(\"pred_keypoints\") else None\n\n        if predictions.has(\"pred_masks\"):\n            masks = np.asarray(predictions.pred_masks)\n            masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]\n        else:\n            masks = None\n\n        if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(\"thing_colors\"):\n            colors = [\n                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes\n            ]\n            alpha = 0.8\n        else:\n            colors = None\n            alpha = 0.5\n\n        if self._instance_mode == ColorMode.IMAGE_BW:\n            self.output.reset_image(\n                self._create_grayscale_image(\n                    (predictions.pred_masks.any(dim=0) > 0).numpy()\n                    if predictions.has(\"pred_masks\")\n                    else None\n                )\n            )\n            alpha = 0.3\n\n        self.overlay_instances(\n            masks=masks,\n            boxes=boxes,\n            labels=labels,\n            keypoints=keypoints,\n            assigned_colors=colors,\n            alpha=alpha,\n        )\n        return self.output\n\n    def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):\n        \"\"\"\n        Draw semantic segmentation predictions/labels.\n\n        Args:\n            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).\n                Each value is the integer label of the pixel.\n            area_threshold (int): segments with less than `area_threshold` are not drawn.\n            alpha (float): the larger it is, the more opaque the segmentations are.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        if isinstance(sem_seg, torch.Tensor):\n            sem_seg = sem_seg.numpy()\n        labels, areas = np.unique(sem_seg, return_counts=True)\n        sorted_idxs = np.argsort(-areas).tolist()\n        labels = labels[sorted_idxs]\n        for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):\n            try:\n                mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]\n            except (AttributeError, IndexError):\n                mask_color = None\n\n            binary_mask = (sem_seg == label).astype(np.uint8)\n            text = self.metadata.stuff_classes[label]\n            self.draw_binary_mask(\n                binary_mask,\n                color=mask_color,\n                edge_color=_OFF_WHITE,\n                text=text,\n                alpha=alpha,\n                area_threshold=area_threshold,\n            )\n        return self.output\n\n    def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):\n        \"\"\"\n        Draw panoptic prediction annotations or results.\n\n        Args:\n            panoptic_seg (Tensor): of shape (height, width) where the values are ids for each\n                segment.\n            segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.\n                If it is a ``list[dict]``, each dict contains keys \"id\", \"category_id\".\n                If None, category id of each pixel is computed by\n                ``pixel // metadata.label_divisor``.\n            area_threshold (int): stuff segments with less than `area_threshold` are not drawn.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)\n\n        if self._instance_mode == ColorMode.IMAGE_BW:\n            self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))\n\n        # draw mask for all semantic segments first i.e. \"stuff\"\n        for mask, sinfo in pred.semantic_masks():\n            category_idx = sinfo[\"category_id\"]\n            try:\n                mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]\n            except AttributeError:\n                mask_color = None\n\n            text = self.metadata.stuff_classes[category_idx]\n            self.draw_binary_mask(\n                mask,\n                color=mask_color,\n                edge_color=_OFF_WHITE,\n                text=text,\n                alpha=alpha,\n                area_threshold=area_threshold,\n            )\n\n        # draw mask for all instances second\n        all_instances = list(pred.instance_masks())\n        if len(all_instances) == 0:\n            return self.output\n        masks, sinfo = list(zip(*all_instances))\n        category_ids = [x[\"category_id\"] for x in sinfo]\n\n        try:\n            scores = [x[\"score\"] for x in sinfo]\n        except KeyError:\n            scores = None\n        labels = _create_text_labels(\n            category_ids, scores, self.metadata.thing_classes, [x.get(\"iscrowd\", 0) for x in sinfo]\n        )\n\n        try:\n            colors = [\n                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids\n            ]\n        except AttributeError:\n            colors = None\n        self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)\n\n        return self.output\n\n    draw_panoptic_seg_predictions = draw_panoptic_seg  # backward compatibility\n\n    def draw_dataset_dict(self, dic):\n        \"\"\"\n        Draw annotations/segmentaions in Detectron2 Dataset format.\n\n        Args:\n            dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        annos = dic.get(\"annotations\", None)\n        if annos:\n            if \"segmentation\" in annos[0]:\n                masks = [x[\"segmentation\"] for x in annos]\n            else:\n                masks = None\n            if \"keypoints\" in annos[0]:\n                keypts = [x[\"keypoints\"] for x in annos]\n                keypts = np.array(keypts).reshape(len(annos), -1, 3)\n            else:\n                keypts = None\n\n            boxes = [\n                BoxMode.convert(x[\"bbox\"], x[\"bbox_mode\"], BoxMode.XYXY_ABS)\n                if len(x[\"bbox\"]) == 4\n                else x[\"bbox\"]\n                for x in annos\n            ]\n\n            colors = None\n            category_ids = [x[\"category_id\"] for x in annos]\n            if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(\"thing_colors\"):\n                colors = [\n                    self._jitter([x / 255 for x in self.metadata.thing_colors[c]])\n                    for c in category_ids\n                ]\n            names = self.metadata.get(\"thing_classes\", None)\n            labels = _create_text_labels(\n                category_ids,\n                scores=None,\n                class_names=names,\n                is_crowd=[x.get(\"iscrowd\", 0) for x in annos],\n            )\n            self.overlay_instances(\n                labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors\n            )\n\n        sem_seg = dic.get(\"sem_seg\", None)\n        if sem_seg is None and \"sem_seg_file_name\" in dic:\n            with PathManager.open(dic[\"sem_seg_file_name\"], \"rb\") as f:\n                sem_seg = Image.open(f)\n                sem_seg = np.asarray(sem_seg, dtype=\"uint8\")\n        if sem_seg is not None:\n            self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)\n\n        pan_seg = dic.get(\"pan_seg\", None)\n        if pan_seg is None and \"pan_seg_file_name\" in dic:\n            with PathManager.open(dic[\"pan_seg_file_name\"], \"rb\") as f:\n                pan_seg = Image.open(f)\n                pan_seg = np.asarray(pan_seg)\n                from panopticapi.utils import rgb2id\n\n                pan_seg = rgb2id(pan_seg)\n        if pan_seg is not None:\n            segments_info = dic[\"segments_info\"]\n            pan_seg = torch.tensor(pan_seg)\n            self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)\n        return self.output\n\n    def overlay_instances(\n        self,\n        *,\n        boxes=None,\n        labels=None,\n        masks=None,\n        keypoints=None,\n        assigned_colors=None,\n        alpha=0.5,\n    ):\n        \"\"\"\n        Args:\n            boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,\n                or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,\n                or a :class:`RotatedBoxes`,\n                or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format\n                for the N objects in a single image,\n            labels (list[str]): the text to be displayed for each instance.\n            masks (masks-like object): Supported types are:\n\n                * :class:`detectron2.structures.PolygonMasks`,\n                  :class:`detectron2.structures.BitMasks`.\n                * list[list[ndarray]]: contains the segmentation masks for all objects in one image.\n                  The first level of the list corresponds to individual instances. The second\n                  level to all the polygon that compose the instance, and the third level\n                  to the polygon coordinates. The third level should have the format of\n                  [x0, y0, x1, y1, ..., xn, yn] (n >= 3).\n                * list[ndarray]: each ndarray is a binary mask of shape (H, W).\n                * list[dict]: each dict is a COCO-style RLE.\n            keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),\n                where the N is the number of instances and K is the number of keypoints.\n                The last dimension corresponds to (x, y, visibility or score).\n            assigned_colors (list[matplotlib.colors]): a list of colors, where each color\n                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'\n                for full list of formats that the colors are accepted in.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        num_instances = 0\n        if boxes is not None:\n            boxes = self._convert_boxes(boxes)\n            num_instances = len(boxes)\n        if masks is not None:\n            masks = self._convert_masks(masks)\n            if num_instances:\n                assert len(masks) == num_instances\n            else:\n                num_instances = len(masks)\n        if keypoints is not None:\n            if num_instances:\n                assert len(keypoints) == num_instances\n            else:\n                num_instances = len(keypoints)\n            keypoints = self._convert_keypoints(keypoints)\n        if labels is not None:\n            assert len(labels) == num_instances\n        if assigned_colors is None:\n            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]\n        if num_instances == 0:\n            return self.output\n        if boxes is not None and boxes.shape[1] == 5:\n            return self.overlay_rotated_instances(\n                boxes=boxes, labels=labels, assigned_colors=assigned_colors\n            )\n\n        # Display in largest to smallest order to reduce occlusion.\n        areas = None\n        if boxes is not None:\n            areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)\n        elif masks is not None:\n            areas = np.asarray([x.area() for x in masks])\n\n        if areas is not None:\n            sorted_idxs = np.argsort(-areas).tolist()\n            # Re-order overlapped instances in descending order.\n            boxes = boxes[sorted_idxs] if boxes is not None else None\n            labels = [labels[k] for k in sorted_idxs] if labels is not None else None\n            masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None\n            assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]\n            keypoints = keypoints[sorted_idxs] if keypoints is not None else None\n\n        for i in range(num_instances):\n            color = assigned_colors[i]\n            if boxes is not None:\n                self.draw_box(boxes[i], edge_color=color)\n\n            if masks is not None:\n                for segment in masks[i].polygons:\n                    self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)\n\n            if labels is not None:\n                # first get a box\n                if boxes is not None:\n                    x0, y0, x1, y1 = boxes[i]\n                    text_pos = (x0, y0)  # if drawing boxes, put text on the box corner.\n                    horiz_align = \"left\"\n                elif masks is not None:\n                    # skip small mask without polygon\n                    if len(masks[i].polygons) == 0:\n                        continue\n\n                    x0, y0, x1, y1 = masks[i].bbox()\n\n                    # draw text in the center (defined by median) when box is not drawn\n                    # median is less sensitive to outliers.\n                    text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]\n                    horiz_align = \"center\"\n                else:\n                    continue  # drawing the box confidence for keypoints isn't very useful.\n                # for small objects, draw text at the side to avoid occlusion\n                instance_area = (y1 - y0) * (x1 - x0)\n                if (\n                    instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale\n                    or y1 - y0 < 40 * self.output.scale\n                ):\n                    if y1 >= self.output.height - 5:\n                        text_pos = (x1, y0)\n                    else:\n                        text_pos = (x0, y1)\n\n                height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)\n                lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n                font_size = (\n                    np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)\n                    * 0.5\n                    * self._default_font_size\n                )\n                self.draw_text(\n                    labels[i],\n                    text_pos,\n                    color=lighter_color,\n                    horizontal_alignment=horiz_align,\n                    font_size=font_size,\n                )\n\n        # draw keypoints\n        if keypoints is not None:\n            for keypoints_per_instance in keypoints:\n                self.draw_and_connect_keypoints(keypoints_per_instance)\n\n        return self.output\n\n    def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):\n        \"\"\"\n        Args:\n            boxes (ndarray): an Nx5 numpy array of\n                (x_center, y_center, width, height, angle_degrees) format\n                for the N objects in a single image.\n            labels (list[str]): the text to be displayed for each instance.\n            assigned_colors (list[matplotlib.colors]): a list of colors, where each color\n                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'\n                for full list of formats that the colors are accepted in.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        num_instances = len(boxes)\n\n        if assigned_colors is None:\n            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]\n        if num_instances == 0:\n            return self.output\n\n        # Display in largest to smallest order to reduce occlusion.\n        if boxes is not None:\n            areas = boxes[:, 2] * boxes[:, 3]\n\n        sorted_idxs = np.argsort(-areas).tolist()\n        # Re-order overlapped instances in descending order.\n        boxes = boxes[sorted_idxs]\n        labels = [labels[k] for k in sorted_idxs] if labels is not None else None\n        colors = [assigned_colors[idx] for idx in sorted_idxs]\n\n        for i in range(num_instances):\n            self.draw_rotated_box_with_label(\n                boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None\n            )\n\n        return self.output\n\n    def draw_and_connect_keypoints(self, keypoints):\n        \"\"\"\n        Draws keypoints of an instance and follows the rules for keypoint connections\n        to draw lines between appropriate keypoints. This follows color heuristics for\n        line color.\n\n        Args:\n            keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints\n                and the last dimension corresponds to (x, y, probability).\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        visible = {}\n        keypoint_names = self.metadata.get(\"keypoint_names\")\n        for idx, keypoint in enumerate(keypoints):\n            # draw keypoint\n            x, y, prob = keypoint\n            if prob > self.keypoint_threshold:\n                self.draw_circle((x, y), color=_RED)\n                if keypoint_names:\n                    keypoint_name = keypoint_names[idx]\n                    visible[keypoint_name] = (x, y)\n\n        if self.metadata.get(\"keypoint_connection_rules\"):\n            for kp0, kp1, color in self.metadata.keypoint_connection_rules:\n                if kp0 in visible and kp1 in visible:\n                    x0, y0 = visible[kp0]\n                    x1, y1 = visible[kp1]\n                    color = tuple(x / 255.0 for x in color)\n                    self.draw_line([x0, x1], [y0, y1], color=color)\n\n        # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip\n        # Note that this strategy is specific to person keypoints.\n        # For other keypoints, it should just do nothing\n        try:\n            ls_x, ls_y = visible[\"left_shoulder\"]\n            rs_x, rs_y = visible[\"right_shoulder\"]\n            mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2\n        except KeyError:\n            pass\n        else:\n            # draw line from nose to mid-shoulder\n            nose_x, nose_y = visible.get(\"nose\", (None, None))\n            if nose_x is not None:\n                self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)\n\n            try:\n                # draw line from mid-shoulder to mid-hip\n                lh_x, lh_y = visible[\"left_hip\"]\n                rh_x, rh_y = visible[\"right_hip\"]\n            except KeyError:\n                pass\n            else:\n                mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2\n                self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)\n        return self.output\n\n    \"\"\"\n    Primitive drawing functions:\n    \"\"\"\n\n    def draw_text(\n        self,\n        text,\n        position,\n        *,\n        font_size=None,\n        color=\"g\",\n        horizontal_alignment=\"center\",\n        rotation=0,\n    ):\n        \"\"\"\n        Args:\n            text (str): class label\n            position (tuple): a tuple of the x and y coordinates to place text on image.\n            font_size (int, optional): font of the text. If not provided, a font size\n                proportional to the image width is calculated and used.\n            color: color of the text. Refer to `matplotlib.colors` for full list\n                of formats that are accepted.\n            horizontal_alignment (str): see `matplotlib.text.Text`\n            rotation: rotation angle in degrees CCW\n\n        Returns:\n            output (VisImage): image object with text drawn.\n        \"\"\"\n        if not font_size:\n            font_size = self._default_font_size\n\n        # since the text background is dark, we don't want the text to be dark\n        color = np.maximum(list(mplc.to_rgb(color)), 0.2)\n        color[np.argmax(color)] = max(0.8, np.max(color))\n\n        x, y = position\n        self.output.ax.text(\n            x,\n            y,\n            text,\n            size=font_size * self.output.scale,\n            family=\"sans-serif\",\n            bbox={\"facecolor\": \"black\", \"alpha\": 0.8, \"pad\": 0.7, \"edgecolor\": \"none\"},\n            verticalalignment=\"top\",\n            horizontalalignment=horizontal_alignment,\n            color=color,\n            zorder=10,\n            rotation=rotation,\n        )\n        return self.output\n\n    def draw_box(self, box_coord, alpha=0.5, edge_color=\"g\", line_style=\"-\"):\n        \"\"\"\n        Args:\n            box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0\n                are the coordinates of the image's top left corner. x1 and y1 are the\n                coordinates of the image's bottom right corner.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            edge_color: color of the outline of the box. Refer to `matplotlib.colors`\n                for full list of formats that are accepted.\n            line_style (string): the string to use to create the outline of the boxes.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        x0, y0, x1, y1 = box_coord\n        width = x1 - x0\n        height = y1 - y0\n\n        linewidth = max(self._default_font_size / 4, 1)\n\n        self.output.ax.add_patch(\n            mpl.patches.Rectangle(\n                (x0, y0),\n                width,\n                height,\n                fill=False,\n                edgecolor=edge_color,\n                linewidth=linewidth * self.output.scale,\n                alpha=alpha,\n                linestyle=line_style,\n            )\n        )\n        return self.output\n\n    def draw_rotated_box_with_label(\n        self, rotated_box, alpha=0.5, edge_color=\"g\", line_style=\"-\", label=None\n    ):\n        \"\"\"\n        Draw a rotated box with label on its top-left corner.\n\n        Args:\n            rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),\n                where cnt_x and cnt_y are the center coordinates of the box.\n                w and h are the width and height of the box. angle represents how\n                many degrees the box is rotated CCW with regard to the 0-degree box.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            edge_color: color of the outline of the box. Refer to `matplotlib.colors`\n                for full list of formats that are accepted.\n            line_style (string): the string to use to create the outline of the boxes.\n            label (string): label for rotated box. It will not be rendered when set to None.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        cnt_x, cnt_y, w, h, angle = rotated_box\n        area = w * h\n        # use thinner lines when the box is small\n        linewidth = self._default_font_size / (\n            6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3\n        )\n\n        theta = angle * math.pi / 180.0\n        c = math.cos(theta)\n        s = math.sin(theta)\n        rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]\n        # x: left->right ; y: top->down\n        rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]\n        for k in range(4):\n            j = (k + 1) % 4\n            self.draw_line(\n                [rotated_rect[k][0], rotated_rect[j][0]],\n                [rotated_rect[k][1], rotated_rect[j][1]],\n                color=edge_color,\n                linestyle=\"--\" if k == 1 else line_style,\n                linewidth=linewidth,\n            )\n\n        if label is not None:\n            text_pos = rotated_rect[1]  # topleft corner\n\n            height_ratio = h / np.sqrt(self.output.height * self.output.width)\n            label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)\n            font_size = (\n                np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size\n            )\n            self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)\n\n        return self.output\n\n    def draw_circle(self, circle_coord, color, radius=3):\n        \"\"\"\n        Args:\n            circle_coord (list(int) or tuple(int)): contains the x and y coordinates\n                of the center of the circle.\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            radius (int): radius of the circle.\n\n        Returns:\n            output (VisImage): image object with box drawn.\n        \"\"\"\n        x, y = circle_coord\n        self.output.ax.add_patch(\n            mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)\n        )\n        return self.output\n\n    def draw_line(self, x_data, y_data, color, linestyle=\"-\", linewidth=None):\n        \"\"\"\n        Args:\n            x_data (list[int]): a list containing x values of all the points being drawn.\n                Length of list should match the length of y_data.\n            y_data (list[int]): a list containing y values of all the points being drawn.\n                Length of list should match the length of x_data.\n            color: color of the line. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            linestyle: style of the line. Refer to `matplotlib.lines.Line2D`\n                for a full list of formats that are accepted.\n            linewidth (float or None): width of the line. When it's None,\n                a default value will be computed and used.\n\n        Returns:\n            output (VisImage): image object with line drawn.\n        \"\"\"\n        if linewidth is None:\n            linewidth = self._default_font_size / 3\n        linewidth = max(linewidth, 1)\n        self.output.ax.add_line(\n            mpl.lines.Line2D(\n                x_data,\n                y_data,\n                linewidth=linewidth * self.output.scale,\n                color=color,\n                linestyle=linestyle,\n            )\n        )\n        return self.output\n\n    def draw_binary_mask(\n        self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0\n    ):\n        \"\"\"\n        Args:\n            binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and\n                W is the image width. Each value in the array is either a 0 or 1 value of uint8\n                type.\n            color: color of the mask. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted. If None, will pick a random color.\n            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a\n                full list of formats that are accepted.\n            text (str): if None, will be drawn in the object's center of mass.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n            area_threshold (float): a connected component small than this will not be shown.\n\n        Returns:\n            output (VisImage): image object with mask drawn.\n        \"\"\"\n        if color is None:\n            color = random_color(rgb=True, maximum=1)\n        color = mplc.to_rgb(color)\n\n        has_valid_segment = False\n        binary_mask = binary_mask.astype(\"uint8\")  # opencv needs uint8\n        mask = GenericMask(binary_mask, self.output.height, self.output.width)\n        shape2d = (binary_mask.shape[0], binary_mask.shape[1])\n\n        if not mask.has_holes:\n            # draw polygons for regular masks\n            for segment in mask.polygons:\n                area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))\n                if area < (area_threshold or 0):\n                    continue\n                has_valid_segment = True\n                segment = segment.reshape(-1, 2)\n                self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)\n        else:\n            # TODO: Use Path/PathPatch to draw vector graphics:\n            # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon\n            rgba = np.zeros(shape2d + (4,), dtype=\"float32\")\n            rgba[:, :, :3] = color\n            rgba[:, :, 3] = (mask.mask == 1).astype(\"float32\") * alpha\n            has_valid_segment = True\n            self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))\n\n        if text is not None and has_valid_segment:\n            # TODO sometimes drawn on wrong objects. the heuristics here can improve.\n            lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n            _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)\n            largest_component_id = np.argmax(stats[1:, -1]) + 1\n\n            # draw text on the largest component, as well as other very large components.\n            for cid in range(1, _num_cc):\n                if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:\n                    # median is more stable than centroid\n                    # center = centroids[largest_component_id]\n                    center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]\n                    self.draw_text(text, center, color=lighter_color)\n        return self.output\n\n    def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):\n        \"\"\"\n        Args:\n            segment: numpy array of shape Nx2, containing all the points in the polygon.\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a\n                full list of formats that are accepted. If not provided, a darker shade\n                of the polygon color will be used instead.\n            alpha (float): blending efficient. Smaller values lead to more transparent masks.\n\n        Returns:\n            output (VisImage): image object with polygon drawn.\n        \"\"\"\n        if edge_color is None:\n            # make edge color darker than the polygon color\n            if alpha > 0.8:\n                edge_color = self._change_color_brightness(color, brightness_factor=-0.7)\n            else:\n                edge_color = color\n        edge_color = mplc.to_rgb(edge_color) + (1,)\n\n        polygon = mpl.patches.Polygon(\n            segment,\n            fill=True,\n            facecolor=mplc.to_rgb(color) + (alpha,),\n            edgecolor=edge_color,\n            linewidth=max(self._default_font_size // 15 * self.output.scale, 1),\n        )\n        self.output.ax.add_patch(polygon)\n        return self.output\n\n    \"\"\"\n    Internal methods:\n    \"\"\"\n\n    def _jitter(self, color):\n        \"\"\"\n        Randomly modifies given color to produce a slightly different color than the color given.\n\n        Args:\n            color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color\n                picked. The values in the list are in the [0.0, 1.0] range.\n\n        Returns:\n            jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the\n                color after being jittered. The values in the list are in the [0.0, 1.0] range.\n        \"\"\"\n        color = mplc.to_rgb(color)\n        vec = np.random.rand(3)\n        # better to do it in another color space\n        vec = vec / np.linalg.norm(vec) * 0.5\n        res = np.clip(vec + color, 0, 1)\n        return tuple(res)\n\n    def _create_grayscale_image(self, mask=None):\n        \"\"\"\n        Create a grayscale version of the original image.\n        The colors in masked area, if given, will be kept.\n        \"\"\"\n        img_bw = self.img.astype(\"f4\").mean(axis=2)\n        img_bw = np.stack([img_bw] * 3, axis=2)\n        if mask is not None:\n            img_bw[mask] = self.img[mask]\n        return img_bw\n\n    def _change_color_brightness(self, color, brightness_factor):\n        \"\"\"\n        Depending on the brightness_factor, gives a lighter or darker color i.e. a color with\n        less or more saturation than the original color.\n\n        Args:\n            color: color of the polygon. Refer to `matplotlib.colors` for a full list of\n                formats that are accepted.\n            brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of\n                0 will correspond to no change, a factor in [-1.0, 0) range will result in\n                a darker color and a factor in (0, 1.0] range will result in a lighter color.\n\n        Returns:\n            modified_color (tuple[double]): a tuple containing the RGB values of the\n                modified color. Each value in the tuple is in the [0.0, 1.0] range.\n        \"\"\"\n        assert brightness_factor >= -1.0 and brightness_factor <= 1.0\n        color = mplc.to_rgb(color)\n        polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))\n        modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])\n        modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness\n        modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness\n        modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])\n        return modified_color\n\n    def _convert_boxes(self, boxes):\n        \"\"\"\n        Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.\n        \"\"\"\n        if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):\n            return boxes.tensor.detach().numpy()\n        else:\n            return np.asarray(boxes)\n\n    def _convert_masks(self, masks_or_polygons):\n        \"\"\"\n        Convert different format of masks or polygons to a tuple of masks and polygons.\n\n        Returns:\n            list[GenericMask]:\n        \"\"\"\n\n        m = masks_or_polygons\n        if isinstance(m, PolygonMasks):\n            m = m.polygons\n        if isinstance(m, BitMasks):\n            m = m.tensor.numpy()\n        if isinstance(m, torch.Tensor):\n            m = m.numpy()\n        ret = []\n        for x in m:\n            if isinstance(x, GenericMask):\n                ret.append(x)\n            else:\n                ret.append(GenericMask(x, self.output.height, self.output.width))\n        return ret\n\n    def _convert_keypoints(self, keypoints):\n        if isinstance(keypoints, Keypoints):\n            keypoints = keypoints.tensor\n        keypoints = np.asarray(keypoints)\n        return keypoints\n\n    def get_output(self):\n        \"\"\"\n        Returns:\n            output (VisImage): the image output containing the visualizations added\n            to the image.\n        \"\"\"\n        return self.output\n"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/models/yolo.py",
    "content": "import os\nimport cv2\nimport torch\nfrom pdf_extract_kit.registry import MODEL_REGISTRY\nfrom pdf_extract_kit.utils.visualization import visualize_bbox\nfrom pdf_extract_kit.dataset.dataset import ImageDataset\n\n@MODEL_REGISTRY.register('layout_detection_yolo')\nclass LayoutDetectionYOLO:\n    def __init__(self, config):\n        \"\"\"\n        Initialize the LayoutDetectionYOLO class.\n\n        Args:\n            config (dict): Configuration dictionary containing model parameters.\n        \"\"\"\n        # Mapping from class IDs to class names\n        self.id_to_names = {\n            0: 'title', \n            1: 'plain text',\n            2: 'abandon', \n            3: 'figure', \n            4: 'figure_caption', \n            5: 'table', \n            6: 'table_caption', \n            7: 'table_footnote', \n            8: 'isolate_formula', \n            9: 'formula_caption'\n        }\n\n        # Load the YOLO model from the specified path\n        try:\n            from doclayout_yolo import YOLOv10\n            self.model = YOLOv10(config['model_path'])\n        except AttributeError:\n            from ultralytics import YOLO\n            self.model = YOLO(config['model_path'])\n\n        # Set model parameters\n        self.img_size = config.get('img_size', 1280)\n        self.conf_thres = config.get('conf_thres', 0.25)\n        self.iou_thres = config.get('iou_thres', 0.45)\n        self.visualize = config.get('visualize', False)\n        self.nc = config.get('nc', 10)\n        self.workers = config.get('workers', 8)\n        self.device = config.get('device', 'cpu')\n        \n        if self.iou_thres > 0:\n            import torchvision\n            self.nms_func = torchvision.ops.nms\n\n    def predict(self, images, result_path, image_ids=None):\n        \"\"\"\n        Predict formulas in images.\n\n        Args:\n            images (list): List of images to be predicted.\n            result_path (str): Path to save the prediction results.\n            image_ids (list, optional): List of image IDs corresponding to the images.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        results = []\n        for idx, image in enumerate(images):\n            result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False, device=self.device)[0]\n            if self.visualize:\n                if not os.path.exists(result_path):\n                    os.makedirs(result_path)\n                boxes = result.__dict__['boxes'].xyxy\n                classes = result.__dict__['boxes'].cls\n                scores = result.__dict__['boxes'].conf\n\n                if self.iou_thres > 0:\n                    indices = self.nms_func(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=self.iou_thres)\n                    boxes, scores, classes = boxes[indices], scores[indices], classes[indices]\n                    if len(boxes.shape) == 1:\n                        boxes = np.expand_dims(boxes, 0)\n                        scores = np.expand_dims(scores, 0)\n                        classes = np.expand_dims(classes, 0)\n                \n                vis_result = visualize_bbox(image, boxes, classes, scores, self.id_to_names)\n\n                # Determine the base name of the image\n                if image_ids:\n                    base_name = image_ids[idx]\n                else:\n                    # base_name = os.path.basename(image)\n                    base_name = os.path.splitext(os.path.basename(image))[0]  # Remove file extension\n                \n                result_name = f\"{base_name}_layout.png\"\n                \n                # Save the visualized result                \n                cv2.imwrite(os.path.join(result_path, result_name), vis_result)\n            results.append(result)\n        return results"
  },
  {
    "path": "pdf_extract_kit/tasks/layout_detection/task.py",
    "content": "from pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.tasks.base_task import BaseTask\n\n\n@TASK_REGISTRY.register(\"layout_detection\")\nclass LayoutDetectionTask(BaseTask):\n    def __init__(self, model):\n        super().__init__(model)\n\n    def predict_images(self, input_data, result_path):\n        \"\"\"\n        Predict layouts in images.\n\n        Args:\n            input_data (str): Path to a single image file or a directory containing image files.\n            result_path (str): Path to save the prediction results.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        images = self.load_images(input_data)\n        # Perform detection\n        return self.model.predict(images, result_path)\n\n    def predict_pdfs(self, input_data, result_path):\n        \"\"\"\n        Predict layouts in PDF files.\n\n        Args:\n            input_data (str): Path to a single PDF file or a directory containing PDF files.\n            result_path (str): Path to save the prediction results.\n\n        Returns:\n            list: List of prediction results.\n        \"\"\"\n        pdf_images = self.load_pdf_images(input_data)\n        # Perform detection\n        return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys()))"
  },
  {
    "path": "pdf_extract_kit/tasks/ocr/__init__.py",
    "content": "from pdf_extract_kit.tasks.ocr.models.paddle_ocr import ModifiedPaddleOCR\n# from pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n\n__all__ = [\n    \"ModifiedPaddleOCR\",\n]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/ocr/models/paddle_ocr.py",
    "content": "import time\nimport copy\nimport logging\nimport base64\nimport cv2\nimport numpy as np\nfrom io import BytesIO\nfrom PIL import Image\n\nfrom paddleocr import PaddleOCR\nfrom ppocr.utils.logging import get_logger\nfrom ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img\nfrom tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop\nfrom pdf_extract_kit.registry import MODEL_REGISTRY\nlogger = get_logger()\n\ndef img_decode(content: bytes):\n    np_arr = np.frombuffer(content, dtype=np.uint8)\n    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)\n\ndef check_img(img):\n    if isinstance(img, bytes):\n        img = img_decode(img)\n    if isinstance(img, str):\n        image_file = img\n        img, flag_gif, flag_pdf = check_and_read(image_file)\n        if not flag_gif and not flag_pdf:\n            with open(image_file, 'rb') as f:\n                img_str = f.read()\n                img = img_decode(img_str)\n            if img is None:\n                try:\n                    buf = BytesIO()\n                    image = BytesIO(img_str)\n                    im = Image.open(image)\n                    rgb = im.convert('RGB')\n                    rgb.save(buf, 'jpeg')\n                    buf.seek(0)\n                    image_bytes = buf.read()\n                    data_base64 = str(base64.b64encode(image_bytes),\n                                      encoding=\"utf-8\")\n                    image_decode = base64.b64decode(data_base64)\n                    img_array = np.frombuffer(image_decode, np.uint8)\n                    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)\n                except:\n                    logger.error(\"error in loading image:{}\".format(image_file))\n                    return None\n        if img is None:\n            logger.error(\"error in loading image:{}\".format(image_file))\n            return None\n    if isinstance(img, np.ndarray) and len(img.shape) == 2:\n        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)\n    if isinstance(img, Image.Image):\n        img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)\n    return img\n\ndef sorted_boxes(dt_boxes):\n    \"\"\"\n    Sort text boxes in order from top to bottom, left to right\n    args:\n        dt_boxes(array):detected text boxes with shape [4, 2]\n    return:\n        sorted boxes(array) with shape [4, 2]\n    \"\"\"\n    num_boxes = dt_boxes.shape[0]\n    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))\n    _boxes = list(sorted_boxes)\n\n    for i in range(num_boxes - 1):\n        for j in range(i, -1, -1):\n            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \\\n                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):\n                tmp = _boxes[j]\n                _boxes[j] = _boxes[j + 1]\n                _boxes[j + 1] = tmp\n            else:\n                break\n    return _boxes\n\n\ndef __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):\n    \"\"\"Check if two bounding boxes overlap on the y-axis, and if the height of the overlapping region exceeds 80% of the height of the shorter bounding box.\"\"\"\n    _, y0_1, _, y1_1 = bbox1\n    _, y0_2, _, y1_2 = bbox2\n\n    overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))\n    height1, height2 = y1_1 - y0_1, y1_2 - y0_2\n    max_height = max(height1, height2)\n    min_height = min(height1, height2)\n\n    return (overlap / min_height) > overlap_ratio_threshold\n\n\ndef bbox_to_points(bbox):\n    \"\"\" change bbox(shape: N * 4) to polygon(shape: N * 8) \"\"\"\n    x0, y0, x1, y1 = bbox\n    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')\n\n\ndef points_to_bbox(points):\n    \"\"\" change polygon(shape: N * 8) to bbox(shape: N * 4) \"\"\"\n    x0, y0 = points[0]\n    x1, _ = points[1]\n    _, y1 = points[2]\n    return [x0, y0, x1, y1]\n\n\ndef merge_intervals(intervals):\n    # Sort the intervals based on the start value\n    intervals.sort(key=lambda x: x[0])\n\n    merged = []\n    for interval in intervals:\n        # If the list of merged intervals is empty or if the current\n        # interval does not overlap with the previous, simply append it.\n        if not merged or merged[-1][1] < interval[0]:\n            merged.append(interval)\n        else:\n            # Otherwise, there is overlap, so we merge the current and previous intervals.\n            merged[-1][1] = max(merged[-1][1], interval[1])\n\n    return merged\n\n\ndef remove_intervals(original, masks):\n    # Merge all mask intervals\n    merged_masks = merge_intervals(masks)\n\n    result = []\n    original_start, original_end = original\n\n    for mask in merged_masks:\n        mask_start, mask_end = mask\n\n        # If the mask starts after the original range, ignore it\n        if mask_start > original_end:\n            continue\n\n        # If the mask ends before the original range starts, ignore it\n        if mask_end < original_start:\n            continue\n\n        # Remove the masked part from the original range\n        if original_start < mask_start:\n            result.append([original_start, mask_start - 1])\n\n        original_start = max(mask_end + 1, original_start)\n\n    # Add the remaining part of the original range, if any\n    if original_start <= original_end:\n        result.append([original_start, original_end])\n\n    return result\n\n\ndef update_det_boxes(dt_boxes, mfd_res):\n    new_dt_boxes = []\n    for text_box in dt_boxes:\n        text_bbox = points_to_bbox(text_box)\n        masks_list = []\n        for mf_box in mfd_res:\n            mf_bbox = mf_box['bbox']\n            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):\n                masks_list.append([mf_bbox[0], mf_bbox[2]])\n        text_x_range = [text_bbox[0], text_bbox[2]]\n        text_remove_mask_range = remove_intervals(text_x_range, masks_list)\n        temp_dt_box = []\n        for text_remove_mask in text_remove_mask_range:\n            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))\n        if len(temp_dt_box) > 0:\n            new_dt_boxes.extend(temp_dt_box)\n    return new_dt_boxes\n\n\ndef merge_spans_to_line(spans):\n    \"\"\"\n    Merge given spans into lines. Spans are considered based on their position in the document.\n    If spans overlap sufficiently on the Y-axis, they are merged into the same line; otherwise, a new line is started.\n\n    Parameters:\n    spans (list): A list of spans, where each span is a dictionary containing at least the key 'bbox',\n                  which itself is a list of four integers representing the bounding box:\n                  [x0, y0, x1, y1], where (x0, y0) is the top-left corner and (x1, y1) is the bottom-right corner.\n\n    Returns:\n    list: A list of lines, where each line is a list of spans.\n    \"\"\"\n    # Return an empty list if the spans list is empty\n    if len(spans) == 0:\n        return []\n    else:\n        # Sort spans by the Y0 coordinate\n        spans.sort(key=lambda span: span['bbox'][1])\n\n        lines = []\n        current_line = [spans[0]]\n        for span in spans[1:]:\n            # If the current span overlaps with the last span in the current line on the Y-axis, add it to the current line\n            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):\n                current_line.append(span)\n            else:\n                # Otherwise, start a new line\n                lines.append(current_line)\n                current_line = [span]\n\n        # Add the last line if it exists\n        if current_line:\n            lines.append(current_line)\n\n        return lines\n\n\ndef merge_overlapping_spans(spans):\n    \"\"\"\n    Merges overlapping spans on the same line.\n\n    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]\n    :return: A list of merged spans\n    \"\"\"\n    # Return an empty list if the input spans list is empty\n    if not spans:\n        return []\n\n    # Sort spans by their starting x-coordinate\n    spans.sort(key=lambda x: x[0])\n\n    # Initialize the list of merged spans\n    merged = []\n    for span in spans:\n        # Unpack span coordinates\n        x1, y1, x2, y2 = span\n        # If the merged list is empty or there's no horizontal overlap, add the span directly\n        if not merged or merged[-1][2] < x1:\n            merged.append(span)\n        else:\n            # If there is horizontal overlap, merge the current span with the previous one\n            last_span = merged.pop()\n            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)\n            x1 = min(last_span[0], x1)\n            y1 = min(last_span[1], y1)\n            x2 = max(last_span[2], x2)\n            y2 = max(last_span[3], y2)\n            # Add the merged span back to the list\n            merged.append((x1, y1, x2, y2))\n\n    # Return the list of merged spans\n    return merged\n\n\ndef merge_det_boxes(dt_boxes):\n    \"\"\"\n    Merge detection boxes.\n\n    This function takes a list of detected bounding boxes, each represented by four corner points.\n    The goal is to merge these bounding boxes into larger text regions.\n\n    Parameters:\n    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.\n\n    Returns:\n    list: A list containing the merged text regions, where each region is represented by four corner points.\n    \"\"\"\n    # Convert the detection boxes into a dictionary format with bounding boxes and type\n    dt_boxes_dict_list = []\n    for text_box in dt_boxes:\n        text_bbox = points_to_bbox(text_box)\n        text_box_dict = {\n            'bbox': text_bbox,\n        }\n        dt_boxes_dict_list.append(text_box_dict)\n\n    # Merge adjacent text regions into lines\n    lines = merge_spans_to_line(dt_boxes_dict_list)\n\n    # Initialize a new list for storing the merged text regions\n    new_dt_boxes = []\n    for line in lines:\n        line_bbox_list = []\n        for span in line:\n            line_bbox_list.append(span['bbox'])\n\n        # Merge overlapping text regions within the same line\n        merged_spans = merge_overlapping_spans(line_bbox_list)\n\n        # Convert the merged text regions back to point format and add them to the new detection box list\n        for span in merged_spans:\n            new_dt_boxes.append(bbox_to_points(span))\n\n    return new_dt_boxes\n\n@MODEL_REGISTRY.register('ocr_ppocr')\nclass ModifiedPaddleOCR(PaddleOCR):\n    def __init__(self, config):\n        super().__init__(**config)\n        \n    def predict(self, img, **kwargs):\n        ppocr_res = self.ocr(img, **kwargs)[0]\n        ocr_res = []\n        for box_ocr_res in ppocr_res:\n            p1, p2, p3, p4 = box_ocr_res[0]\n            text, score = box_ocr_res[1]\n            ocr_res.append({\n                \"category_type\": \"text\",\n                'poly': p1 + p2 + p3 + p4,\n                'score': round(score, 2),\n                'text': text,\n            })\n        return ocr_res\n        \n    def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):\n        \"\"\"\n        OCR with PaddleOCR\n        args：\n            img: img for OCR, support ndarray, img_path and list or ndarray\n            det: use text detection or not. If False, only rec will be exec. Default is True\n            rec: use text recognition or not. If False, only det will be exec. Default is True\n            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.\n            bin: binarize image to black and white. Default is False.\n            inv: invert image colors. Default is False.\n            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.\n        \"\"\"\n        assert isinstance(img, (np.ndarray, list, str, bytes, Image.Image))\n        if isinstance(img, list) and det == True:\n            logger.error('When input a list of images, det must be false')\n            exit(0)\n        if cls == True and self.use_angle_cls == False:\n            logger.warning(\n                'Since the angle classifier is not initialized, it will not be used during the forward process'\n            )\n\n        img = check_img(img)\n        # for infer pdf file\n        if isinstance(img, list):\n            if self.page_num > len(img) or self.page_num == 0:\n                self.page_num = len(img)\n            imgs = img[:self.page_num]\n        else:\n            imgs = [img]\n\n        def preprocess_image(_image):\n            _image = alpha_to_color(_image, alpha_color)\n            if inv:\n                _image = cv2.bitwise_not(_image)\n            if bin:\n                _image = binarize_img(_image)\n            return _image\n\n        if det and rec:\n            ocr_res = []\n            for idx, img in enumerate(imgs):\n                img = preprocess_image(img)\n                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)\n                if not dt_boxes and not rec_res:\n                    ocr_res.append(None)\n                    continue\n                tmp_res = [[box.tolist(), res]\n                           for box, res in zip(dt_boxes, rec_res)]\n                ocr_res.append(tmp_res)\n            return ocr_res\n        elif det and not rec:\n            ocr_res = []\n            for idx, img in enumerate(imgs):\n                img = preprocess_image(img)\n                dt_boxes, elapse = self.text_detector(img)\n                if not dt_boxes:\n                    ocr_res.append(None)\n                    continue\n                tmp_res = [box.tolist() for box in dt_boxes]\n                ocr_res.append(tmp_res)\n            return ocr_res\n        else:\n            ocr_res = []\n            cls_res = []\n            for idx, img in enumerate(imgs):\n                if not isinstance(img, list):\n                    img = preprocess_image(img)\n                    img = [img]\n                if self.use_angle_cls and cls:\n                    img, cls_res_tmp, elapse = self.text_classifier(img)\n                    if not rec:\n                        cls_res.append(cls_res_tmp)\n                rec_res, elapse = self.text_recognizer(img)\n                ocr_res.append(rec_res)\n            if not rec:\n                return cls_res\n            return ocr_res\n        \n    def __call__(self, img, cls=True, mfd_res=None):\n        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}\n\n        if img is None:\n            logger.debug(\"no valid image provided\")\n            return None, None, time_dict\n\n        start = time.time()\n        ori_im = img.copy()\n        dt_boxes, elapse = self.text_detector(img)\n        time_dict['det'] = elapse\n\n        if dt_boxes is None:\n            logger.debug(\"no dt_boxes found, elapsed : {}\".format(elapse))\n            end = time.time()\n            time_dict['all'] = end - start\n            return None, None, time_dict\n        else:\n            logger.debug(\"dt_boxes num : {}, elapsed : {}\".format(\n                len(dt_boxes), elapse))\n        img_crop_list = []\n\n        dt_boxes = sorted_boxes(dt_boxes)\n\n        dt_boxes = merge_det_boxes(dt_boxes)\n\n        if mfd_res:\n            bef = time.time()\n            dt_boxes = update_det_boxes(dt_boxes, mfd_res)\n            aft = time.time()\n            logger.debug(\"split text box by formula, new dt_boxes num : {}, elapsed : {}\".format(\n                len(dt_boxes), aft-bef))\n\n        for bno in range(len(dt_boxes)):\n            tmp_box = copy.deepcopy(dt_boxes[bno])\n            if self.args.det_box_type == \"quad\":\n                img_crop = get_rotate_crop_image(ori_im, tmp_box)\n            else:\n                img_crop = get_minarea_rect_crop(ori_im, tmp_box)\n            img_crop_list.append(img_crop)\n        if self.use_angle_cls and cls:\n            img_crop_list, angle_list, elapse = self.text_classifier(\n                img_crop_list)\n            time_dict['cls'] = elapse\n            logger.debug(\"cls num  : {}, elapsed : {}\".format(\n                len(img_crop_list), elapse))\n\n        rec_res, elapse = self.text_recognizer(img_crop_list)\n        time_dict['rec'] = elapse\n        logger.debug(\"rec_res num  : {}, elapsed : {}\".format(\n            len(rec_res), elapse))\n        if self.args.save_crop_res:\n            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,\n                                   rec_res)\n        filter_boxes, filter_rec_res = [], []\n        for box, rec_result in zip(dt_boxes, rec_res):\n            text, score = rec_result\n            if score >= self.drop_score:\n                filter_boxes.append(box)\n                filter_rec_res.append(rec_result)\n        end = time.time()\n        time_dict['all'] = end - start\n        return filter_boxes, filter_rec_res, time_dict"
  },
  {
    "path": "pdf_extract_kit/tasks/ocr/task.py",
    "content": "import os\nimport json\nimport random\nfrom PIL import Image, ImageDraw\nfrom pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.utils.data_preprocess import load_pdf\nfrom pdf_extract_kit.tasks.base_task import BaseTask\n\n\n@TASK_REGISTRY.register(\"ocr\")\nclass OCRTask(BaseTask):\n    def __init__(self, model):\n        \"\"\"init the task based on the given model.\n        \n        Args:\n            model: task model, must contains predict function.\n        \"\"\"\n        super().__init__(model)\n\n    def predict_image(self, image):\n        \"\"\"predict on one image, reture text detection and recognition results.\n        \n        Args:\n            image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict)\n            \n        Returns:\n            List[dict]: list of text bbox with it's content\n            \n        Return example:\n            [\n                {\n                    \"category_type\": \"text\",\n                    \"poly\": [\n                        380.6792698635707,\n                        159.85058512958923,\n                        765.1419999999998,\n                        159.85058512958923,\n                        765.1419999999998,\n                        192.51073013642917,\n                        380.6792698635707,\n                        192.51073013642917\n                    ],\n                    \"text\": \"this is an example text\",\n                    \"score\": 0.97\n                },\n                ...\n            ]\n        \"\"\"\n        return self.model.predict(image)\n        \n    def prepare_input_files(self, input_path):\n        if os.path.isdir(input_path):\n            file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)]\n        else:\n            file_list = [input_path]\n        return file_list\n            \n    def process(self, input_path, save_dir=None, visualize=False):\n        file_list = self.prepare_input_files(input_path)\n        res_list = []\n        for fpath in file_list:\n            basename = os.path.basename(fpath)[:-4]\n            if fpath.endswith(\".pdf\") or fpath.endswith(\".PDF\"):\n                images = load_pdf(fpath)\n                pdf_res = []\n                for page, img in enumerate(images):\n                    page_res = self.predict_image(img)\n                    pdf_res.append(page_res)\n                    if save_dir:\n                        os.makedirs(os.path.join(save_dir, basename), exist_ok=True)\n                        self.save_json_result(page_res, os.path.join(save_dir, basename, f\"page_{page+1}.json\"))\n                        if visualize:\n                            self.visualize_image(img, page_res, os.path.join(save_dir, basename, f\"page_{page+1}.jpg\"))\n                        \n                res_list.append(pdf_res)\n            else:\n                image = Image.open(fpath)\n                img_res = self.predict_image(image)\n                res_list.append(img_res)\n                if save_dir:\n                    os.makedirs(save_dir, exist_ok=True)\n                    self.save_json_result(img_res, os.path.join(save_dir, f\"{basename}.json\"))\n                    if visualize:\n                        self.visualize_image(image, img_res, os.path.join(save_dir, f\"{basename}.png\"))\n                \n        return res_list\n    \n    def visualize_image(self, image, ocr_res, save_path=\"\", cate2color={}):\n        \"\"\"plot each result's bbox and category on image.\n        \n        Args:\n            image: PIL.Image.Image\n            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function\n            save_path: path to save visualized image\n        \"\"\"\n        draw = ImageDraw.Draw(image)\n        for res in ocr_res:\n            box_color = cate2color.get(res['category_type'], (0, 255, 0))\n            x_min, y_min = int(res['poly'][0]), int(res['poly'][1])\n            x_max, y_max = int(res['poly'][4]), int(res['poly'][5])\n            draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1)\n            draw.text((x_min, y_min), res['category_type'], (255, 0, 0))\n        if save_path:\n            image.save(save_path)\n        \n    def save_json_result(self, ocr_res, save_path):\n        \"\"\"save results to a json file.\n        \n        Args:\n            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function\n            save_path: path to save visualized image\n        \"\"\"\n        with open(save_path, \"w\", encoding=\"utf-8\") as f:\n            f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False))\n        \n        \n"
  },
  {
    "path": "pdf_extract_kit/tasks/table_parsing/__init__.py",
    "content": "from pdf_extract_kit.tasks.table_parsing.models.struct_eqtable import TableParsingStructEqTable\n\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n__all__ = [\n    \"TableParsingStructEqTable\",\n]\n"
  },
  {
    "path": "pdf_extract_kit/tasks/table_parsing/models/struct_eqtable.py",
    "content": "import torch\n\nfrom PIL import Image\nfrom struct_eqtable import build_model\nfrom pdf_extract_kit.registry.registry import MODEL_REGISTRY\n\n\n@MODEL_REGISTRY.register(\"table_parsing_struct_eqtable\")\nclass TableParsingStructEqTable:\n    def __init__(self, config):\n        \"\"\"\n        Initialize the TableParsingStructEqTable class.\n\n        Args:\n            config (dict): Configuration dictionary containing model parameters.\n        \"\"\"\n        assert torch.cuda.is_available(), \"CUDA must be available for StructEqTable model.\"\n\n        self.model_dir = config.get('model_path', 'U4R/StructTable-InternVL2-1B')\n        self.max_new_tokens = config.get('max_new_tokens', 1024)\n        self.max_time = config.get('max_time', 30)\n\n        self.lmdeploy = config.get('lmdeploy', False)\n        self.flash_attn = config.get('flash_attn', True)\n        self.batch_size = config.get('batch_size', 1)\n        self.default_format = config.get('output_format', 'latex')\n\n        # Load the StructEqTable model\n        self.model = build_model(\n            model_ckpt=self.model_dir,\n            max_new_tokens=self.max_new_tokens,\n            max_time=self.max_time,\n            lmdeploy=self.lmdeploy,\n            flash_attn=self.flash_attn,\n            batch_size=self.batch_size,\n        ).cuda()\n\n    def predict(self, images, result_path, output_format=None, **kwargs):        \n\n        load_images = [Image.open(image_path) for image_path in images]\n\n        if output_format is None:\n            output_format = self.default_format\n        else:\n            if output_format not in ['latex', 'markdown', 'html']:\n                raise ValueError(f\"Output format {output_format} is not supported.\")\n\n        results = self.model(\n            load_images, output_format=output_format\n        )\n\n        return results\n"
  },
  {
    "path": "pdf_extract_kit/tasks/table_parsing/task.py",
    "content": "from pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.tasks.base_task import BaseTask\n\n\n@TASK_REGISTRY.register(\"table_parsing\")\nclass TableParsingTask(BaseTask):\n    def __init__(self, model):\n        super().__init__(model)\n\n    def predict(self, input_data, result_path, **kwargs):\n        images = self.load_images(input_data)\n        # Perform layout detection on input_data\n        return self.model.predict(images, result_path, **kwargs)"
  },
  {
    "path": "pdf_extract_kit/utils/__init__.py",
    "content": ""
  },
  {
    "path": "pdf_extract_kit/utils/config_loader.py",
    "content": "import yaml\nimport warnings\nfrom pdf_extract_kit.registry.registry import TASK_REGISTRY, MODEL_REGISTRY\n\n\ndef load_config(config_path):\n    if config_path is None:\n        warnings.warn(\n            (\"Configuration path is None. Please provide a valid configuration file path. \")\n        )\n        return None\n    \n    with open(config_path, 'r') as file:\n        config = yaml.safe_load(file)\n    return config\n\n\n# def initialize_task_and_model(config):\n#     task_name = config['task']\n#     model_name = config['model']\n#     model_config = config['model_config']\n\n#     TaskClass = TASK_REGISTRY.get(task_name)\n#     ModelClass = MODEL_REGISTRY.get(model_name)\n\n#     model_instance = ModelClass(model_config)\n#     task_instance = TaskClass(model_instance)\n\n#     return task_instance\n\ndef initialize_tasks_and_models(config):\n\n    task_instances = {}\n    for task_name in config['tasks']:\n\n        model_name = config['tasks'][task_name]['model']\n        model_config = config['tasks'][task_name]['model_config']\n\n        TaskClass = TASK_REGISTRY.get(task_name)\n        ModelClass = MODEL_REGISTRY.get(model_name)\n\n        model_instance = ModelClass(model_config)\n        task_instance = TaskClass(model_instance)\n\n        task_instances[task_name] = task_instance\n\n    return task_instances"
  },
  {
    "path": "pdf_extract_kit/utils/data_preprocess.py",
    "content": "import fitz\nfrom PIL import Image\n\n\ndef load_pdf_page(page, dpi):\n    pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72))\n    image = Image.frombytes(\"RGB\", [pix.width, pix.height], pix.samples)\n    if pix.width > 3000 or pix.height > 3000:\n        pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)\n        image = Image.frombytes(\"RGB\", [pix.width, pix.height], pix.samples)\n    return image\n\ndef load_pdf(pdf_path, dpi=144):\n    images = []\n    doc = fitz.open(pdf_path)\n    for i in range(len(doc)):\n        page = doc[i]\n        image = load_pdf_page(page, dpi)\n        images.append(image)\n    return images"
  },
  {
    "path": "pdf_extract_kit/utils/merge_blocks_and_spans.py",
    "content": "# revised from https://github.com/opendatalab/MinerU/blob/7f0fe20004af7416db886f4b75c116bcc1c986b4/magic_pdf/pdf_parse_union_core.py#L177\n# from fast_langdetect import detect_language\n# import unicodedata\nimport re\n\n\ndef __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):\n    \"\"\"检查两个bbox在y轴上是否有重叠，并且该重叠区域的高度占两个bbox高度更低的那个超过80%\"\"\"\n    _, y0_1, _, y1_1 = bbox1\n    _, y0_2, _, y1_2 = bbox2\n\n    overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))\n    height1, height2 = y1_1 - y0_1, y1_2 - y0_2\n    max_height = max(height1, height2)\n    min_height = min(height1, height2)\n\n    return (overlap / min_height) > overlap_ratio_threshold\n\ndef merge_spans_to_line(spans):\n    if len(spans) == 0:\n        return []\n    else:\n        # 按照y0坐标排序\n        spans.sort(key=lambda span: span['bbox'][1])\n\n        lines = []\n        current_line = [spans[0]]\n        for span in spans[1:]:\n            # 如果当前的span类型为\"isolated\" 或者 当前行中已经有\"isolated\"\n            # image和table类型，同上\n            if span['type'] in ['isolated'] or any(\n                    s['type'] in ['isolated'] for s in\n                    current_line):\n                # 则开始新行\n                lines.append(current_line)\n                current_line = [span]\n                continue\n\n            # 如果当前的span与当前行的最后一个span在y轴上重叠，则添加到当前行\n            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):\n                current_line.append(span)\n            else:\n                # 否则，开始新行\n                lines.append(current_line)\n                current_line = [span]\n\n        # 添加最后一行\n        if current_line:\n            lines.append(current_line)\n\n        return lines\n\n# 将每一个line中的span从左到右排序\ndef line_sort_spans_by_left_to_right(lines):\n    line_objects = []\n    for line in lines:\n        # 按照x0坐标排序\n        line.sort(key=lambda span: span['bbox'][0])\n        line_bbox = [\n            min(span['bbox'][0] for span in line),  # x0\n            min(span['bbox'][1] for span in line),  # y0\n            max(span['bbox'][2] for span in line),  # x1\n            max(span['bbox'][3] for span in line),  # y1\n        ]\n        line_objects.append({\n            \"bbox\": line_bbox,\n            \"spans\": line,\n        })\n    return line_objects\n\ndef fix_text_block(block):\n    # 文本block中的公式span都应该转换成行内type\n    for span in block['spans']:\n        if span['type'] == \"isolated\":\n            span['type'] = \"inline\"\n    block_lines = merge_spans_to_line(block['spans'])\n    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)\n    block['lines'] = sort_block_lines\n    del block['spans']\n    return block\n\n\ndef fix_interline_block(block):\n    block_lines = merge_spans_to_line(block['spans'])\n    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)\n    block['lines'] = sort_block_lines\n    del block['spans']\n    return block\n\ndef calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):\n    \"\"\"\n    计算box1和box2的重叠面积占bbox1的比例\n    \"\"\"\n    # Determine the coordinates of the intersection rectangle\n    x_left = max(bbox1[0], bbox2[0])\n    y_top = max(bbox1[1], bbox2[1])\n    x_right = min(bbox1[2], bbox2[2])\n    y_bottom = min(bbox1[3], bbox2[3])\n\n    if x_right < x_left or y_bottom < y_top:\n        return 0.0\n\n    # The area of overlap area\n    intersection_area = (x_right - x_left) * (y_bottom - y_top)\n    bbox1_area = (bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1])\n    if bbox1_area == 0:\n        return 0\n    else:\n        return intersection_area / bbox1_area\n\ndef fill_spans_in_blocks(blocks, spans, radio):\n    '''\n    将allspans中的span按位置关系，放入blocks中\n    '''\n    block_with_spans = []\n    for block in blocks:\n        block_type = block[\"category_type\"]\n        L = block['poly'][0]\n        U = block['poly'][1]\n        R = block['poly'][2]\n        D = block['poly'][5]\n        L, R = min(L, R), max(L, R)\n        U, D = min(U, D), max(U, D)\n        block_bbox = [L, U, R, D]\n        block_dict = {\n            'type': block_type,\n            'bbox': block_bbox,\n            'saved_info': block\n        }\n        block_spans = []\n        for span in spans:\n            span_bbox = span[\"bbox\"]\n            if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio:\n                block_spans.append(span)\n\n        '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''\n        # displayed_list = []\n        # text_inline_lines = []\n        # modify_y_axis(block_spans, displayed_list, text_inline_lines)\n\n        '''模型识别错误的行间公式, type类型转换成行内公式'''\n        # block_spans = modify_inline(block_spans, displayed_list, text_inline_lines)\n\n        '''bbox去除粘连'''  # 去粘连会影响span的bbox，导致后续fill的时候出错\n        # block_spans = remove_overlap_between_bbox_for_span(block_spans)\n\n        block_dict['spans'] = block_spans\n        block_with_spans.append(block_dict)\n\n        # 从spans删除已经放入block_spans中的span\n        if len(block_spans) > 0:\n            for span in block_spans:\n                spans.remove(span)\n\n    return block_with_spans, spans\n\ndef fix_block_spans(block_with_spans):\n    '''\n    1、img_block和table_block因为包含caption和footnote的关系，存在block的嵌套关系\n        需要将caption和footnote的text_span放入相应img_block和table_block内的\n        caption_block和footnote_block中\n    2、同时需要删除block中的spans字段\n    '''\n    fix_blocks = []\n    for block in block_with_spans:\n        block_type = block['type']\n\n        # if block_type == BlockType.Image:\n        #     block = fix_image_block(block, img_blocks)\n        # elif block_type == BlockType.Table:\n        #     block = fix_table_block(block, table_blocks)\n        if block_type == \"isolate_formula\":\n            block = fix_interline_block(block)\n        else:\n            block = fix_text_block(block)\n        fix_blocks.append(block)\n    return fix_blocks\n\n\n# def detect_lang(text: str) -> str:\n\n#     if len(text) == 0:\n#         return \"\"\n#     try:\n#         lang_upper = detect_language(text)\n#     except:\n#         html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])\n#         lang_upper = detect_language(html_no_ctrl_chars)\n#     try:\n#         lang = lang_upper.lower()\n#     except:\n#         lang = \"\"\n#     return lang\n\ndef detect_lang(string):\n    \"\"\"\n    检查整个字符串是否包含中文\n    :param string: 需要检查的字符串\n    :return: bool\n    \"\"\"\n\n    for ch in string:\n        if u'\\u4e00' <= ch <= u'\\u9fff':\n            return 'zh'\n    return 'en'\n\ndef ocr_escape_special_markdown_char(content):\n    \"\"\"\n    转义正文里对markdown语法有特殊意义的字符\n    \"\"\"\n    special_chars = [\"*\", \"`\", \"~\", \"$\"]\n    for char in special_chars:\n        content = content.replace(char, \"\\\\\" + char)\n\n    return content\n\n# def split_long_words(text):\n#     segments = text.split(' ')\n#     for i in range(len(segments)):\n#         words = re.findall(r'\\w+|[^\\w]', segments[i], re.UNICODE)\n#         for j in range(len(words)):\n#             if len(words[j]) > 15:\n#                 words[j] = ' '.join(wordninja.split(words[j]))\n#         segments[i] = ''.join(words)\n#     return ' '.join(segments)\n\n\ndef merge_para_with_text(para_block):\n    para_text = ''\n    for line in para_block['lines']:\n        line_text = \"\"\n        line_lang = \"\"\n        for span in line['spans']:\n            span_type = span['type']\n            if span_type == \"text\":\n                line_text += span['content'].strip()\n        if line_text != \"\":\n            line_lang = detect_lang(line_text)\n        for span in line['spans']:\n            span_type = span['type']\n            content = ''\n            if span_type == \"text\":\n                content = span['content']\n                content = ocr_escape_special_markdown_char(content)\n                # language = detect_lang(content)\n                # if language == 'en':  # 只对英文长词进行分词处理，中文分词会丢失文本\n                    # content = ocr_escape_special_markdown_char(split_long_words(content))\n                # else:\n                #     content = ocr_escape_special_markdown_char(content)\n            elif span_type == 'inline':\n                content = f\" ${span['content'].strip('$')}$ \"\n            elif span_type == 'ignore-formula':\n                content = f\" ${span['content'].strip('$')}$ \"\n            elif span_type == 'isolated':\n                content = f\"\\n$$\\n{span['content'].strip('$')}\\n$$\\n\"    \n            elif span_type == 'footnote':\n                content_ori = span['content'].strip('$')\n                if '^' in content_ori:\n                    content = f\" ${content_ori}$ \"\n                else:\n                    content = f\" $^{content_ori}$ \"\n\n            if content != '':\n                if 'zh' in line_lang:  # 遇到一些一个字一个span的文档，这种单字语言判断不准，需要用整行文本判断\n                    para_text += content.strip()  # 中文语境下，content间不需要空格分隔\n                else:\n                    para_text += content.strip() + ' '  # 英文语境下 content间需要空格分隔\n    return para_text"
  },
  {
    "path": "pdf_extract_kit/utils/pdf_utils.py",
    "content": "from pdf2image import convert_from_path\n\ndef load_pdf(pdf_path):\n    images = convert_from_path(pdf_path)\n    return images\n"
  },
  {
    "path": "pdf_extract_kit/utils/visualization.py",
    "content": "import numpy as np\nimport cv2\nfrom PIL import Image\n\ndef colormap(N=256, normalized=False):\n    \"\"\"\n    Generate the color map.\n\n    Args:\n        N (int): Number of labels (default is 256).\n        normalized (bool): If True, return colors normalized to [0, 1]. Otherwise, return [0, 255].\n\n    Returns:\n        np.ndarray: Color map array of shape (N, 3).\n    \"\"\"\n    def bitget(byteval, idx):\n        \"\"\"\n        Get the bit value at the specified index.\n\n        Args:\n            byteval (int): The byte value.\n            idx (int): The index of the bit.\n\n        Returns:\n            int: The bit value (0 or 1).\n        \"\"\"\n        return ((byteval & (1 << idx)) != 0)\n\n    cmap = np.zeros((N, 3), dtype=np.uint8)\n    for i in range(N):\n        r = g = b = 0\n        c = i\n        for j in range(8):\n            r = r | (bitget(c, 0) << (7 - j))\n            g = g | (bitget(c, 1) << (7 - j))\n            b = b | (bitget(c, 2) << (7 - j))\n            c = c >> 3\n        cmap[i] = np.array([r, g, b])\n    \n    if normalized:\n        cmap = cmap.astype(np.float32) / 255.0\n\n    return cmap\n\ndef visualize_bbox(image_path, bboxes, classes, scores, id_to_names, alpha=0.3):\n    \"\"\"\n    Visualize layout detection results on an image.\n\n    Args:\n        image_path (str): Path to the input image.\n        bboxes (list): List of bounding boxes, each represented as [x_min, y_min, x_max, y_max].\n        classes (list): List of class IDs corresponding to the bounding boxes.\n        id_to_names (dict): Dictionary mapping class IDs to class names.\n        alpha (float): Transparency factor for the filled color (default is 0.3).\n\n    Returns:\n        np.ndarray: Image with visualized layout detection results.\n    \"\"\"\n    # Check if image_path is a PIL.Image.Image object\n    if isinstance(image_path, Image.Image):\n        image = np.array(image_path)\n        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # Convert RGB to BGR for OpenCV\n    else:\n        image = cv2.imread(image_path)\n\n    overlay = image.copy()\n    \n    cmap = colormap(N=len(id_to_names), normalized=False)\n    \n    # Iterate over each bounding box\n    for i, bbox in enumerate(bboxes):\n        x_min, y_min, x_max, y_max = map(int, bbox)\n        class_id = int(classes[i])\n        class_name = id_to_names[class_id]\n        \n        text = class_name + f\":{scores[i]:.3f}\"\n        \n        color = tuple(int(c) for c in cmap[class_id])\n        cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)\n        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)\n        \n        # Add the class name with a background rectangle\n        (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2)\n        cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)\n        cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)\n    \n    # Blend the overlay with the original image\n    cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)\n    \n    return image"
  },
  {
    "path": "pdf_extract_kit/version.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\n__version__ = '0.1.0'\nshort_version = __version__\n\n\ndef parse_version_info(version_str: str) -> Tuple:\n    \"\"\"Parse version from a string.\n\n    Args:\n        version_str (str): A string represents a version info.\n\n    Returns:\n        tuple: A sequence of integer and string represents version.\n    \"\"\"\n    _version_info = []\n    for x in version_str.split('.'):\n        if x.isdigit():\n            _version_info.append(int(x))\n        elif x.find('rc') != -1:\n            patch_version = x.split('rc')\n            _version_info.append(int(patch_version[0]))\n            _version_info.append(f'rc{patch_version[1]}')\n    return tuple(_version_info)\n\n\nversion_info = parse_version_info(__version__)"
  },
  {
    "path": "project/pdf2markdown/README.md",
    "content": "# PDF2Markdown\n\n**Demo:(left: input image; right: rendered markdown.)**\n\n![demo](demo.png)\n\n\n1. Extract PDF features by these tasks:\n\n    - Layout Detection: Using the YOLOv8 model for region detection, such as images, tables, titles, text, etc.;\n\n    - Formula Detection: Using YOLOv8 for detecting formulas, including inline formulas and isolated formulas;\n\n    - Formula Recognition: Using UniMERNet for formula recognition;\n\n    - Table Recognition: Using StructEqTable for table recognition;\n\n    - Optical Character Recognition: Using PaddleOCR for text recognition;\n\n2. Convert features to markdown file:\n\n    Using simple rules to convert the identified result to markdown (*Note: this is a simply convert code and can only support one-column PDFs, see [MinerU](https://github.com/opendatalab/MinerU) for more complex situation*).\n\n\n# Usage\n\n```\npython project/pdf2markdown/scripts/run_project.py --config project/pdf2markdown/configs/pdf2markdown.yaml\n```\n"
  },
  {
    "path": "project/pdf2markdown/configs/pdf2markdown.yaml",
    "content": "inputs: assets/demo/formula_detection\noutputs: outputs/pdf2markdown\nvisualize: True\nmerge2markdown: True\ntasks:\n  layout_detection:\n    model: layout_detection_yolo\n    model_config:\n      img_size: 1024\n      conf_thres: 0.25\n      iou_thres: 0.45\n      model_path: models/Layout/YOLO/doclayout_yolo_ft.pt\n  formula_detection:\n    model: formula_detection_yolo\n    model_config:\n      img_size: 1280\n      conf_thres: 0.25\n      iou_thres: 0.45\n      batch_size: 1\n      model_path: models/MFD/YOLO/yolo_v8_ft.pt\n  formula_recognition:\n    model: formula_recognition_unimernet\n    model_config:\n      batch_size: 128\n      cfg_path: pdf_extract_kit/configs/unimernet.yaml\n      model_path: models/MFR/unimernet_tiny\n  ocr:\n    model: ocr_ppocr\n    model_config:\n      lang: ch\n      show_log: True\n      det_model_dir: models/OCR/PaddleOCR/det/ch_PP-OCRv4_det\n      rec_model_dir: models/OCR/PaddleOCR/rec/ch_PP-OCRv4_rec\n      det_db_box_thresh: 0.3\n\n  "
  },
  {
    "path": "project/pdf2markdown/scripts/pdf2markdown.py",
    "content": "import os\nimport re\nimport gc\nimport sys\nimport time\nimport torch\nfrom PIL import Image, ImageDraw\nfrom torchvision import transforms\nfrom torch.utils.data import DataLoader\n\nsys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..'))\nfrom pdf_extract_kit.utils.data_preprocess import load_pdf\nfrom pdf_extract_kit.tasks.ocr.task import OCRTask\nfrom pdf_extract_kit.dataset.dataset import MathDataset\nfrom pdf_extract_kit.registry.registry import TASK_REGISTRY\nfrom pdf_extract_kit.utils.merge_blocks_and_spans import (\n    fill_spans_in_blocks,\n    fix_block_spans,\n    merge_para_with_text\n)\n\n\ndef latex_rm_whitespace(s: str):\n    \"\"\"Remove unnecessary whitespace from LaTeX code.\n    \"\"\"\n    text_reg = r'(\\\\(operatorname|mathrm|text|mathbf)\\s?\\*? {.*?})'\n    letter = '[a-zA-Z]'\n    noletter = '[\\W_^\\d]'\n    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]\n    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)\n    news = s\n    while True:\n        s = news\n        news = re.sub(r'(?!\\\\ )(%s)\\s+?(%s)' % (noletter, noletter), r'\\1\\2', s)\n        news = re.sub(r'(?!\\\\ )(%s)\\s+?(%s)' % (noletter, letter), r'\\1\\2', news)\n        news = re.sub(r'(%s)\\s+?(%s)' % (letter, noletter), r'\\1\\2', news)\n        if news == s:\n            break\n    return s\n\ndef crop_img(input_res, input_pil_img, padding_x=0, padding_y=0):\n    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])\n    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])\n    # Create a white background with an additional width and height of 50\n    crop_new_width = crop_xmax - crop_xmin + padding_x * 2\n    crop_new_height = crop_ymax - crop_ymin + padding_y * 2\n    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')\n\n    # Crop image\n    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)\n    cropped_img = input_pil_img.crop(crop_box)\n    return_image.paste(cropped_img, (padding_x, padding_y))\n    return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]\n    return return_image, return_list\n\n@TASK_REGISTRY.register(\"pdf2markdown\")\nclass PDF2MARKDOWN(OCRTask):\n    def __init__(self, layout_model, mfd_model, mfr_model, ocr_model):\n        self.layout_model = layout_model\n        self.mfd_model = mfd_model\n        self.mfr_model = mfr_model\n        self.ocr_model = ocr_model\n        if self.mfr_model is not None:\n            assert self.mfd_model is not None, \"formula recognition based on formula detection, mfd_model can not be None.\"\n            self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ])\n            \n        self.color_palette  = {\n            'title': (255, 64, 255),\n            'plain text': (255, 255, 0),\n            'abandon': (0, 255, 255),\n            'figure': (255, 215, 135),\n            'figure_caption': (215, 0, 95),\n            'table': (100, 0, 48),\n            'table_caption': (0, 175, 0),\n            'table_footnote': (95, 0, 95),\n            'isolate_formula': (175, 95, 0),\n            'formula_caption': (95, 95, 0),\n            'inline': (0, 0, 255),\n            'isolated': (0, 255, 0),\n            'text': (255, 0, 0)\n        }\n\n    def convert_format(self, yolo_res, id_to_names, ):\n        \"\"\"\n        convert yolo format to pdf-extract format.\n        \"\"\"\n        res_list = []\n        for xyxy, conf, cla in zip(yolo_res.boxes.xyxy.cpu(), yolo_res.boxes.conf.cpu(), yolo_res.boxes.cls.cpu()):\n            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]\n            new_item = {\n                'category_type': id_to_names[int(cla.item())],\n                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],\n                'score': round(float(conf.item()), 2),\n            }\n            res_list.append(new_item)\n        return res_list\n    \n    \n    def process_single_pdf(self, image_list):\n        \"\"\"predict on one image, reture text detection and recognition results.\n        \n        Args:\n            image_list: List[PIL.Image.Image]\n            \n        Returns:\n            List[dict]: list of PDF extract results\n            \n        Return example:\n            [\n                {\n                    \"layout_dets\": [\n                        {\n                            \"category_type\": \"text\",\n                            \"poly\": [\n                                380.6792698635707,\n                                159.85058512958923,\n                                765.1419999999998,\n                                159.85058512958923,\n                                765.1419999999998,\n                                192.51073013642917,\n                                380.6792698635707,\n                                192.51073013642917\n                            ],\n                            \"text\": \"this is an example text\",\n                            \"score\": 0.97\n                        },\n                        ...\n                    ], \n                    \"page_info\": {\n                        \"page_no\": 0,\n                        \"height\": 2339,\n                        \"width\": 1654,\n                    }\n                },\n                ...\n            ]\n        \"\"\"\n        pdf_extract_res = []\n        mf_image_list = []\n        latex_filling_list = []\n        for idx, image in enumerate(image_list):\n            img_W, img_H = image.size\n            if self.layout_model is not None:\n                ori_layout_res = self.layout_model.predict([image], \"\")[0]\n                layout_res = self.convert_format(ori_layout_res, self.layout_model.id_to_names)\n            else:\n                layout_res = []\n            single_page_res = {'layout_dets': layout_res}\n            single_page_res['page_info'] = dict(\n                page_no = idx,\n                height = img_H,\n                width = img_W\n            )\n            if self.mfd_model is not None:\n                mfd_res = self.mfd_model.predict([image], \"\")[0]\n                for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):\n                    xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]\n                    new_item = {\n                        'category_type': self.mfd_model.id_to_names[int(cla.item())],\n                        'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],\n                        'score': round(float(conf.item()), 2),\n                        'latex': '',\n                    }\n                    single_page_res['layout_dets'].append(new_item)\n                    if self.mfr_model is not None:\n                        latex_filling_list.append(new_item)\n                        bbox_img = image.crop((xmin, ymin, xmax, ymax))\n                        mf_image_list.append(bbox_img)\n                    \n                pdf_extract_res.append(single_page_res)\n                \n                del mfd_res\n                torch.cuda.empty_cache()\n                gc.collect()\n            \n        # Formula recognition, collect all formula images in whole pdf file, then batch infer them.\n        if self.mfr_model is not None:\n            a = time.time()\n            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)\n            dataloader = DataLoader(dataset, batch_size=self.mfr_model.batch_size, num_workers=0)\n\n            mfr_res = []\n            for imgs in dataloader:\n                imgs = imgs.to(self.mfr_model.device)\n                output = self.mfr_model.model.generate({'image': imgs})\n                mfr_res.extend(output['pred_str'])\n            for res, latex in zip(latex_filling_list, mfr_res):\n                res['latex'] = latex_rm_whitespace(latex)\n            b = time.time()\n            print(\"formula nums:\", len(mf_image_list), \"mfr time:\", round(b-a, 2))\n        \n        # ocr_res = self.ocr_model.predict(image)\n            \n        # ocr and table recognition\n        for idx, image in enumerate(image_list):\n            layout_res = pdf_extract_res[idx]['layout_dets']\n            pil_img = image.copy()\n\n            ocr_res_list = []\n            table_res_list = []\n            single_page_mfdetrec_res = []\n\n            for res in layout_res:\n                if res['category_type'] in self.mfd_model.id_to_names.values():\n                    single_page_mfdetrec_res.append({\n                        \"bbox\": [int(res['poly'][0]), int(res['poly'][1]),\n                                 int(res['poly'][4]), int(res['poly'][5])],\n                    })\n                elif res['category_type'] in [self.layout_model.id_to_names[cid] for cid in [0, 1, 2, 4, 6, 7]]:\n                    ocr_res_list.append(res)\n                elif res['category_type'] in [self.layout_model.id_to_names[5]]:\n                    table_res_list.append(res)\n\n            ocr_start = time.time()\n            # Process each area that requires OCR processing\n            for res in ocr_res_list:\n                new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25)\n                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list\n                # Adjust the coordinates of the formula area\n                adjusted_mfdetrec_res = []\n                for mf_res in single_page_mfdetrec_res:\n                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res[\"bbox\"]\n                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area\n                    x0 = mf_xmin - xmin + paste_x\n                    y0 = mf_ymin - ymin + paste_y\n                    x1 = mf_xmax - xmin + paste_x\n                    y1 = mf_ymax - ymin + paste_y\n                    # Filter formula blocks outside the graph\n                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):\n                        continue\n                    else:\n                        adjusted_mfdetrec_res.append({\n                            \"bbox\": [x0, y0, x1, y1],\n                        })\n\n                # OCR recognition\n                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]\n\n                # Integration results\n                if ocr_res:\n                    for box_ocr_res in ocr_res:\n                        p1, p2, p3, p4 = box_ocr_res[0]\n                        text, score = box_ocr_res[1]\n\n                        # Convert the coordinates back to the original coordinate system\n                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]\n                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]\n                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]\n                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]\n\n                        layout_res.append({\n                            'category_type': 'text',\n                            'poly': p1 + p2 + p3 + p4,\n                            'score': round(score, 2),\n                            'text': text,\n                        })\n\n            ocr_cost = round(time.time() - ocr_start, 2)\n            print(f\"ocr cost: {ocr_cost}\")\n        return pdf_extract_res\n    \n    def order_blocks(self, blocks):\n        def calculate_oder(poly):\n            xmin, ymin, _, _, xmax, ymax, _, _ = poly\n            return ymin*3000 + xmin\n        return sorted(blocks, key=lambda item: calculate_oder(item['poly']))\n                 \n    def convert2md(self, extract_res):\n        blocks = []\n        spans = []\n\n        for item in extract_res['layout_dets']:\n            if item['category_type'] in ['inline', 'text', 'isolated']:\n                text_key = 'text' if item['category_type'] == 'text' else 'latex'\n                xmin, ymin, _, _, xmax, ymax, _, _ = item['poly']\n                spans.append(\n                    {\n                        \"type\": item['category_type'],\n                        \"bbox\": [xmin, ymin, xmax, ymax],\n                        \"content\": item[text_key]\n                    }\n                )\n                if item['category_type'] == \"isolated\":\n                    item['category_type'] = \"isolate_formula\"\n                    blocks.append(item)\n            else:\n                blocks.append(item)\n                \n        blocks_types = [\"title\", \"plain text\", \"figure_caption\", \"table_caption\", \"table_footnote\", \"isolate_formula\", \"formula_caption\"]\n\n        need_fix_bbox = []\n        final_block = []\n        for block in blocks:\n            block_type = block[\"category_type\"]\n            if block_type in blocks_types:\n                need_fix_bbox.append(block)\n            else:\n                final_block.append(block)\n                \n        block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6)\n        \n        fix_blocks = fix_block_spans(block_with_spans)\n        for para_block in fix_blocks:\n            result = merge_para_with_text(para_block)\n            if para_block['type'] == \"isolate_formula\":\n                para_block['saved_info']['latex'] = result\n            else:\n                para_block['saved_info']['text'] = result\n            final_block.append(para_block['saved_info'])\n            \n        final_block = self.order_blocks(final_block)\n        md_text = \"\"\n        for block in final_block:\n            if block['category_type'] == \"title\":\n                md_text += \"\\n# \"+block['text'] +\"\\n\"\n            elif block['category_type'] in [\"isolate_formula\"]:\n                md_text += \"\\n\"+block['latex']+\"\\n\"\n            elif block['category_type'] in [\"plain text\", \"figure_caption\", \"table_caption\"]:\n                md_text += \" \"+block['text']+\" \"\n            elif block['category_type'] in [\"figure\", \"table\"]:\n                continue\n            else:\n                continue\n        return md_text\n        \n    def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False):\n        file_list = self.prepare_input_files(input_path)\n        res_list = []\n        for fpath in file_list:\n            basename = os.path.basename(fpath)[:-4]\n            if fpath.endswith(\".pdf\") or fpath.endswith(\".PDF\"):\n                images = load_pdf(fpath)\n            else:\n                images = [Image.open(fpath)]\n            pdf_extract_res = self.process_single_pdf(images)\n            res_list.append(pdf_extract_res)\n            if save_dir:\n                os.makedirs(save_dir, exist_ok=True)\n                self.save_json_result(pdf_extract_res, os.path.join(save_dir, f\"{basename}.json\"))\n                \n                if merge2markdown:\n                    md_content = []\n                    for extract_res in pdf_extract_res:\n                        md_text = self.convert2md(extract_res)\n                        md_content.append(md_text)\n                    with open(os.path.join(save_dir, f\"{basename}.md\"), \"w\") as f:\n                        f.write(\"\\n\\n\".join(md_content))\n                        \n                if visualize:\n                    for image, page_res in zip(images, pdf_extract_res):\n                        self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette)\n                    if fpath.endswith(\".pdf\") or fpath.endswith(\".PDF\"):\n                        first_page = images.pop(0)\n                        first_page.save(os.path.join(save_dir, f'{basename}.pdf'), 'PDF', resolution=100, save_all=True, append_images=images)\n                    else:\n                        images[0].save(os.path.join(save_dir, f\"{basename}.png\"))\n\n        return res_list\n        \n        \n        \n        "
  },
  {
    "path": "project/pdf2markdown/scripts/run_project.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\nfrom pdf2markdown import PDF2MARKDOWN\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nfrom pdf_extract_kit.registry.registry import TASK_REGISTRY\n\n\nTASK_NAME = 'pdf2markdown'\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs/pdf_extract')\n    visualize = config.get('visualize', False)\n    merge2markdown = config.get('merge2markdown', False)\n\n    layout_model = task_instances['layout_detection'].model if 'layout_detection' in task_instances else None\n    mfd_model = task_instances['formula_detection'].model if 'formula_detection' in task_instances else None\n    mfr_model = task_instances['formula_recognition'].model if 'formula_recognition' in task_instances else None\n    ocr_model = task_instances['ocr'].model if 'ocr' in task_instances else None\n    \n    pdf_extract_task = TASK_REGISTRY.get(TASK_NAME)(layout_model, mfd_model, mfr_model, ocr_model)\n    extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown)\n\n    print(f'Task done, results can be found at {result_path}')\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"pdf-extract-kit\"\nversion = \"0.1.0\"\nauthors = [\n    { name=\"Bin Wang\", email=\"ictwangbin@gmail.com\" }\n]\ndescription = \"A Comprehensive Toolkit for High-Quality PDF Content Extraction.\"\nreadme = \"README.md\"\nlicense = { file=\"LICENSE\" }\nrequires-python = \">=3.10\"\ndependencies = [\n    \"PyPDF2\",\n    \"matplotlib\",\n    \"pyyaml\",\n    \"frontend\",\n    \"pymupdf\",\n    opencv-python = \"^4.6.0\"\n    # Add other common dependencies\n]\n\n[project.optional-dependencies]\nlayout_detection = [\n    \"transformers\",  # for layoutlmv3\n    # Add other dependencies for layout detection\n]\nformula_detection = [\n    \"ultralytics\",  # for yolov8\n    # Add other dependencies for formula detection\n]\n# Add additional dependencies for other models\n"
  },
  {
    "path": "requirements/docs.txt",
    "content": "myst-parser\nsphinx\nsphinx-book-theme\nsphinx-copybutton\nsphinx-tabs\nsphinxcontrib-mermaid"
  },
  {
    "path": "requirements-cpu.txt",
    "content": "omegaconf\nmatplotlib\nPyMuPDF\nultralytics>=8.2.85\ndoclayout-yolo==0.0.2\nunimernet==0.2.1\npaddlepaddle\npaddleocr==2.7.3\nstruct-eqtable\n"
  },
  {
    "path": "requirements.txt",
    "content": "omegaconf\nmatplotlib\nPyMuPDF\nultralytics>=8.2.85\ndoclayout-yolo==0.0.2\nunimernet==0.2.1\npaddlepaddle-gpu\npaddleocr==2.7.3\nstruct-eqtable\nlmdeploy\n"
  },
  {
    "path": "scripts/formula_detection.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks\n\nTASK_NAME = 'formula_detection'\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n    # formula_detection_task\n    model_formula_detection = task_instances[TASK_NAME]\n\n    # for image detection\n    detection_results = model_formula_detection.predict_images(input_data, result_path)\n\n    # for pdf detection\n    # detection_results = model_formula_detection.predict_pdfs(input_data, result_path)\n\n    # print(detection_results)\n    print(f'The predicted results can be found at {result_path}')\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "scripts/formula_recognition.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks\n\nTASK_NAME = 'formula_recognition'\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n    # formula_detection_task\n    model_formula_recognition = task_instances[TASK_NAME]\n\n    # for image detection\n    recognition_results = model_formula_recognition.predict(input_data, result_path)\n\n\n    print('Recognition results are as follows:')\n    for id, math in enumerate(recognition_results):\n        print(str(id+1)+': ', math)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "scripts/layout_detection.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks\n\nTASK_NAME = 'layout_detection'\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n    # layout_detection_task\n    model_layout_detection = task_instances[TASK_NAME]\n\n    # for image detection\n    detection_results = model_layout_detection.predict_images(input_data, result_path)\n\n    # for pdf detection\n    # detection_results = model_layout_detection.predict_pdfs(input_data, result_path)\n\n    # print(detection_results)\n    print(f'The predicted results can be found at {result_path}')\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "scripts/ocr.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks\n\nTASK_NAME = 'ocr'\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n    visualize = config.get('visualize', False)\n\n    # formula_detection_task\n    task = task_instances[TASK_NAME]\n\n    detection_results = task.process(input_data, save_dir=result_path, visualize=visualize)\n\n    print(f'Task done, results can be found at {result_path}')\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "scripts/run_task.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks  # 确保所有任务模块被导入\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # 从配置文件中获取输入数据路径\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs')\n\n    # formula_detection_task\n    model_formula_detection = task_instances['formula_detection']\n    detection_results = model_formula_detection.predict(input_data, result_path)\n    print(detection_results)\n\n    # formula_recognition_task\n    # model_formula_recognition = task_instances['formula_recognition']\n    # recognition_results = model_formula_recognition.predict(input_data, result_path)\n\n    # for id, math in enumerate(recognition_results):\n    #     print(str(id+1)+': ', math)\n\n    # results = task_instance.run(input_data)\n    # print(results)\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  },
  {
    "path": "scripts/table_parsing.py",
    "content": "import os\nimport sys\nimport os.path as osp\nimport argparse\n\nsys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))\nfrom pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models\nimport pdf_extract_kit.tasks\n\nTASK_NAME = 'table_parsing'\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a task with a given configuration file.\")\n    parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')\n    return parser.parse_args()\n\ndef main(config_path):\n    config = load_config(config_path)\n    task_instances = initialize_tasks_and_models(config)\n\n    # get input and output path from config\n    input_data = config.get('inputs', None)\n    result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)\n\n    # table_parsing_task\n    model_table_parsing = task_instances[TASK_NAME]\n\n    # for image detection\n    parsing_results = model_table_parsing.predict(input_data, result_path)\n\n\n    print('Table Parsing results are as follows:')\n    for id, result in enumerate(parsing_results):\n        print(str(id+1)+':\\n', result)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args.config)\n"
  }
]