[
  {
    "path": ".gitignore",
    "content": "dist/\nbuild/\n**.egg-info/\n**__pycache__/\n**.cache\nckpts/\n**version.py\n\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<h1>StructEqTable-Deploy: A High-efficiency Open-source Toolkit for Table-to-Latex Transformation</h1>\n\n\n[[ Paper ]](https://arxiv.org/abs/2505.16938) [[ Website ]](https://alpha-innovator.github.io/InternAgent-project-page) [[ Dataset🤗 ]](https://huggingface.co/datasets/U4R/DocGenome/tree/main) [[ Models🤗 ]](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) [[ Demo💬 ]](https://www.modelscope.cn/studios/HongbinZhou/StructEqTable-Demo/)\n\n\n</div>\n\nWelcome to the official repository StructEqTable-Deploy of InternScience group, a solution that converts images of Table into LaTeX/HTML/MarkDown, powered by scalable data from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/).\n\n\n## Overview\nTable is an effective way to represent structured data in scientific publications, financial statements, invoices, web pages, and many other scenarios. Extracting tabular data from a visual table image and performing the downstream reasoning tasks according to the extracted data is challenging, mainly due to that tables often present complicated column and row headers with spanning cell operation. To address these challenges, we present TableX, a large-scale multi-modal table benchmark extracted from [DocGenome benchmark](https://alpha-innovator.github.io/InternAgent-project-page/) for table pre-training, comprising more than 2 million high-quality Image-LaTeX pair data covering 156 disciplinary classes. Besides, benefiting from such large-scale data, we train an end-to-end model, StructEqTable, which provides the capability to precisely obtain the corresponding LaTeX description from a visual table image and perform multiple table-related reasoning tasks, including structural extraction and question answering, broadening its application scope and potential.\n\n## Changelog\n- [2024/12/12] 🔥 We have released latest model **[StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main)** with enhanced recognition stability for HTML and Markdown formats!\n\n- [2024/10/19] We have released our latest model StructTable-InternVL2-1B!\n\n  Thanks to IntenrVL2 powerful foundational capabilities, and through fine-tuning on the synthetic tabular data and DocGenome dataset, StructTable can convert table image into various common table formats including LaTeX, HTML, and Markdown. Moreover, inference speed has been significantly improved compared to the v0.2 version.\n- [2024/8/22] We have released our StructTable-base-v0.2, fine-tuned on the DocGenome dataset. This version features improved inference speed and robustness, achieved through data augmentation and reduced image token num.\n- [2024/8/08] We have released the TensorRT accelerated version, which only takes about 1 second for most images on GPU A100. Please follow the tutorial to install the environment and compile the model weights.\n- [2024/7/30] We have released the first version of StructEqTable. \n\n## TODO\n\n- [x] Release inference code and checkpoints of StructEqTable.\n- [x] Support Chinese version of StructEqTable.\n- [x] Accelerated version of StructEqTable using TensorRT-LLM.\n- [x] Expand more domains of table image to improve the model's general capabilities.\n- [x] Efficient inference of StructTable-InternVL2-1B by [LMDeploy](https://github.com/InternLM/lmdeploy) Tookit.\n- [ ] Release our table pre-training and fine-tuning code\n\n\n## Installation\n``` bash \nconda create -n structeqtable python>=3.10\nconda activate structeqtable\n\n# Install from Source code  (Suggested)\ngit clone https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git\ncd StructEqTable-Deploy\npip install -r requirements.txt\npython setup develop\n\n# or Install from Github repo\npip install \"git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git\"\n\n# or Install from PyPI\npip install struct-eqtable --upgrade\n```\n\n## Model Zoo\n\n| Base Model | Model Size | Training Data | Data Augmentation | LMDeploy | TensorRT | HuggingFace |\n|---------------------|------------|------------------|-------------------|----------|----------|-------------------|\n| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) |\n| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.1](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/v0.1) |\n| Pix2Struct-base | ~300M | DocGenome | ✔ | | ✔ | [StructTable-base v0.2](https://huggingface.co/U4R/StructTable-base/tree/v0.2) |\n| Pix2Struct-base | ~300M | DocGenome | | | ✔ | [StructTable-base v0.1](https://huggingface.co/U4R/StructTable-base/tree/v0.1) |\n\n\n\n## Quick Demo\n- Run the demo/demo.py\n```shell script\ncd tools/demo\n\npython demo.py \\\n  --image_path ./demo.png \\\n  --ckpt_path U4R/StructTable-InternVL2-1B \\\n  --output_format latex\n```\n\n- HTML or Markdown format output (Only Supported by StructTable-InternVL2-1B)\n\n```shell script\npython demo.py \\\n  --image_path ./demo.png \\\n  --ckpt_path U4R/StructTable-InternVL2-1B \\\n  --output_format html markdown\n```\n\n## Efficient Inference\n- Install LMDeploy Tookit\n```shell script\npip install lmdeploy\n```\n\n- Run the demo/demo.py\n```shell script\ncd tools/demo\n\npython demo.py \\\n  --image_path ./demo.png \\\n  --ckpt_path U4R/StructTable-InternVL2-1B \\\n  --output_format latex \\\n  --lmdeploy\n```\n\n\n- Visualization Result\n\n  You can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) for table visualization.\n![](docs/imgs/output.png)\n\n\n## Acknowledgements\n- [DocGenome](https://github.com/UniModal4Reasoning/DocGenome). An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Models.\n- [ChartVLM](https://github.com/UniModal4Reasoning/ChartVLM). A Versatile Benchmark and Foundation Model for Complicated Chart Reasoning.\n- [Pix2Struct](https://github.com/google-research/pix2struct). Screenshot Parsing as Pretraining for Visual Language Understanding.\n- [InternVL Family](https://github.com/OpenGVLab/InternVL). A Series of Powerful Foundational Vision-Language Models.\n- [LMDeploy](https://github.com/InternLM/lmdeploy). A toolkit for compressing, deploying, and serving LLM and MLLM.\n- [UniMERNet](https://github.com/opendatalab/UniMERNet). A Universal Network for Real-World Mathematical Expression Recognition.\n- [Donut](https://huggingface.co/naver-clova-ix/donut-base). The UniMERNet's Transformer Encoder-Decoder are referenced from Donut.\n- [Nougat](https://github.com/facebookresearch/nougat). Data Augmentation follows Nougat.  \n- [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Model inference acceleration uses TensorRT-LLM.\n\n\n## License\nStructEqTable is released under the [Apache License 2.0](LICENSE)\n\n## Citation\nIf you find our models / code / papers useful in your research, please consider giving ⭐ and citations 📝, thx :)  \n```bibtex\n@article{xia2024docgenome,\n  title={DocGenome: An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Language Models},\n  author={Xia, Renqiu and Mao, Song and Yan, Xiangchao and Zhou, Hongbin and Zhang, Bo and Peng, Haoyang and Pi, Jiahao and Fu, Daocheng and Wu, Wenjie and Ye, Hancheng and others},\n  journal={arXiv preprint arXiv:2406.11633},\n  year={2024}\n}\n```\n\n## Contact Us\nIf you encounter any issues or have questions, please feel free to contact us via zhouhongbin@pjlab.org.cn.\n"
  },
  {
    "path": "docs/TENSORRT_GETTING_STARTED.md",
    "content": "# Getting Started\n[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is used for model inference speeding up.  \n\nAll the codes are successfully tested in the following enviroments:\n* Linux (18.04, 20.04, 22.04)\n* Python 3.10\n* Pytorch 2.0 or higher\n* CUDA 12.1 or higher\n* TensorRT-LLM 0.11.0 (stable version)\n\n### 1. Conda or Python Environment Preparation\n\n\n* Please follow the step 1, 2 from the [official tutorial](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) of TensorRT-LLM to install the environment.  \n\nNote we used the TensorRT-LLM **stable version `0.11.0`**.\n``` bash\n# Installing on Linux\nStep 1. Retrieve and launch the docker container (optional).\n\n    You can pre-install the environment using the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit) to avoid manual environment configuration.\n\n    ```bash\n    # Obtain and start the basic docker image environment (optional).\n    docker run --rm --ipc=host --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.4.1-devel-ubuntu22.04\n    ```\n    Note: please make sure to set `--ipc=host` as a docker run argument to avoid `Bus error (core dumped)`.\n\nStep 2. Install TensorRT-LLM.\n\n    ```bash\n    # Install dependencies, TensorRT-LLM requires Python 3.10\n    apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs\n\n    # Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.\n    # If you want to install the stable version (corresponding to the release branch), please\n    # remove the `--pre` option.\n    pip3 install tensorrt_llm==0.11.0 --extra-index-url https://pypi.nvidia.com\n\n    # Check installation\n    python3 -c \"import tensorrt_llm\"\n    ```\n\n    Please note that TensorRT-LLM depends on TensorRT. In earlier versions that include TensorRT 8,\n    overwriting an upgraded to a new version may require explicitly running `pip uninstall tensorrt`\n    to uninstall the old version.\n```\n* Once you successfully execute `python3 -c \"import tensorrt_llm\"`, it means that you have completed Environment Preparation.  \n\nTips: If you want to install the environment manually, please note that the version of Python require >= 3.10\n\n\n### 2. Model Compilation\nYou can refer to the [official tutorial](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) to complete the model compilation, or follow our instructions and use the provided scripts to implement it.\n\n#### 2.1 Download [StructEqTable checkpoints](https://huggingface.co/U4R/StructTable-base/tree/v0.2)\n```\ncd StructEqTable-Deploy\n\n# using huggingface-cli download checkpoint\nhuggingface-cli download --resume-download --local-dir-use-symlinks False U4R/StructTable-base --local-dir ckpts/StructTable-base\n\n```\nAfter above steps, the files to directory of StructEqTable-Deploy as follows:  \n```\nStructEqTable-Deploy\n├── ckpts\n│   ├── StructTable-base \n├── docs\n├── struct_eqtable\n├── tools\n```\n\n#### 2.2 Convert Checkpoint and Build Engine\nWe provide a script to help users quickly implement model compilation.\n\n``` bash\ncd StructEqTable-Deploy/tools\n# execute the script to quickly compile the model.\nbash scripts/build_tensorrt.sh \n```\nAfter the script runs successfully, the built models can be found in `ckpts/StructTable-base-TensorRT`.  \nThe file structure in the path `ckpts/StructTable-base-TensorRT` should be as follows:  \n```\nckpts\n├── StructTable-base \n├── StructTable-base-TensorRT \n│   ├── trt_engines \n│   ├── trt_models\n│   ├── visual_engiens\n```\n\n#### 2.3 Run Quickly Demo\nRun the demo/demo.py with TensorRT mode.\n\n``` bash\ncd StructEqTable-Deploy/tools/demo\n\npython demo.py \\\n  --image_path ./demo.png \\\n  --ckpt_path ../../ckpts/StructTable-base \\\n  --output_format latex\n  --tensorrt ../../ckpts/StructTable-base-TensorRT\n```\n\nYou may get output as follows:\n```\ntotal cost time: 0.88s\nTable 0 LATEX format output:\n\\begin{tabular}{|c|c|c|c|}\n\\hline\nQuantity $\\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\\\\n\\hline\nMass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\\\\n\\hline\nLength, translational motion & meter (m) & foot (ft) & inch (in.) \\\\\n\\hline\nTime, $t$ & second (s) & second (s) & second (s) \\\\\n\\hline\nForce, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\\\\n\\hline\nTranslational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\\\\n\\hline\nTranslational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\\\\n\\hline\nAngle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\\\\n\\hline\nRotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\\\\n\\hline\nMoment or torque, rotational action & N-m & lb-ft & lb-inch \\\\\n\\hline\nRotational stiffness constant, $k_\\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\\\\n\\hline\nRotational damping constant, $c_\\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\\\\n\\hline\n\\end{tabular}\n```\n\n\n### 3. Table Visualization\nYou can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) or Visual Studio Code LaTeX Workshop Extension for table visualization.\n\n![](./imgs/demo.png)"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntransformers<=4.47\n"
  },
  {
    "path": "setup.py",
    "content": "from pathlib import Path\nfrom setuptools import find_packages, setup\n\n\ndef write_version_to_file(version, target_file):\n    with open(target_file, 'w') as f:\n        print('__version__ = \"%s\"' % version, file=f)\n\nif __name__ == '__main__':\n    version = '0.3.3'\n    write_version_to_file(version, 'struct_eqtable/version.py')\n    with Path(Path(__file__).parent,\n              'README.md').open(encoding='utf-8') as file:\n        long_description = file.read()\n    setup(\n        name='struct_eqtable',\n        version=version,\n        description='A High-efficiency Open-source Toolkit for Table-to-Latex Transformation',\n        long_description=long_description,\n        long_description_content_type=\"text/markdown\",\n        install_requires=[\n            'torch',\n            'transformers<=4.47',\n        ],\n        python_requires=\">=3.9\",\n        author='Hongbin Zhou, Xiangchao Yan, Bo Zhang',\n        author_email='zhangbo@pjlab.org.cn',\n        url=\"https://github.com/UniModal4Reasoning/StructEqTable-Deploy\",\n        license='Apache License 2.0',\n        packages=find_packages(exclude=['demo']),\n    )\n"
  },
  {
    "path": "struct_eqtable/__init__.py",
    "content": "from .pix2s import Pix2Struct, Pix2StructTensorRT\nfrom .internvl import InternVL, InternVL_LMDeploy\n\nfrom transformers import AutoConfig\n\n\n__ALL_MODELS__ = {\n    'Pix2Struct': Pix2Struct,\n    'Pix2StructTensorRT': Pix2StructTensorRT,\n    'InternVL': InternVL,\n    'InternVL_LMDeploy': InternVL_LMDeploy,\n}\n\n\ndef get_model_name(model_path):\n    model_config = AutoConfig.from_pretrained(\n        model_path,\n        trust_remote_code=True,\n    )\n\n    if 'Pix2Struct' in model_config.architectures[0]:\n        model_name = 'Pix2Struct'\n    elif 'InternVL' in model_config.architectures[0]:\n        model_name = 'InternVL'\n    else:\n        raise ValueError(f\"Unsupported model type: {model_config.architectures[0]}\")\n\n    return model_name\n\n\ndef build_model(model_ckpt='U4R/StructTable-InternVL2-1B', **kwargs):\n    model_name = get_model_name(model_ckpt)\n    if model_name == 'InternVL' and kwargs.get('lmdeploy', False):\n        model_name = 'InternVL_LMDeploy'\n    elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):\n        model_name = 'Pix2StructTensorRT'\n\n    model = __ALL_MODELS__[model_name](\n        model_ckpt, \n        **kwargs\n    )\n\n    return model"
  },
  {
    "path": "struct_eqtable/internvl/__init__.py",
    "content": "from .internvl import InternVL\nfrom .internvl_lmdeploy import InternVL_LMDeploy"
  },
  {
    "path": "struct_eqtable/internvl/conversation.py",
    "content": "\"\"\"\nConversation prompt templates.\n\nWe kindly request that you import fastchat instead of copying this file if you wish to use it.\nIf you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.\n\"\"\"\n\nimport dataclasses\nfrom enum import IntEnum, auto\nfrom typing import Any, Dict, List, Tuple, Union\n\n\nclass SeparatorStyle(IntEnum):\n    \"\"\"Separator styles.\"\"\"\n\n    ADD_COLON_SINGLE = auto()\n    ADD_COLON_TWO = auto()\n    ADD_COLON_SPACE_SINGLE = auto()\n    NO_COLON_SINGLE = auto()\n    NO_COLON_TWO = auto()\n    ADD_NEW_LINE_SINGLE = auto()\n    LLAMA2 = auto()\n    CHATGLM = auto()\n    CHATML = auto()\n    CHATINTERN = auto()\n    DOLLY = auto()\n    RWKV = auto()\n    PHOENIX = auto()\n    ROBIN = auto()\n    FALCON_CHAT = auto()\n    CHATGLM3 = auto()\n    INTERNVL_ZH = auto()\n    MPT = auto()\n\n\n@dataclasses.dataclass\nclass Conversation:\n    \"\"\"A class that manages prompt templates and keeps all conversation history.\"\"\"\n\n    # The name of this template\n    name: str\n    # The template of the system prompt\n    system_template: str = '{system_message}'\n    # The system message\n    system_message: str = ''\n    # The names of two roles\n    roles: Tuple[str] = ('USER', 'ASSISTANT')\n    # All messages. Each item is (role, message).\n    messages: List[List[str]] = ()\n    # The number of few shot examples\n    offset: int = 0\n    # The separator style and configurations\n    sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE\n    sep: str = '\\n'\n    sep2: str = None\n    # Stop criteria (the default one is EOS token)\n    stop_str: Union[str, List[str]] = None\n    # Stops generation if meeting any token in this list\n    stop_token_ids: List[int] = None\n\n    def get_prompt(self) -> str:\n        \"\"\"Get the prompt for generation.\"\"\"\n        system_prompt = self.system_template.format(system_message=self.system_message)\n        if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + ': ' + message + self.sep\n                else:\n                    ret += role + ':'\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + ': ' + message + seps[i % 2]\n                else:\n                    ret += role + ':'\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + ': ' + message + self.sep\n                else:\n                    ret += role + ': '  # must be end with a space\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:\n            ret = '' if system_prompt == '' else system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + '\\n' + message + self.sep\n                else:\n                    ret += role + '\\n'\n            return ret\n        elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:\n            ret = system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.NO_COLON_TWO:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + message + seps[i % 2]\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.RWKV:\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += (\n                        role\n                        + ': '\n                        + message.replace('\\r\\n', '\\n').replace('\\n\\n', '\\n')\n                    )\n                    ret += '\\n\\n'\n                else:\n                    ret += role + ':'\n            return ret\n        elif self.sep_style == SeparatorStyle.LLAMA2:\n            seps = [self.sep, self.sep2]\n            if self.system_message:\n                ret = system_prompt\n            else:\n                ret = '[INST] '\n            for i, (role, message) in enumerate(self.messages):\n                tag = self.roles[i % 2]\n                if message:\n                    if i == 0:\n                        ret += message + ' '\n                    else:\n                        ret += tag + ' ' + message + seps[i % 2]\n                else:\n                    ret += tag\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATGLM:\n            # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308\n            # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926\n            round_add_n = 1 if self.name == 'chatglm2' else 0\n            if system_prompt:\n                ret = system_prompt + self.sep\n            else:\n                ret = ''\n\n            for i, (role, message) in enumerate(self.messages):\n                if i % 2 == 0:\n                    ret += f'[Round {i//2 + round_add_n}]{self.sep}'\n\n                if message:\n                    ret += f'{role}：{message}{self.sep}'\n                else:\n                    ret += f'{role}：'\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATML:\n            ret = '' if system_prompt == '' else system_prompt + self.sep + '\\n'\n            for role, message in self.messages:\n                if message:\n                    ret += role + '\\n' + message + self.sep + '\\n'\n                else:\n                    ret += role + '\\n'\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATGLM3:\n            ret = ''\n            if self.system_message:\n                ret += system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + '\\n' + ' ' + message\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATINTERN:\n            # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                # if i % 2 == 0:\n                #     ret += \"<s>\"\n                if message:\n                    ret += role + ':' + message + seps[i % 2] + '\\n'\n                else:\n                    ret += role + ':'\n            return ret\n        elif self.sep_style == SeparatorStyle.DOLLY:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + ':\\n' + message + seps[i % 2]\n                    if i % 2 == 1:\n                        ret += '\\n\\n'\n                else:\n                    ret += role + ':\\n'\n            return ret\n        elif self.sep_style == SeparatorStyle.PHOENIX:\n            ret = system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + ': ' + '<s>' + message + '</s>'\n                else:\n                    ret += role + ': ' + '<s>'\n            return ret\n        elif self.sep_style == SeparatorStyle.ROBIN:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + ':\\n' + message + self.sep\n                else:\n                    ret += role + ':\\n'\n            return ret\n        elif self.sep_style == SeparatorStyle.FALCON_CHAT:\n            ret = ''\n            if self.system_message:\n                ret += system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + ': ' + message + self.sep\n                else:\n                    ret += role + ':'\n\n            return ret\n        elif self.sep_style == SeparatorStyle.INTERNVL_ZH:\n            seps = [self.sep, self.sep2]\n            ret = self.system_message + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + ': ' + message + seps[i % 2]\n                else:\n                    ret += role + ':'\n            return ret\n        elif self.sep_style == SeparatorStyle.MPT:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        else:\n            raise ValueError(f'Invalid style: {self.sep_style}')\n\n    def set_system_message(self, system_message: str):\n        \"\"\"Set the system message.\"\"\"\n        self.system_message = system_message\n\n    def append_message(self, role: str, message: str):\n        \"\"\"Append a new message.\"\"\"\n        self.messages.append([role, message])\n\n    def update_last_message(self, message: str):\n        \"\"\"Update the last output.\n\n        The last message is typically set to be None when constructing the prompt,\n        so we need to update it in-place after getting the response from a model.\n        \"\"\"\n        self.messages[-1][1] = message\n\n    def to_gradio_chatbot(self):\n        \"\"\"Convert the conversation to gradio chatbot format.\"\"\"\n        ret = []\n        for i, (role, msg) in enumerate(self.messages[self.offset :]):\n            if i % 2 == 0:\n                ret.append([msg, None])\n            else:\n                ret[-1][-1] = msg\n        return ret\n\n    def to_openai_api_messages(self):\n        \"\"\"Convert the conversation to OpenAI chat completion format.\"\"\"\n        ret = [{'role': 'system', 'content': self.system_message}]\n\n        for i, (_, msg) in enumerate(self.messages[self.offset :]):\n            if i % 2 == 0:\n                ret.append({'role': 'user', 'content': msg})\n            else:\n                if msg is not None:\n                    ret.append({'role': 'assistant', 'content': msg})\n        return ret\n\n    def copy(self):\n        return Conversation(\n            name=self.name,\n            system_template=self.system_template,\n            system_message=self.system_message,\n            roles=self.roles,\n            messages=[[x, y] for x, y in self.messages],\n            offset=self.offset,\n            sep_style=self.sep_style,\n            sep=self.sep,\n            sep2=self.sep2,\n            stop_str=self.stop_str,\n            stop_token_ids=self.stop_token_ids,\n        )\n\n    def dict(self):\n        return {\n            'template_name': self.name,\n            'system_message': self.system_message,\n            'roles': self.roles,\n            'messages': self.messages,\n            'offset': self.offset,\n        }\n\n\n# A global registry for all conversation templates\nconv_templates: Dict[str, Conversation] = {}\n\n\ndef register_conv_template(template: Conversation, override: bool = False):\n    \"\"\"Register a new conversation template.\"\"\"\n    if not override:\n        assert (\n            template.name not in conv_templates\n        ), f'{template.name} has been registered.'\n\n    conv_templates[template.name] = template\n\n\ndef get_conv_template(name: str) -> Conversation:\n    \"\"\"Get a conversation template.\"\"\"\n    return conv_templates[name].copy()\n\n\n# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference\n# is that during training, the preprocessing function for the Hermes-2 template doesn't add\n# <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.\n# Therefore, they are completely equivalent during inference.\nregister_conv_template(\n    Conversation(\n        name='Hermes-2',\n        system_template='<|im_start|>system\\n{system_message}',\n        # note: The new system prompt was not used here to avoid changes in benchmark performance.\n        # system_message='我是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',\n        # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型，英文名叫InternVL, 是一个有用无害的人工智能助手。',\n        system_message='You are a Table Image to LaTeX/Markdown/HMTL Code converter.',\n        roles=('<|im_start|>user\\n', '<|im_start|>assistant\\n'),\n        sep_style=SeparatorStyle.MPT,\n        sep='<|im_end|>',\n        stop_token_ids=[\n            2,\n            6,\n            7,\n            8,\n        ],\n        stop_str='<|endoftext|>',\n    )\n)\n\n\nregister_conv_template(\n    Conversation(\n        name='internlm2-chat',\n        system_template='<|im_start|>system\\n{system_message}',\n        # note: The new system prompt was not used here to avoid changes in benchmark performance.\n        # system_message='我是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',\n        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型，英文名叫InternVL, 是一个有用无害的人工智能助手。',\n        roles=('<|im_start|>user\\n', '<|im_start|>assistant\\n'),\n        sep_style=SeparatorStyle.MPT,\n        sep='<|im_end|>',\n        stop_token_ids=[\n            2,\n            92543,\n            92542\n        ]\n    )\n)\n\n\nregister_conv_template(\n    Conversation(\n        name='phi3-chat',\n        system_template='<|system|>\\n{system_message}',\n        # note: The new system prompt was not used here to avoid changes in benchmark performance.\n        # system_message='我是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',\n        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型，英文名叫InternVL, 是一个有用无害的人工智能助手。',\n        roles=('<|user|>\\n', '<|assistant|>\\n'),\n        sep_style=SeparatorStyle.MPT,\n        sep='<|end|>',\n        stop_token_ids=[\n            2,\n            32000,\n            32007\n        ]\n    )\n)\n"
  },
  {
    "path": "struct_eqtable/internvl/internvl.py",
    "content": "import torch\n\nfrom torch import nn\nfrom transformers import AutoModel, AutoTokenizer, AutoImageProcessor, GenerationConfig\n\nfrom .conversation import get_conv_template\n\nclass InternVL(nn.Module):\n    def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, max_time=30, flash_attn=True, **kwargs):\n        super().__init__()\n        self.model_path = model_path\n        self.max_new_tokens = max_new_tokens\n        self.max_generate_time = max_time\n        self.flash_attn = flash_attn\n\n        # init model and image processor from ckpt path\n        self.init_tokenizer(model_path)\n        self.init_image_processor(model_path)\n        self.init_model(model_path)\n\n        self.prompt_template = {\n            'latex': '<latex>',\n            'html': '<html>',\n            'markdown': '<markdown>',\n        }\n        # support output format\n        self.supported_output_format = ['latex', 'html', 'markdown']\n\n    def init_model(self, model_path):\n        self.model = AutoModel.from_pretrained(\n            model_path,\n            trust_remote_code=True,\n            torch_dtype=torch.bfloat16,\n            low_cpu_mem_usage=True,\n            use_flash_attn=self.flash_attn,\n        )\n        self.model.eval()\n    \n    def init_image_processor(self, image_processor_path):\n        self.image_processor = AutoImageProcessor.from_pretrained(\n            image_processor_path,\n            trust_remote_code=True,\n        )\n\n    def init_tokenizer(self, tokenizer_path):\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            tokenizer_path,\n            trust_remote_code=True,\n            use_fast=False,\n        )\n\n        self.image_context_token = '<IMG_CONTEXT>'\n        self.image_token_num = 256\n        self.image_start_token = '<img>'\n        self.image_end_token = '</img>'\n        self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(self.image_context_token)\n    \n    def format_image_tokens(self, path_num):\n        return f'{self.image_start_token}{self.image_context_token* self.image_token_num * path_num}{self.image_end_token}'\n\n    def forward(self, images, output_format='latex', **kwargs):\n        # process image to tokens\n        if not isinstance(images, list):\n            images = [images] \n        \n        pixel_values_list = []\n        for image in images:\n            path_images = self.dynamic_preprocess(\n                image, image_size=448, max_num=12\n            )\n            pixel_values = self.image_processor(\n                path_images, \n                return_tensors='pt'\n            )['pixel_values'].to(torch.bfloat16)\n            pixel_values_list.append(pixel_values)\n        \n        batch_size = len(pixel_values_list)\n        conversation_list = []\n        for bs_idx in range(batch_size):\n            pixel_values= pixel_values_list[bs_idx].to(torch.bfloat16)\n\n            image_tokens = self.format_image_tokens(pixel_values.shape[0])\n            question = '<image>\\n' + self.prompt_template[output_format]\n            answer = None\n        \n            template = get_conv_template(self.model.config.template)\n            template.append_message(template.roles[0], question)\n            template.append_message(template.roles[1], answer)\n            conversation = template.get_prompt()\n            conversation = conversation.replace('<image>', image_tokens, 1)\n            conversation_list.append(conversation)\n\n        device = next(self.parameters()).device\n        self.tokenizer.padding_side = 'left'\n        model_inputs = self.tokenizer(\n            conversation_list, \n            return_tensors='pt', \n            padding=True,\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n        ).to(device)\n        pixel_values = torch.cat(pixel_values_list, axis=0).to(device)\n\n        # generation config\n        generation_config = dict(\n            max_new_tokens=self.max_new_tokens,\n            max_time=self.max_generate_time,\n            img_context_token_id=self.img_context_token_id,\n            pad_token_id=self.tokenizer.pad_token_id,\n            eos_token_id=self.tokenizer.eos_token_id,\n            do_sample=False,\n            no_repeat_ngram_size=20,\n        )\n\n        # generate text from image tokens\n        model_output = self.model.generate(\n            pixel_values=pixel_values,\n            input_ids=model_inputs.input_ids,\n            attention_mask=model_inputs.attention_mask, \n            **generation_config,\n            # **kwargs\n        )\n\n        batch_decode_texts = self.tokenizer.batch_decode(\n            model_output,\n            skip_special_tokens=True\n        )\n        return batch_decode_texts\n    \n    def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):\n        best_ratio_diff = float('inf')\n        best_ratio = (1, 1)\n        area = width * height\n        for ratio in target_ratios:\n            target_aspect_ratio = ratio[0] / ratio[1]\n            ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n            if ratio_diff < best_ratio_diff:\n                best_ratio_diff = ratio_diff\n                best_ratio = ratio\n            elif ratio_diff == best_ratio_diff:\n                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                    best_ratio = ratio\n        return best_ratio\n\n    def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):\n        orig_width, orig_height = image.size\n        aspect_ratio = orig_width / orig_height\n\n        # calculate the existing image aspect ratio\n        target_ratios = set(\n            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n            i * j <= max_num and i * j >= min_num)\n        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n        # find the closest aspect ratio to the target\n        target_aspect_ratio = self.find_closest_aspect_ratio(\n            aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n        # calculate the target width and height\n        target_width = image_size * target_aspect_ratio[0]\n        target_height = image_size * target_aspect_ratio[1]\n        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n        # resize the image\n        resized_img = image.resize((target_width, target_height))\n        processed_images = []\n        for i in range(blocks):\n            box = (\n                (i % (target_width // image_size)) * image_size,\n                (i // (target_width // image_size)) * image_size,\n                ((i % (target_width // image_size)) + 1) * image_size,\n                ((i // (target_width // image_size)) + 1) * image_size\n            )\n            # split the image\n            split_img = resized_img.crop(box)\n            processed_images.append(split_img)\n        assert len(processed_images) == blocks\n        if use_thumbnail and len(processed_images) != 1:\n            thumbnail_img = image.resize((image_size, image_size))\n            processed_images.append(thumbnail_img)\n        return processed_images\n"
  },
  {
    "path": "struct_eqtable/internvl/internvl_lmdeploy.py",
    "content": "import torch\nfrom torch import nn\n\nfrom transformers import AutoTokenizer\ntry:\n    from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, ChatTemplateConfig\nexcept:\n    print(\"\\033[93mimport lmdeploy failed, if do not use lmdeploy, ignore this message\\033[0m\")\n\n\nclass InternVL_LMDeploy(nn.Module):\n    def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, batch_size=4, **kwargs):\n        super().__init__()\n        self.model_path = model_path\n        self.max_new_tokens = max_new_tokens\n        self.max_batch_size = batch_size\n\n        # init model and tokenizer from ckpt path\n        self.init_tokenizer(model_path)\n        self.init_model(model_path)\n\n        self.prompt_template = {\n            'latex': '<latex>',\n            'html': '<html>',\n            'markdown': '<markdown>',\n        }\n        # support output format\n        self.supported_output_format = ['latex', 'html', 'markdown']\n    \n    def init_tokenizer(self, tokenizer_path):\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            tokenizer_path,\n            trust_remote_code=True,\n            use_fast=False,\n        )\n\n    def init_model(self, model_path):\n        engine_config = PytorchEngineConfig(\n            dtype='bfloat16',\n            max_batch_size=self.max_batch_size,\n            cache_max_entry_count=0.1\n        )\n        self.pipeline = pipeline(\n            model_path,\n            backend_config=engine_config,\n            chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2')\n        )\n\n    def forward(self, images, output_format='latex', **kwargs):\n        # process image to tokens\n        if not isinstance(images, list):\n            images = [images] \n        \n        prompts = [self.prompt_template[output_format]] * len(images)\n        generation_config = GenerationConfig(\n            max_new_tokens=self.max_new_tokens,\n            do_sample=False,\n            temperature=1.0,\n            stop_token_ids=[self.tokenizer.eos_token_id],\n        )\n        \n        responses = self.pipeline(\n            [(x, y) for x, y in zip(prompts, images)],\n            gen_config=generation_config,\n        )\n        batch_decode_texts = [responce.text for responce in responses]\n        return batch_decode_texts\n    \n\n"
  },
  {
    "path": "struct_eqtable/pix2s/__init__.py",
    "content": "from .pix2s import Pix2Struct\nfrom .pix2s_trt import Pix2StructTensorRT\n    "
  },
  {
    "path": "struct_eqtable/pix2s/pix2s.py",
    "content": "import torch\n\nfrom torch import nn\nfrom transformers import AutoModelForVision2Seq, AutoProcessor\n\n\nclass Pix2Struct(nn.Module):\n    def __init__(self, model_path='U4R/StructTable-base', max_new_tokens=1024, max_time=30, **kwargs):\n        super().__init__()\n        self.model_path = model_path\n        self.max_new_tokens = max_new_tokens\n        self.max_generate_time = max_time\n\n        # init model and image processor from ckpt path\n        self.init_image_processor(model_path)\n        self.init_model(model_path)\n\n        self.special_str_list = ['\\\\midrule', '\\\\hline']\n        self.supported_output_format = ['latex']\n\n    def postprocess_latex_code(self, code):\n        for special_str in self.special_str_list:\n            code = code.replace(special_str, special_str + ' ')\n        return code\n\n    def init_model(self, model_path):\n        self.model = AutoModelForVision2Seq.from_pretrained(model_path)\n        self.model.eval()\n\n    def init_image_processor(self, image_processor_path):\n        self.data_processor = AutoProcessor.from_pretrained(image_processor_path)\n\n    def forward(self, image, **kwargs):\n        # process image to tokens\n        image_tokens = self.data_processor.image_processor(\n            images=image,\n            return_tensors='pt',\n        )\n\n        device = next(self.parameters()).device\n        for k, v in image_tokens.items():\n            image_tokens[k] = v.to(device)\n\n        # generate text from image tokens\n        model_output = self.model.generate(\n            flattened_patches=image_tokens['flattened_patches'],\n            attention_mask=image_tokens['attention_mask'], \n            max_new_tokens=self.max_new_tokens,\n            max_time=self.max_generate_time,\n            no_repeat_ngram_size=20,\n        )\n\n        latex_codes = self.data_processor.batch_decode(model_output, skip_special_tokens=True)\n        # postprocess\n        for i, code in enumerate(latex_codes):\n            latex_codes[i] = self.postprocess_latex_code(code)\n\n        return latex_codes\n"
  },
  {
    "path": "struct_eqtable/pix2s/pix2s_trt.py",
    "content": "import os\nimport time\nimport json\n\nimport torch\nimport torch.nn as nn\n\ntry:\n    import tensorrt_llm\n    import tensorrt as trt\n    import tensorrt_llm.profiler as profiler\n\n    from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy\n    from tensorrt_llm.lora_manager import LoraManager\n    from tensorrt_llm.runtime import Session, TensorInfo, ModelConfig, SamplingConfig\nexcept:\n    print(\"\\033[93mimport tensorrt_llm failed, if do not use tensorrt, ignore this message\\033[0m\")\n\nfrom typing import List\nfrom transformers import AutoProcessor, AutoTokenizer, AutoConfig\n\n\ndef trt_dtype_to_torch(dtype):\n    if dtype == trt.float16:\n        return torch.float16\n    elif dtype == trt.float32:\n        return torch.float32\n    elif dtype == trt.int32:\n        return torch.int32\n    elif dtype == trt.bfloat16:\n        return torch.bfloat16\n    else:\n        raise TypeError(\"%s is not supported\" % dtype)\n\n\nclass Pix2StructTensorRT(nn.Module):\n\n    def __init__(self, model_path, tensorrt_path, batch_size=1, max_new_tokens=4096, **kwargs):\n        \n        self.model_ckpt_path = model_path\n        self.tensorrt_path = tensorrt_path\n        self.batch_size = batch_size\n        self.max_new_tokens = max_new_tokens\n\n        self.llm_engine_path = os.path.join(tensorrt_path, 'llm_engines')\n        self.visual_engine_path = os.path.join(tensorrt_path, 'visual_engines')\n        \n        device_id = torch.cuda.current_device() % torch.cuda.device_count()\n        self.device_id = device_id\n        self.device = \"cuda:%d\" % (device_id)\n        \n        self.stream = torch.cuda.Stream(torch.cuda.current_device())\n        torch.cuda.set_stream(self.stream)\n\n        # parse model type from visual engine config\n        with open(os.path.join(self.visual_engine_path, \"config.json\"),\n                  \"r\") as f:\n            config = json.load(f)\n        self.model_type = config['builder_config']['model_type']\n        self.vision_precision = config['builder_config']['precision']\n\n        self.vision_precision = 'float16'\n        self.decoder_llm = not (\n            't5' in self.model_type\n            or self.model_type in ['nougat', 'pix2struct', 'StructEqTable']\n        )  # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs\n\n        self.profiling_iterations = 20\n\n        self.init_image_encoder()\n        self.init_tokenizer()\n        self.init_llm()\n        self.init_image_processor()\n\n        self.special_str_list = ['\\\\midrule', '\\\\hline']\n        self.supported_output_format = ['latex']\n\n    def postprocess_latex_code(self, code):\n        for special_str in self.special_str_list:\n            code = code.replace(special_str, special_str + ' ')\n        return code\n\n    def init_image_processor(self):\n        self.data_processor = AutoProcessor.from_pretrained(\n            self.model_ckpt_path)\n\n    def init_tokenizer(self):\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            self.model_ckpt_path, use_fast=True, use_legacy=False)\n        # self.tokenizer.padding_side = \"right\"\n\n    def init_image_encoder(self):\n        vision_encoder_path = os.path.join(self.visual_engine_path,\n                                           'visual_encoder.engine')\n        with open(vision_encoder_path, 'rb') as f:\n            engine_buffer = f.read()\n        self.visual_encoder_session = Session.from_serialized_engine(\n            engine_buffer)\n\n    def init_llm(self):\n\n        self.model = TRTLLMEncDecModel.from_engine(\n            os.path.basename(self.model_ckpt_path),\n            self.llm_engine_path,\n            skip_encoder=self.model_type in ['nougat', 'pix2struct', 'StructEqTable'],\n            debug_mode=False,\n            stream=self.stream)\n\n        self.model_config = self.model.decoder_model_config\n        self.runtime_mapping = self.model.decoder_runtime_mapping\n\n    def __call__(self, image, **kwargs):\n        # process image to tokens\n        image_tokens = self.data_processor.image_processor(\n            images=image,\n            return_tensors='pt',\n        )\n\n        for k, v in image_tokens.items():\n            image_tokens[k] = v.cuda()\n\n        model_output = self.run(\n            flattened_patches=image_tokens['flattened_patches'],\n            attention_mask=image_tokens['attention_mask'], \n            max_new_tokens=self.max_new_tokens\n        )\n\n        # postprocess\n        latex_codes = []\n        for i, code in enumerate(model_output):\n            latex_codes.append(self.postprocess_latex_code(code[0]))\n\n        return latex_codes\n\n    def preprocess(self, warmup, pre_prompt, post_prompt, image,\n                   attention_mask):\n        if not warmup:\n            profiler.start(\"Vision\")\n\n        visual_features, visual_atts = self.get_visual_features(\n            torch.stack(image['image_patches'], dim=0)\n            if self.model_type == 'fuyu' else image, attention_mask)\n\n        if not warmup:\n            profiler.stop(\"Vision\")\n       \n        pre_input_ids = self.tokenizer(pre_prompt,\n                                        return_tensors=\"pt\",\n                                        padding=True).input_ids\n        if post_prompt[0] is not None:\n            post_input_ids = self.tokenizer(post_prompt,\n                                            return_tensors=\"pt\",\n                                            padding=True).input_ids\n            length = pre_input_ids.shape[1] + post_input_ids.shape[\n                1] + visual_atts.shape[1]\n        else:\n            post_input_ids = None\n            length = pre_input_ids.shape[1] + visual_atts.shape[1]\n\n        input_lengths = torch.IntTensor([length] * 1).to(\n            torch.int32)\n\n        input_ids, ptuning_args = self.setup_fake_prompts(\n            visual_features, pre_input_ids, post_input_ids, input_lengths)\n\n        return input_ids, input_lengths, ptuning_args, visual_features\n\n    def generate(self, pre_prompt, post_prompt, image, decoder_input_ids,\n                 max_new_tokens, attention_mask, warmup):\n        if not warmup:\n            profiler.start(\"Generate\")\n\n        input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(\n            warmup, pre_prompt, post_prompt, image, attention_mask)\n\n        if warmup: return None\n\n        profiler.start(\"LLM\")\n\n        # Trim encoder input_ids to match visual features shape\n        ids_shape = (self.batch_size, visual_features.shape[1])\n\n        input_ids = torch.ones(ids_shape, dtype=torch.int32)\n\n        output_ids = self.model.generate(\n            input_ids,\n            decoder_input_ids,\n            max_new_tokens,\n            num_beams=1,\n            bos_token_id=self.tokenizer.bos_token_id,\n            pad_token_id=self.tokenizer.pad_token_id,\n            eos_token_id=self.tokenizer.eos_token_id,\n            debug_mode=False,\n            prompt_embedding_table=ptuning_args[0],\n            prompt_tasks=ptuning_args[1],\n            prompt_vocab_size=ptuning_args[2],\n            attention_mask=attention_mask)\n\n        # Reset input_lengths to match decoder_input_ids\n        input_lengths = torch.ones(input_lengths.shape,\n                                    dtype=input_lengths.dtype)\n        profiler.stop(\"LLM\")\n\n        if tensorrt_llm.mpi_rank() == 0:\n            # Extract a list of tensors of shape beam_width x output_ids.\n            output_beams_list = [\n                self.tokenizer.batch_decode(\n                    output_ids[batch_idx, :, input_lengths[batch_idx]:],\n                    skip_special_tokens=True)\n                for batch_idx in range(self.batch_size)\n            ]\n\n            stripped_text = [[\n                output_beams_list[batch_idx][beam_idx].strip()\n                for beam_idx in range(1)\n            ] for batch_idx in range(self.batch_size)]\n            profiler.stop(\"Generate\")\n            return stripped_text\n        else:\n            profiler.stop(\"Generate\")\n            return None\n        \n    def get_visual_features(self, image, attention_mask):\n        visual_features = {\n            'input':\n            image.to(\n                tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))\n        }\n        if attention_mask is not None:\n            visual_features['attention_mask'] = attention_mask\n        tensor_info = [\n            TensorInfo('input', str_dtype_to_trt(self.vision_precision),\n                       image.shape)\n        ]\n        if attention_mask is not None:\n            tensor_info.append(\n                TensorInfo('attention_mask', trt.DataType.INT32,\n                           attention_mask.shape))\n        visual_output_info = self.visual_encoder_session.infer_shapes(\n            tensor_info)\n        visual_outputs = {\n            t.name: torch.empty(tuple(t.shape),\n                                dtype=trt_dtype_to_torch(t.dtype),\n                                device=image.device)\n            for t in visual_output_info\n        }\n\n        ok = self.visual_encoder_session.run(visual_features, visual_outputs,\n                                             self.stream.cuda_stream)\n        assert ok, \"Runtime execution failed for vision encoder session\"\n        self.stream.synchronize()\n\n        image_embeds = visual_outputs['output']\n        image_atts = torch.ones(image_embeds.size()[:-1],\n                                dtype=torch.long).to(image.device)\n\n        return image_embeds, image_atts\n    \n    def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,\n                           input_lengths):\n        # Assemble fake prompts which points to image embedding actually\n        fake_prompt_id = torch.arange(\n            self.model_config.vocab_size, self.model_config.vocab_size +\n            visual_features.shape[0] * visual_features.shape[1])\n        fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0],\n                                                visual_features.shape[1])\n\n        if post_input_ids is not None:\n            input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]\n        else:\n            input_ids = [fake_prompt_id, pre_input_ids]\n        \n        input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)\n\n        if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():\n            ptuning_args = self.ptuning_setup(visual_features, input_ids,\n                                              input_lengths)\n        else:\n            ptuning_args = [None, None, None]\n\n        return input_ids, ptuning_args\n\n    def ptuning_setup(self, prompt_table, input_ids, input_lengths):\n        hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size\n        if prompt_table is not None:\n            task_vocab_size = torch.tensor(\n                [prompt_table.shape[1]],\n                dtype=torch.int32,\n            ).cuda()\n            prompt_table = prompt_table.view(\n                (prompt_table.shape[0] * prompt_table.shape[1],\n                 prompt_table.shape[2]))\n            assert prompt_table.shape[\n                1] == hidden_size, \"Prompt table dimensions do not match hidden size\"\n\n            prompt_table = prompt_table.cuda().to(\n                dtype=tensorrt_llm._utils.str_dtype_to_torch(\n                    self.model_config.dtype))\n        else:\n            prompt_table = torch.empty([1, hidden_size]).cuda()\n            task_vocab_size = torch.zeros([1]).cuda()\n\n        if self.model_config.remove_input_padding:\n            tasks = torch.zeros([torch.sum(input_lengths)],\n                                dtype=torch.int32).cuda()\n            if self.decoder_llm: tasks = tasks.unsqueeze(0)\n        else:\n            tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()\n\n        return [prompt_table, tasks, task_vocab_size]\n\n    def setup_inputs(self, input_text, raw_image):\n        attention_mask = None\n       \n        image_processor = AutoProcessor.from_pretrained(self.model_ckpt_path)\n        if input_text is None:\n            input_text = \"\"\n        inputs = image_processor(\n            images=raw_image,\n            text=input_text,\n            return_tensors=\"pt\",\n        )\n        image = inputs['flattened_patches']\n        image = image.expand(self.batch_size, -1, -1).contiguous()\n        attention_mask = inputs['attention_mask'].to(self.device).to(\n            torch.int)\n        attention_mask = attention_mask.expand(self.batch_size,\n                                                -1).contiguous()\n        pre_prompt = \"\"\n        post_prompt = None\n\n        # Repeat inputs to match batch size\n        pre_prompt = [pre_prompt] * self.batch_size\n        post_prompt = [post_prompt] * self.batch_size\n        image = image.to(self.device)\n\n        # Generate decoder_input_ids for enc-dec models\n        # Custom prompts can be added as:\n        # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids\n        if self.decoder_llm:\n            decoder_input_ids = None\n        else:\n            config = AutoConfig.from_pretrained(self.model_ckpt_path)\n            decoder_start_id = config.decoder_start_token_id  # T5\n            if decoder_start_id is None:\n                decoder_start_id = config.decoder.bos_token_id  # Nougat\n\n            decoder_input_ids = torch.IntTensor([[decoder_start_id]])\n            decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))\n\n        return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask\n\n    def run(self, flattened_patches, attention_mask, max_new_tokens):\n        # input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs(\n        #     None, raw_image)\n        pre_prompt = [\"\"] * self.batch_size\n        post_prompt = [None] * self.batch_size\n        config = AutoConfig.from_pretrained(self.model_ckpt_path)\n        decoder_start_id = config.decoder_start_token_id  # T5 \n        decoder_input_ids = torch.IntTensor([[decoder_start_id]])\n        decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))\n\n        processed_image = flattened_patches.expand(self.batch_size, -1, -1).contiguous()\n        attention_mask = attention_mask.to(self.device).to(torch.int)\n        attention_mask = attention_mask.expand(self.batch_size,-1).contiguous()\n\n        self.generate(pre_prompt,\n                       post_prompt,\n                       processed_image,\n                       decoder_input_ids,\n                       max_new_tokens,\n                       attention_mask=attention_mask,\n                       warmup=True)\n        # num_iters = self.profiling_iterations if self.args.run_profiling else 1\n        num_iters = 1\n        # print(num_iters)\n        for _ in range(num_iters):\n            output_text = self.generate(pre_prompt,\n                                         post_prompt,\n                                         processed_image,\n                                         decoder_input_ids,\n                                         max_new_tokens,\n                                         attention_mask=attention_mask,\n                                         warmup=False)\n        # if self.runtime_rank == 0:\n        #     self.print_result(input_text, output_text)\n        return output_text\n\n\ndef read_config(config_path):\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    builder_config = config['build_config']\n    plugin_config = builder_config['plugin_config']\n    pretrained_config = config['pretrained_config']\n    lora_config = builder_config['lora_config']\n    auto_parallel_config = builder_config['auto_parallel_config']\n    use_gpt_attention_plugin = plugin_config[\"gpt_attention_plugin\"]\n    remove_input_padding = plugin_config[\"remove_input_padding\"]\n    use_lora_plugin = plugin_config[\"lora_plugin\"]\n    tp_size = pretrained_config['mapping']['tp_size']\n    pp_size = pretrained_config['mapping']['pp_size']\n    gpus_per_node = auto_parallel_config['gpus_per_node']\n    world_size = tp_size * pp_size\n    assert world_size == tensorrt_llm.mpi_world_size(), \\\n        f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'\n    num_heads = pretrained_config[\"num_attention_heads\"]\n    hidden_size = pretrained_config[\"hidden_size\"]\n    head_size = pretrained_config[\"head_size\"]\n    vocab_size = pretrained_config[\"vocab_size\"]\n    max_batch_size = builder_config[\"max_batch_size\"]\n    max_beam_width = builder_config[\"max_beam_width\"]\n    num_layers = pretrained_config[\"num_hidden_layers\"]\n    num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)\n\n    assert (num_heads % tp_size) == 0\n    num_heads = num_heads // tp_size\n    hidden_size = hidden_size // tp_size\n    num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size\n\n    cross_attention = pretrained_config[\"architecture\"] == \"DecoderModel\"\n    skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False)\n    has_position_embedding = pretrained_config[\"has_position_embedding\"]\n    has_token_type_embedding = hasattr(pretrained_config, \"type_vocab_size\")\n    use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False)\n    dtype = pretrained_config[\"dtype\"]\n\n    paged_kv_cache = plugin_config['paged_kv_cache']\n    tokens_per_block = plugin_config['tokens_per_block']\n\n    gather_context_logits = builder_config.get('gather_context_logits', False)\n    gather_generation_logits = builder_config.get('gather_generation_logits',\n                                                  False)\n    max_prompt_embedding_table_size = builder_config.get(\n        'max_prompt_embedding_table_size', 0)\n\n    model_config = ModelConfig(\n        num_heads=num_heads,\n        num_kv_heads=num_kv_heads,\n        hidden_size=hidden_size,\n        head_size=head_size,\n        max_batch_size=max_batch_size,\n        max_beam_width=max_beam_width,\n        vocab_size=vocab_size,\n        num_layers=num_layers,\n        gpt_attention_plugin=use_gpt_attention_plugin,\n        remove_input_padding=remove_input_padding,\n        paged_kv_cache=paged_kv_cache,\n        tokens_per_block=tokens_per_block,\n        cross_attention=cross_attention,\n        has_position_embedding=has_position_embedding,\n        has_token_type_embedding=has_token_type_embedding,\n        use_custom_all_reduce=use_custom_all_reduce,\n        dtype=dtype,\n        gather_context_logits=gather_context_logits,\n        gather_generation_logits=gather_generation_logits,\n        max_prompt_embedding_table_size=max_prompt_embedding_table_size,\n        lora_plugin=use_lora_plugin,\n        lora_target_modules=lora_config.get('lora_target_modules'),\n        trtllm_modules_to_hf_modules=lora_config.get(\n            'trtllm_modules_to_hf_modules'),\n        skip_cross_qkv=skip_cross_qkv,\n    )\n\n    return model_config, tp_size, pp_size, gpus_per_node, dtype\n\n\nclass Mapping(object):\n    def __init__(\n            self,\n            world_size=1,\n            rank=0,\n            gpus_per_node=8,\n            tp_size=1,\n            pp_size=1,\n            moe_tp_size=-1,  # -1 means no moe\n            moe_ep_size=-1):  # -1 means no moe\n        # set default values for non-moe cases\n        if moe_tp_size == -1:\n            moe_tp_size = tp_size\n            moe_ep_size = 1\n\n        if pp_size * tp_size != world_size:\n            raise ValueError(\n                f\"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}\"\n            )\n\n        moe_tp_ep_size = moe_tp_size * moe_ep_size\n        if moe_tp_ep_size != tp_size:\n            raise ValueError(\n                f\"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}\"\n            )\n\n        self.tp_size = tp_size\n        self.pp_size = pp_size\n        self.moe_tp_size = moe_tp_size\n        self.moe_ep_size = moe_ep_size\n        self.world_size = world_size\n        self.rank = rank\n        self.gpus_per_node = gpus_per_node\n\n        self.pp_groups = []\n        self.tp_groups = []\n        self.moe_tp_groups = []\n        self.moe_ep_groups = []\n\n        # init pp group\n        for i in range(tp_size):\n            ranks = range(i+ self.rank, world_size+ self.rank, tp_size)\n            self.pp_groups.append(list(ranks))\n\n        # init tp group\n        for i in range(pp_size):\n            ranks = range(i * tp_size + self.rank, (i + 1) * tp_size + self.rank)\n            self.tp_groups.append(list(ranks))\n\n        # init moe tp group\n        for i in range(pp_size):\n            for j in range(moe_ep_size):\n                ranks = range(i * moe_tp_ep_size + j, (i + 1) * moe_tp_ep_size,\n                              moe_ep_size)\n                self.moe_tp_groups.append(list(ranks))\n\n        # init moe ep group\n        for i in range(pp_size):\n            for j in range(moe_tp_size):\n                ranks = range(i * moe_tp_ep_size + j * moe_ep_size,\n                              i * moe_tp_ep_size + (j + 1) * moe_ep_size)\n                self.moe_ep_groups.append(list(ranks))\n\n        # self.pp_rank = self.rank // self.tp_size\n        # self.tp_rank = self.rank % self.tp_size\n        self.pp_rank = 0\n        self.tp_rank = 0\n        self.moe_tp_rank = self.tp_rank // self.moe_ep_size\n        self.moe_ep_rank = self.tp_rank % self.moe_ep_size\n\n        # self.tp_group = self.tp_groups[self.pp_rank]\n        # self.pp_group = self.pp_groups[self.tp_rank]\n        self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size +\n                                               self.moe_ep_rank]\n        self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size +\n                                               self.moe_tp_rank]\n\n        self.node_rank = self.rank // self.gpus_per_node\n        self.local_rank = self.rank % self.gpus_per_node\n\n    def get_node_rank(self, rank: int):\n        return rank // self.gpus_per_node\n\n    def get_local_rank(self, rank: int):\n        return rank % self.gpus_per_node\n\n    def has_tp(self):\n        return self.tp_size > 1\n\n    def is_last_pp_rank(self):\n        return self.pp_rank == self.pp_size - 1\n\n    def is_first_pp_rank(self):\n        return self.pp_rank == 0\n\n    def has_pp(self):\n        return self.pp_size > 1\n\n    def prev_pp_rank(self):\n        p = self.rank - self.tp_size\n        if p < 0:\n            p = p + self.world_size\n        return p\n\n    def next_pp_rank(self):\n        p = self.rank + self.tp_size\n        if p >= self.world_size:\n            p = p - self.world_size\n        return p\n\n    def has_moe_tp(self):\n        return self.moe_tp_size > 1\n\n    def has_moe_ep(self):\n        return self.moe_ep_size > 1\n\n    def pp_layers(self, num_layers: int) -> List[int]:\n        layers_per_pipeline_stage = num_layers // self.pp_size\n        layers_range = range(self.pp_rank * layers_per_pipeline_stage,\n                             (self.pp_rank + 1) * layers_per_pipeline_stage)\n        return list(layers_range)\n\n    def ep_experts(self, num_experts: int) -> List[int]:\n        experts_per_rank = num_experts // self.moe_ep_size\n        experts_range = range(self.moe_ep_rank * experts_per_rank,\n                              (self.moe_ep_rank + 1) * experts_per_rank)\n        return list(experts_range)\n\n\ndef get_engine_name(rank):\n    return 'rank{}.engine'.format(rank)\n\nclass TRTLLMEncDecModel:\n\n    def __init__(\n        self,\n        engine_name,\n        engine_dir,\n        lora_dir=None,\n        lora_task_uids=None,\n        debug_mode=False,\n        skip_encoder=False,\n        stream: torch.cuda.Stream = None,\n    ):\n        # in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device\n        # accordingly, all input & output tensors should be moved to current device\n        # otherwise, it's default to 'cuda:0'\n        \n        # self.runtime_rank = tensorrt_llm.mpi_rank()\n        self.device_id = torch.cuda.current_device()\n        # torch.cuda.set_device(device_id)\n        self.device = torch.cuda.current_device()\n        self.skip_encoder = skip_encoder\n        self.lora_task_uids = lora_task_uids\n\n        # when enc-dec runs by itself, stream can be None and we create new stream here\n        # when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync\n        self.stream = stream\n        if self.stream is None:\n            self.stream = torch.cuda.Stream(self.device)\n        torch.cuda.set_stream(self.stream)\n\n        def engine_setup(component):\n            # model config\n            config_path = os.path.join(engine_dir, component, \"config.json\")\n            model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(\n                config_path)\n\n            # MGMN config\n            world_size = tp_size * pp_size\n            # runtime_rank = tensorrt_llm.mpi_rank()\n            runtime_rank = torch.cuda.current_device()\n            # assert runtime_rank < world_size, \"Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?\"\n            # runtime_mapping = tensorrt_llm.Mapping(world_size,\n            #                                        runtime_rank,\n            #                                        tp_size=tp_size,\n            #                                        pp_size=pp_size,\n            #                                        gpus_per_node=gpus_per_node)\n            # tensorrt_llm.Mapping\n            runtime_mapping = Mapping(world_size,\n                                      runtime_rank,\n                                      tp_size=tp_size,\n                                      pp_size=pp_size,\n                                      gpus_per_node=gpus_per_node)\n            # load engine\n            # engine_fname = get_engine_name(runtime_rank)\n            engine_fname = get_engine_name(0)\n            with open(os.path.join(engine_dir, component, engine_fname), \"rb\") as f:\n                engine_buffer = f.read()\n\n            return model_config, runtime_mapping, engine_buffer\n\n        # Note: encoder and decoder doesn't necessarily have the same TP & PP config\n\n        if not skip_encoder:\n            self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(\n                component='encoder')\n\n            self.nccl_comm = None\n            if self.encoder_runtime_mapping.has_pp():\n                # for Pipeline Parallelism in encoder\n                self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(\n                    self.encoder_runtime_mapping.tp_size,\n                    self.encoder_runtime_mapping.pp_size,\n                    self.encoder_runtime_mapping.rank)\n\n            # session setup\n            self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(\n                encoder_engine_buffer)\n\n            # encoder lora manager setup\n            if self.encoder_model_config.lora_plugin:\n                self.encoder_lora_manager = LoraManager()\n                # TODO: this is only for bart\n                self.encoder_lora_manager.load_from_hf(\n                    model_dirs=lora_dir,\n                    model_config=self.encoder_model_config,\n                    runtime_mapping=self.encoder_runtime_mapping,\n                    component='encoder',\n                )\n            else:\n                self.encoder_lora_manager = None\n        else:\n            self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None\n            self.nccl_comm, self.encoder_session = None, None\n\n        self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(\n            component='decoder')\n\n        self.decoder_session = tensorrt_llm.runtime.GenerationSession(\n            self.decoder_model_config,\n            decoder_engine_buffer,\n            self.decoder_runtime_mapping,\n            debug_mode=debug_mode)\n\n        # decoder lora manager setup\n        if self.decoder_model_config.lora_plugin:\n            self.decoder_lora_manager = LoraManager()\n            # TODO: this is only for bart\n            self.decoder_lora_manager.load_from_hf(\n                model_dirs=lora_dir,\n                model_config=self.decoder_model_config,\n                runtime_mapping=self.decoder_runtime_mapping,\n                component='decoder',\n            )\n        else:\n            self.decoder_lora_manager = None\n    \n    @classmethod\n    def from_engine(cls,\n                    engine_name,\n                    engine_dir,\n                    lora_dir=None,\n                    lora_task_uids=None,\n                    debug_mode=False,\n                    skip_encoder=False,\n                    stream=None):\n        return cls(engine_name,\n                   engine_dir,\n                   lora_dir,\n                   lora_task_uids,\n                   debug_mode=debug_mode,\n                   skip_encoder=skip_encoder,\n                   stream=stream)\n\n    def process_input(self,\n                      input_ids,\n                      remove_input_padding=False,\n                      pad_token_id=0,\n                      prompt_tasks=None):\n        if remove_input_padding:\n            # in remove padding mode --> flatten input, calculate actual length and max length\n            # Note: 1st token should never be removed, even if it is pad_token_id\n            first_ids = input_ids[:, 0]\n            input_ids = input_ids[:, 1:]\n            input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(\n                torch.IntTensor).to(self.device)  # [batch_size]\n            new_ids = []\n            for i in range(len(input_ids)):\n                row = input_ids[i, :]\n                row = row[row != pad_token_id]\n                new_ids.append(\n                    torch.cat(\n                        (torch.IntTensor([first_ids[i]]).to(self.device), row)))\n            input_ids = torch.cat(new_ids)  # [num_tokens]\n            if prompt_tasks is not None:\n                prompt_tasks = prompt_tasks[:input_ids.shape[0]]\n        else:\n            # in padding mode --> keep input, just calculate actual length and max length\n            # Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count\n            input_lengths = torch.tensor(\n                1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type(\n                    torch.IntTensor).to(self.device),\n                dtype=torch.int32,\n                device=self.device)\n        max_input_length = torch.max(input_lengths).item()\n        return input_ids, input_lengths, max_input_length, prompt_tasks\n\n    def encoder_run(self,\n                    input_ids,\n                    input_lengths,\n                    max_input_length,\n                    position_ids=None,\n                    token_type_ids=None,\n                    debug_mode=False,\n                    prompt_embedding_table=None,\n                    prompt_tasks=None,\n                    prompt_vocab_size=None,\n                    attention_mask=None):\n\n        # each engine has hidden_dim/TP, don't forget to multiply TP\n        hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size\n        if input_ids.dim() == 1:\n            hidden_states_shape = (input_ids.shape[0], hidden_size\n                                   )  # [num_tokens,D]\n        else:\n            hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],\n                                   hidden_size)  # [BS,seqlen,D]\n        hidden_states_dtype = lambda name: trt_dtype_to_torch(\n            self.encoder_session.engine.get_tensor_dtype(name))\n\n        # input tensors. only first PP rank has id input, others are hidden_states input\n        inputs = {}\n        if self.encoder_runtime_mapping.is_first_pp_rank():\n            inputs['input_ids'] = input_ids.contiguous()\n            if self.encoder_model_config.has_position_embedding:\n                if position_ids is None:\n                    if self.encoder_model_config.remove_input_padding:\n                        position_ids = [\n                            torch.arange(sample_length,\n                                         dtype=torch.int32,\n                                         device=input_ids.device)\n                            for sample_length in torch_to_numpy(input_lengths)\n                        ]\n                        position_ids = torch.cat(position_ids)\n                    else:\n                        bsz, seq_len = input_ids.shape[:2]\n                        position_ids = torch.arange(\n                            seq_len, dtype=torch.int32,\n                            device=input_ids.device).expand(bsz, -1)\n                inputs['position_ids'] = position_ids.contiguous()\n            if self.encoder_model_config.has_token_type_embedding:\n                inputs['token_type_ids'] = token_type_ids.contiguous()\n\n            if self.encoder_model_config.max_prompt_embedding_table_size > 0:\n                inputs[\n                    'prompt_embedding_table'] = prompt_embedding_table.contiguous(\n                    )\n                inputs['tasks'] = prompt_tasks.contiguous()\n                inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()\n        else:\n            # just need a placeholder, engine will call NCCL to recv and fill data from previous rank\n            inputs['hidden_states_input'] = torch.empty(\n                hidden_states_shape,\n                dtype=hidden_states_dtype('hidden_states_input'),\n                device=self.device).contiguous()\n        if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:\n            inputs['attention_mask'] = attention_mask.contiguous()\n\n        inputs['input_lengths'] = input_lengths\n        # use shape info to pass max length info in remove padding mode\n        inputs['max_input_length'] = torch.empty(\n            (max_input_length, ),\n            dtype=hidden_states_dtype('max_input_length'),\n            device=self.device).contiguous()\n        batch_size = input_lengths.size(0)\n        inputs['host_request_types'] = torch.IntTensor([0] *\n                                                       batch_size).to('cpu')\n        if self.encoder_model_config.remove_input_padding:\n            inputs['host_context_lengths'] = input_lengths.to('cpu')\n\n        if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:\n            inputs.update(\n                self.encoder_lora_manager.input_buffers(\n                    self.lora_task_uids,\n                    self.encoder_runtime_mapping,\n                    self.encoder_model_config.num_layers,\n                ))\n\n        # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape\n        self.encoder_session.set_shapes(inputs)\n\n        # output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later\n        outputs = {}\n        if self.encoder_runtime_mapping.is_last_pp_rank():\n            outputs['encoder_output'] = torch.empty(\n                hidden_states_shape,\n                dtype=hidden_states_dtype('encoder_output'),\n                device=self.device).contiguous()\n        else:\n            outputs['hidden_states_output'] = torch.empty(\n                hidden_states_shape,\n                dtype=hidden_states_dtype('hidden_states_output'),\n                device=self.device).contiguous()\n\n        # -------------------------------------------\n        if debug_mode:\n            engine = self.encoder_session.engine\n            context = self.encoder_session.context\n            # setup debugging buffer for the encoder\n            for i in range(self.encoder_session.engine.num_io_tensors):\n                name = engine.get_tensor_name(i)\n                if engine.get_tensor_mode(\n                        name\n                ) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():\n                    dtype = engine.get_tensor_dtype(name)\n                    shape = context.get_tensor_shape(name)\n                    outputs[name] = torch.zeros(tuple(shape),\n                                                dtype=trt_dtype_to_torch(dtype),\n                                                device=self.device)\n                    context.set_tensor_address(name, outputs[name].data_ptr())\n        # -------------------------------------------\n\n        # TRT session run\n        # Note: need cuda stream ID, not a torch Stream\n        ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)\n        assert ok, \"Runtime execution failed\"\n        self.stream.synchronize()\n\n        # Tensor Parallelism is handled by model/engine definition\n        # But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism\n        # After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config\n        def pp_communicate_encoder_output(encoder_output):\n            if self.encoder_runtime_mapping.is_last_pp_rank():\n                for pp_rank in self.encoder_runtime_mapping.pp_group:\n                    if pp_rank != self.encoder_runtime_mapping.rank:\n                        self.nccl_comm.send(encoder_output, pp_rank)\n                return encoder_output\n            else:\n                self.nccl_comm.recv(encoder_output,\n                                    self.encoder_runtime_mapping.pp_group[-1])\n                return encoder_output\n\n        if self.encoder_runtime_mapping.has_pp():\n            # use hidden_states output buffer to receive output as the shapes are same\n            encoder_output_buf = outputs[\n                'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(\n                ) else outputs['hidden_states_output']\n            encoder_output = pp_communicate_encoder_output(encoder_output_buf)\n        else:\n            encoder_output = outputs['encoder_output']\n\n        return encoder_output\n\n    def generate(self,\n                 encoder_input_ids,\n                 decoder_input_ids,\n                 max_new_tokens,\n                 num_beams=1,\n                 pad_token_id=None,\n                 eos_token_id=None,\n                 bos_token_id=None,\n                 debug_mode=False,\n                 return_dict=False,\n                 prompt_embedding_table=None,\n                 prompt_tasks=None,\n                 prompt_vocab_size=None,\n                 attention_mask=None,\n                 time_encoder=False,\n                 return_encoder_output=False):\n        ## ensure all externally provided tensors are on the correct device.\n        encoder_input_ids = encoder_input_ids.to(self.device)\n        decoder_input_ids = decoder_input_ids.to(self.device)\n\n        if attention_mask is not None:\n            attention_mask = torch.tensor(attention_mask,\n                                          dtype=torch.int32,\n                                          device=self.device)\n\n        ## encoder run\n        encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding\n\n        encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(\n            encoder_input_ids, encoder_remove_input_padding, pad_token_id,\n            prompt_tasks)\n\n        if not self.skip_encoder:\n            #logger.info(f\"Rank {self.runtime_rank} Running encoder engine ...\")\n            if time_encoder:\n                tik = time.time()\n            encoder_output = self.encoder_run(\n                encoder_input_ids,\n                encoder_input_lengths,\n                encoder_max_input_length,\n                debug_mode=debug_mode,\n                prompt_embedding_table=prompt_embedding_table,\n                prompt_tasks=prompt_tasks,\n                prompt_vocab_size=prompt_vocab_size,\n                attention_mask=attention_mask)\n            if time_encoder:\n                tok = time.time()\n                print(f\"TRT-LLM Encoder time {(tok-tik)*1000}ms\")\n        else:\n            encoder_output = prompt_embedding_table\n            if encoder_input_ids.dim() > 1:\n                encoder_output = encoder_output.unsqueeze(0)\n\n        ## decoder run\n        # logger.info(f\"Rank {self.runtime_rank} Running decoder engine ...\")\n        decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(\n            decoder_input_ids, self.decoder_model_config.remove_input_padding,\n            pad_token_id)\n\n        # `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]\n        # where query_len happens to be 1 in current cases, but not necessarily always, and\n        # `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where\n        # the query_len is always 1 since we have kv cache.\n        cross_attention_mask = None\n        if attention_mask is not None:\n            cross_attention_mask = torch.tensor(attention_mask,\n                                                dtype=torch.int32,\n                                                device=self.device).reshape(\n                                                    attention_mask.shape[0], 1,\n                                                    attention_mask.shape[1])\n\n        # generation config\n        sampling_config = SamplingConfig(end_id=eos_token_id,\n                                         pad_id=pad_token_id,\n                                         num_beams=num_beams,\n                                         min_length=1,\n                                         return_dict=return_dict)\n        sampling_config.update(output_cum_log_probs=return_dict,\n                               output_log_probs=return_dict)\n\n        # decoder autoregressive generation\n        self.decoder_session.setup(\n            decoder_input_lengths.size(0),\n            decoder_max_input_length,\n            max_new_tokens,\n            num_beams,\n            max_attention_window_size=None,\n            encoder_max_input_length=encoder_max_input_length,\n            lora_manager=self.decoder_lora_manager,\n            lora_uids=self.lora_task_uids,\n        )\n\n        output = self.decoder_session.decode(\n            decoder_input_ids,\n            decoder_input_lengths,\n            sampling_config,\n            encoder_output=encoder_output,\n            encoder_input_lengths=encoder_input_lengths,\n            return_dict=return_dict,\n            cross_attention_mask=cross_attention_mask)\n\n        if return_dict and return_encoder_output:\n            output['encoder_output'] = encoder_output\n\n        return output\n"
  },
  {
    "path": "tools/demo/demo.py",
    "content": "import time\nimport torch\nimport argparse\n\nfrom PIL import Image\nfrom struct_eqtable import build_model\n\n\ndef parse_config():\n    parser = argparse.ArgumentParser(description='arg parser')\n    parser.add_argument('--image_path', type=str, default='demo.png', help='data path for table image')\n    parser.add_argument('--ckpt_path', type=str, default='U4R/StructTable-InternVL2-1B', help='ckpt path for table model, which can be downloaded from huggingface')\n    parser.add_argument('--max_new_tokens', type=int, default=1024, help='maximum output tokens of model inference')\n    parser.add_argument('-t', '--max_waiting_time', type=int, default=60, help='maximum waiting time of model inference')\n    parser.add_argument('-f', '--output_format', type=str, nargs='+', default=['latex'], \n                        help='The model outputs LaTeX format code by default. Simple structured table LaTeX code can be converted to HTML or Markdown format using pypandoc.')\n    parser.add_argument('--tensorrt_path', type=str, default=None, help='enable tensorrt for model acceleration')\n    parser.add_argument('--lmdeploy', action='store_true', help='use lmdepoly to accelerate model inference')\n    parser.add_argument('--disable_flash_attn', action='store_true', help='disable flash attention for non ampere gpu')\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = parse_config()\n\n    # build model\n    model = build_model(\n        args.ckpt_path, \n        max_new_tokens=args.max_new_tokens, \n        max_time=args.max_waiting_time,\n        tensorrt_path=args.tensorrt_path,\n        lmdeploy=args.lmdeploy,\n        flash_attn=not args.disable_flash_attn\n    )\n\n    assert torch.cuda.is_available(), \"Our model current only support with gpu\"\n    if not args.tensorrt_path:\n        model = model.cuda()\n\n    # process output format\n    output_formats = list(set(args.output_format) & set(model.supported_output_format))\n    print(f\"Supported output format: {' '.join(output_formats)}\")\n\n    # model inference\n    raw_image = Image.open(args.image_path)\n\n    output_list = []\n    start_time = time.time()\n\n    with torch.no_grad():\n        for tgt_fmt in output_formats:\n            output = model(raw_image, output_format=tgt_fmt)\n            output_list.append(output)\n\n    # show output latex code of table\n    cost_time = time.time() - start_time\n    print(f\"total cost time: {cost_time:.2f}s\")\n\n    if cost_time >= args.max_waiting_time:\n        warn_log = f\"\\033[93mThe model inference time exceeds the maximum waiting time {args.max_waiting_time} seconds, the result may be incomplete.\\n\" \\\n        \"Please increase the maximum waiting time with argument --max_waiting_time or Model may not support the type of input table image \\033[0m\"\n        print(warn_log)\n\n    for i, tgt_fmt in enumerate(output_formats):\n        for j, output in enumerate(output_list[i]):\n            print(f\"Table {j} {tgt_fmt.upper()} format output:\\n{output}\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools/demo/demo.tex",
    "content": "\n\\documentclass[border=20pt]{standalone}\n\\usepackage{blindtext}%\n\\usepackage{subcaption}\n\\usepackage{url}\n\\usepackage{graphicx}\n\\usepackage{caption}\n\\usepackage{multirow}\n\\usepackage{booktabs}\n\\usepackage{color}\n\\usepackage{colortbl}\n\\usepackage{xcolor,soul,framed}\n\\usepackage{xeCJK}\n%\\usepackage{fontspec}\n%\\usepackage[margin=1in]{geometry} \n\\usepackage{printlen}\n\\usepackage{amsmath,amssymb,mathtools,bm,mathrsfs,textcomp}\n\\setlength{\\parindent}{0pt}\n\n\\begin{document}\n\n\\begin{tabular}{|c|c|c|c|}\n  \\hline\n  Quantity $\\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\\\\n  \\hline\n  Mass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\\\\n  \\hline\n  Length, translational motion & meter (m) & foot (ft) & inch (in.) \\\\\n  \\hline\n  Time, $t$ & second (s) & second (s) & second (s) \\\\\n  \\hline\n  Force, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\\\\n  \\hline\n  Translational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\\\\n  \\hline\n  Translational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\\\\n  \\hline\n  Angle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\\\\n  \\hline\n  Rotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\\\\n  \\hline\n  Moment or torque, rotational action & N-m & lb-ft & lb-inch \\\\\n  \\hline\n  Rotational stiffness constant, $k_\\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\\\\n  \\hline\n  Rotational damping constant, $c_\\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\\\\n  \\hline\n\\end{tabular}\n\n\\end{document}"
  },
  {
    "path": "tools/scripts/build_tensorrt.sh",
    "content": "set -x \n\nHF_CKPT_PATH=${1:-\"../ckpts/StructTable-base\"}\nMODEL_OUTPUT=${2:-\"../ckpts/StructTable-base-TensorRT\"}\nMAX_IMAGE_TOKEN_NUM=${3:-2048}\nMAX_OUPTPUT_TOKEN_NUM=${4:-2048}\nMODEL_TYPE=${5:-\"StructEqTable\"}\n\nif [ ! -d $MODEL_OUTPUT ]; then\n    mkdir -p $MODEL_OUTPUT\nfi\n\n# Step1 Convert the model into TensorrtLLM checkpoint format\necho \"Step1 Convert the model into TensorrtLLM checkpoint format\"\n\npython tensorrt_utils/convert_checkpoint.py --model_type $MODEL_TYPE \\\n    --model_dir $HF_CKPT_PATH \\\n    --output_dir $MODEL_OUTPUT/trt_models/float16 \\\n    --tp_size 1 \\\n    --pp_size 1 \\\n    --workers 1 \\\n    --dtype float16\n\n# Step2 Compile the model\necho \"Step2 build LLM Engine\"\n\ntrtllm-build --checkpoint_dir $MODEL_OUTPUT/trt_models/float16/decoder \\\n    --output_dir $MODEL_OUTPUT/llm_engines/decoder \\\n    --paged_kv_cache disable \\\n    --moe_plugin disable \\\n    --enable_xqa disable \\\n    --use_custom_all_reduce disable \\\n    --gemm_plugin float16 \\\n    --bert_attention_plugin float16 \\\n    --gpt_attention_plugin float16 \\\n    --remove_input_padding enable \\\n    --context_fmha disable \\\n    --max_beam_width 1 \\\n    --max_batch_size 1 \\\n    --max_seq_len $MAX_OUPTPUT_TOKEN_NUM \\\n    --max_encoder_input_len $MAX_IMAGE_TOKEN_NUM \\\n    --max_input_len 1\n\n# Step3 build visual engine\necho \"Step3 Build Visual Engine\"\n\npython tensorrt_utils/build_visual_engine.py --model_type $MODEL_TYPE \\\n    --model_path $HF_CKPT_PATH \\\n    --output_dir $MODEL_OUTPUT/visual_engines \\\n    --max_batch_size 1\n\nif [ -f './model.cache' ]; then\n    rm ./model.cache\nfi\n\necho \"Build TensorRT model and Visual Engine Successfully\""
  },
  {
    "path": "tools/tensorrt_utils/build_visual_engine.py",
    "content": "import argparse\nimport os\nimport shutil\nimport sys\nimport tarfile\nfrom time import time\n\nimport yaml\n\n# isort: off\nimport torch\nimport tensorrt as trt\nfrom tensorrt_llm.builder import Builder\n# isort: on\nimport json\nimport math\n\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom safetensors.torch import save_file\nfrom transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,\n                          AutoModelForVision2Seq, AutoProcessor,\n                          Blip2ForConditionalGeneration, Blip2Processor,\n                          FuyuForCausalLM, FuyuProcessor,\n                          LlavaForConditionalGeneration, NougatProcessor,\n                          Pix2StructForConditionalGeneration,\n                          VisionEncoderDecoderModel)\n\n\ndef parse_arguments():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model_type',\n                        type=str,\n                        default=None,\n                        choices=[\n                            'opt-2.7b', 'opt-6.7b', 'flan-t5-xl', 'flan-t5-xxl',\n                            'llava', 'vila', 'nougat', 'cogvlm', 'fuyu', 'pix2struct',\n                            'StructEqTable', 'neva', 'kosmos-2', 'video-neva',\n                            'phi-3-vision'\n                        ],\n                        help=\"Model type\")\n    parser.add_argument(\n        '--model_path',\n        type=str,\n        default=None,\n        help=\n        \"Huggingface repo, local directory with weights or path to checkpoint file\"\n    )\n    parser.add_argument('--vila_path',\n                        type=str,\n                        default=None,\n                        help=\"Path to VILA source code directory\")\n    parser.add_argument('--output_dir',\n                        type=str,\n                        default=None,\n                        help=\"Directory where visual TRT engines are saved\")\n    parser.add_argument('--max_batch_size',\n                        type=int,\n                        default=4,\n                        help=\"Maximum batch size for input images\")\n    return parser.parse_args()\n\n\nclass VisionEngineBuilder:\n\n    def __init__(self, args):\n        args.device = torch.device(\n            \"cuda\") if torch.cuda.is_available() else \"cpu\"\n        if args.output_dir is None:\n            args.output_dir = 'visual_engines/%s' % (\n                args.model_path.split('/')[-1] if args.vila_path is not None\n                else args.model_path.split('/')[-1])\n        if not os.path.exists(args.output_dir):\n            os.makedirs(args.output_dir)\n\n        self.args = args\n\n    def build(self):\n        args = self.args\n        if 'opt' in args.model_type or 't5' in args.model_type:\n            build_blip2_engine(args)\n        elif args.model_type == 'pix2struct':\n            build_pix2struct_engine(args)\n        elif args.model_type == 'StructEqTable':\n            build_StructEqTable_engine(args)\n        elif args.model_type == 'llava':\n            build_llava_engine(args)\n        elif args.model_type == 'vila':\n            assert args.vila_path is not None, \"Please clone and provide VILA source code path\"\n            build_vila_engine(args)\n        elif args.model_type == 'nougat':\n            build_nougat_engine(args)\n        elif args.model_type == 'cogvlm':\n            build_cogvlm_engine(args)\n        elif args.model_type == 'fuyu':\n            build_fuyu_engine(args)\n        elif args.model_type == 'neva':\n            build_neva_engine(args)\n        elif args.model_type == 'video-neva':\n            build_video_neva_engine(args)\n        elif args.model_type == 'kosmos-2':\n            build_kosmos_engine(args)\n        elif args.model_type == 'phi-3-vision':\n            build_phi_engine(args)\n        else:\n            raise RuntimeError(f\"Invalid model type {args.model_type}\")\n\n\ndef export_visual_wrapper_onnx(visual_wrapper,\n                               input,\n                               output_dir,\n                               input_names=['input'],\n                               dynamic_axes={'input': {\n                                   0: 'batch'\n                               }}):\n    logger.log(trt.Logger.INFO, \"Exporting onnx\")\n    os.makedirs(f'{output_dir}/onnx', exist_ok=True)\n    torch.onnx.export(visual_wrapper,\n                      input,\n                      f'{output_dir}/onnx/visual_encoder.onnx',\n                      opset_version=17,\n                      input_names=input_names,\n                      output_names=['output'],\n                      dynamic_axes=dynamic_axes)\n\n\ndef build_trt_engine(model_type,\n                     input_sizes,\n                     output_dir,\n                     max_batch_size,\n                     dtype=torch.float16,\n                     num_frames=None):\n    part_name = 'visual_encoder'\n    onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)\n    engine_file = '%s/%s.engine' % (output_dir, part_name)\n    config_file = '%s/%s' % (output_dir, \"config.json\")\n    logger.log(trt.Logger.INFO, \"Building TRT engine for %s\" % part_name)\n\n    builder = trt.Builder(logger)\n    network = builder.create_network(\n        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n    profile = builder.create_optimization_profile()\n\n    config_args = {\n        \"precision\": str(dtype).split('.')[-1],\n        \"model_type\": model_type\n    }\n    if num_frames is not None:\n        config_args[\"num_frames\"] = num_frames\n\n    config_wrapper = Builder().create_builder_config(**config_args)\n    config = config_wrapper.trt_builder_config\n\n    parser = trt.OnnxParser(network, logger)\n\n    with open(onnx_file, 'rb') as model:\n        if not parser.parse(model.read(), os.path.abspath(onnx_file)):\n            logger.log(trt.Logger.ERROR, \"Failed parsing %s\" % onnx_file)\n            for error in range(parser.num_errors):\n                logger.log(trt.Logger.ERROR, parser.get_error(error))\n        logger.log(trt.Logger.INFO, \"Succeeded parsing %s\" % onnx_file)\n\n    # Delete onnx files since we don't need them now\n    shutil.rmtree(f'{output_dir}/onnx')\n\n    nBS = -1\n    nMinBS = 1\n    nOptBS = max(nMinBS, int(max_batch_size / 2))\n    nMaxBS = max_batch_size\n\n    inputT = network.get_input(0)\n\n    # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images,\n    # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]).\n    assert isinstance(input_sizes, list), \"input_sizes must be a list\"\n    if isinstance(input_sizes[0], int):\n        logger.log(trt.Logger.INFO, f\"Processed input sizes {input_sizes}\")\n        inputT.shape = [nBS, *input_sizes]\n        min_size = opt_size = max_size = input_sizes\n    elif len(input_sizes) == 3 and isinstance(input_sizes[0], list):\n        min_size, opt_size, max_size = input_sizes\n        logger.log(\n            trt.Logger.INFO,\n            f\"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}\"\n        )\n    else:\n        raise ValueError(f\"invalid input sizes: {input_sizes}\")\n\n    profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size],\n                      [nMaxBS, *max_size])\n    if model_type == \"pix2struct\" or model_type == \"StructEqTable\" :\n        inputT = network.get_input(1)\n        P = input_sizes[0]  # Number of patches\n        inputT.shape = [nBS, P]\n        profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P])\n    config.add_optimization_profile(profile)\n\n    t0 = time()\n    engine_string = builder.build_serialized_network(network, config)\n    t1 = time()\n    if engine_string is None:\n        raise RuntimeError(\"Failed building %s\" % (engine_file))\n    else:\n        logger.log(trt.Logger.INFO,\n                   \"Succeeded building %s in %d s\" % (engine_file, t1 - t0))\n        with open(engine_file, 'wb') as f:\n            f.write(engine_string)\n\n    Builder.save_config(config_wrapper, config_file)\n\n\ndef build_blip2_engine(args):\n    model_type = 'Salesforce/blip2-' + args.model_type\n    processor = Blip2Processor.from_pretrained(model_type)\n\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    prompt = \"Question: what is this? Answer:\"\n    inputs = processor(raw_image, prompt,\n                       return_tensors=\"pt\").to(args.device, torch.float16)\n    image = inputs['pixel_values']\n\n    class Blip2VisionWrapper(torch.nn.Module):\n\n        def __init__(self, vision_model, qformer, projector, query_tokens):\n            super().__init__()\n            self.vision_model = vision_model\n            self.qformer = qformer\n            self.projector = projector\n            self.query_tokens = query_tokens\n\n        def forward(self, image):\n            features = self.vision_model(image)[0]\n            qformer_output = self.qformer(query_embeds=self.query_tokens,\n                                          encoder_hidden_states=features,\n                                          return_dict=True)\n            return self.projector(qformer_output.last_hidden_state)\n\n    model = Blip2ForConditionalGeneration.from_pretrained(\n        model_type, torch_dtype=torch.float16)\n    wrapper = Blip2VisionWrapper(model.vision_model, model.qformer,\n                                 model.language_projection, model.query_tokens)\n    wrapper.to(args.device)\n\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_pix2struct_engine(args):\n    processor = AutoProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    dtype = torch.float16\n    inputs = processor(text=\"dummy\", images=raw_image, return_tensors=\"pt\", max_patches=processor.image_processor.max_patches)\n    image = inputs['flattened_patches'].to(args.device, dtype)\n    attention_mask = inputs['attention_mask'].to(args.device, torch.int)\n    class pix2structVisionWrapper(torch.nn.Module):\n\n        def __init__(self, encoder):\n            super().__init__()\n            self.encoder = encoder\n\n        def forward(self, image, attention_mask):\n            vision_x = self.encoder.embeddings(image)\n            img_features = self.encoder.encoder(vision_x,\n                                                attention_mask=attention_mask)\n            img_features = self.encoder.layernorm(img_features[0])\n            return img_features\n\n    model = Pix2StructForConditionalGeneration.from_pretrained(\n        args.model_path, torch_dtype=dtype)\n\n    wrapper = pix2structVisionWrapper(model.encoder.to(args.device))\n    # input shape: batch size, number of patches, hidden dimension\n    # attention mask shape: batch size, number of patches\n    # The number of image patches can vary depending on the image size, but it typically\n    # falls within a relatively narrow range. To improve performance, we can avoid using\n    # dynamic axis for the input patches and instead use a fixed number of patches along\n    # with an attention mask.\n    export_visual_wrapper_onnx(wrapper, (image, attention_mask),\n                               args.output_dir,\n                               input_names=['input', 'attention_mask'],\n                               dynamic_axes={\n                                   'input': {\n                                       0: 'batch'\n                                   },\n                                   'attention_mask': {\n                                       0: 'batch'\n                                   }\n                               })\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2]],  # Number of Patches, Hidden Dimension\n        args.output_dir,\n        args.max_batch_size,\n        torch.bfloat16)\n\n\ndef build_StructEqTable_engine(args):\n    processor = AutoProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    dtype = torch.float16\n    inputs = processor(text=\"dummy\", images=raw_image, return_tensors=\"pt\", max_patches=processor.image_processor.max_patches)\n    image = inputs['flattened_patches'].to(args.device, dtype)\n    attention_mask = inputs['attention_mask'].to(args.device, torch.int)\n    class StructEqTableVisionWrapper(torch.nn.Module):\n\n        def __init__(self, encoder):\n            super().__init__()\n            self.encoder = encoder\n\n        def forward(self, image, attention_mask):\n            vision_x = self.encoder.embeddings(image)\n            img_features = self.encoder.encoder(vision_x,\n                                                attention_mask=attention_mask)\n            img_features = self.encoder.layernorm(img_features[0])\n            return img_features\n\n    model = AutoModelForVision2Seq.from_pretrained(\n        args.model_path, torch_dtype=dtype)\n\n    wrapper = StructEqTableVisionWrapper(model.encoder.to(args.device))\n    # input shape: batch size, number of patches, hidden dimension\n    # attention mask shape: batch size, number of patches\n    # The number of image patches can vary depending on the image size, but it typically\n    # falls within a relatively narrow range. To improve performance, we can avoid using\n    # dynamic axis for the input patches and instead use a fixed number of patches along\n    # with an attention mask.\n    export_visual_wrapper_onnx(wrapper, (image, attention_mask),\n                               args.output_dir,\n                               input_names=['input', 'attention_mask'],\n                               dynamic_axes={\n                                   'input': {\n                                       0: 'batch'\n                                   },\n                                   'attention_mask': {\n                                       0: 'batch'\n                                   }\n                               })\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2]],  # Number of Patches, Hidden Dimension\n        args.output_dir,\n        args.max_batch_size,\n        torch.bfloat16)\n\n\ndef build_llava_engine(args):\n    processor = AutoProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    image = processor(text=\"dummy\", images=raw_image,\n                      return_tensors=\"pt\")['pixel_values'].to(\n                          args.device, torch.float16)\n\n    class LlavaVisionWrapper(torch.nn.Module):\n\n        def __init__(self, tower, projector, feature_layer):\n            super().__init__()\n            self.tower = tower\n            self.projector = projector\n            self.feature_layer = feature_layer\n\n        def forward(self, image):\n            all_hidden_states = self.tower(\n                image, output_hidden_states=True).hidden_states\n            features = all_hidden_states[self.feature_layer][:, 1:]\n            return self.projector(features)\n\n    model = LlavaForConditionalGeneration.from_pretrained(\n        args.model_path, torch_dtype=torch.float16)\n    wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device),\n                                 model.multi_modal_projector.to(args.device),\n                                 model.config.vision_feature_layer)\n\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_vila_engine(args):\n    # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo\n    sys.path.append(args.vila_path)\n    from llava.model import LlavaLlamaConfig, LlavaLlamaModel  # noqa\n    from transformers import AutoModel\n    model = AutoModel.from_pretrained(\n        args.model_path,\n        device_map='auto',\n    )\n\n    vision_tower = model.get_vision_tower()\n    image_processor = vision_tower.image_processor\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    image = image_processor(images=raw_image,\n                            return_tensors=\"pt\")['pixel_values']\n    if isinstance(image, list):\n        image = image[0].unsqueeze(0)\n    image = image.to(args.device, torch.float16)\n\n    class VilaVisionWrapper(torch.nn.Module):\n\n        def __init__(self, tower, projector):\n            super().__init__()\n            self.tower = tower\n            self.projector = projector\n\n        def forward(self, image):\n            features = self.tower(image)\n            return self.projector(features)\n\n    model = AutoModel.from_pretrained(\n        args.model_path,\n        device_map='auto',\n    )\n    wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device),\n                                model.mm_projector.to(args.device))\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_nougat_engine(args):\n    processor = NougatProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    image = processor(raw_image, return_tensors=\"pt\")['pixel_values'].to(\n        args.device, torch.float16)\n\n    class SwinEncoderWrapper(torch.nn.Module):\n\n        def __init__(self, encoder):\n            super().__init__()\n            self.encoder = encoder\n\n        def forward(self, image):\n            return self.encoder(image).last_hidden_state\n\n    model = VisionEncoderDecoderModel.from_pretrained(args.model_path,\n                                                      torch_dtype=torch.float16)\n    swin_encoder = model.get_encoder().to(args.device)\n    wrapper = SwinEncoderWrapper(swin_encoder)\n\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_cogvlm_engine(args):\n    hf_config = AutoConfig.from_pretrained(args.model_path,\n                                           trust_remote_code=True)\n    image_size = hf_config.vision_config['image_size']\n    dtype = hf_config.torch_dtype\n    image = torch.empty(1,\n                        3,\n                        image_size,\n                        image_size,\n                        dtype=dtype,\n                        device=args.device)  # dummy image\n\n    class CogVlmVisionWrapper(torch.nn.Module):\n\n        def __init__(self, encoder):\n            super().__init__()\n            self.encoder = encoder\n\n        def forward(self, image):\n            return self.encoder(image)\n\n    cogvlm = AutoModelForCausalLM.from_pretrained(args.model_path,\n                                                  torch_dtype=dtype,\n                                                  trust_remote_code=True)\n    vit_encoder = cogvlm.model.vision.to(args.device).eval()\n\n    wrapper = CogVlmVisionWrapper(vit_encoder)\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size,\n        dtype)\n\n\ndef build_fuyu_engine(args):\n    processor = FuyuProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])\n    image = processor(text=\"dummy\", images=raw_image,\n                      return_tensors=\"pt\")['image_patches'][0].to(\n                          args.device, torch.float16).unsqueeze(0)\n\n    class FuyuEncoderWrapper(torch.nn.Module):\n\n        def __init__(self, linear):\n            super().__init__()\n            self.linear = linear.to(torch.float16)\n\n        def forward(self, patches):\n            return self.linear(patches).flatten(0, 1)\n\n    model = FuyuForCausalLM.from_pretrained(args.model_path,\n                                            torch_dtype=torch.float16)\n\n    vision_encoder = model.vision_embed_tokens\n    wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device)\n\n    export_visual_wrapper_onnx(wrapper,\n                               image,\n                               args.output_dir,\n                               dynamic_axes={'input': {\n                                   0: 'batch',\n                                   2: 'patch'\n                               }})\n    build_trt_engine(\n        args.model_type,\n        # [nImgs, nImgPatches, nDims]\n        # nImgs is always one since each query has exactly one image\n        # nImgPatches depends on image size (patch size: 30x30)\n        # nDims is 30x30x3=2700 (patch size x color channels)\n        [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]],\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_neva_engine(args):\n    # extract NeMo checkpoint\n    with tarfile.open(args.model_path) as tar:\n        nemo_config = yaml.safe_load(tar.extractfile(\"./model_config.yaml\"))\n        try:\n            # trained without TP\n            mp0_weights = torch.load(tar.extractfile(\"./model_weights.ckpt\"),\n                                     map_location=args.device)\n        except KeyError:\n            # trained with TP\n            mp0_weights = torch.load(\n                tar.extractfile(\"./mp_rank_00/model_weights.ckpt\"),\n                map_location=args.device)\n\n    vision_config = nemo_config[\"mm_cfg\"][\"vision_encoder\"]\n\n    class VisionEncoderWrapper(torch.nn.Module):\n\n        def __init__(self, encoder, connector):\n            super().__init__()\n            self.encoder = encoder\n            self.connector = connector\n\n        def forward(self, images):\n            vision_x = self.encoder(pixel_values=images,\n                                    output_hidden_states=True)\n            vision_x = vision_x.hidden_states[-2]\n            vision_x = vision_x[:, 1:]\n            vision_x = self.connector(vision_x)\n            return vision_x\n\n    encoder = AutoModel.from_pretrained(vision_config[\"from_pretrained\"],\n                                        torch_dtype=torch.bfloat16,\n                                        trust_remote_code=True)\n    vision_encoder = encoder.vision_model\n    hf_config = encoder.config\n    dtype = hf_config.torch_dtype\n\n    # connector\n    assert nemo_config[\"mm_cfg\"][\"mm_mlp_adapter_type\"] == \"mlp2x_gelu\"\n    vision_connector = torch.nn.Sequential(\n        torch.nn.Linear(vision_config[\"hidden_size\"],\n                        nemo_config[\"hidden_size\"],\n                        bias=True), torch.nn.GELU(),\n        torch.nn.Linear(nemo_config[\"hidden_size\"],\n                        nemo_config[\"hidden_size\"],\n                        bias=True)).to(dtype=dtype)\n\n    key_prefix = \"model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector\"\n    for layer in range(0, 3, 2):\n        vision_connector[layer].load_state_dict({\n            'weight':\n            mp0_weights[f\"{key_prefix}.{layer}.weight\"].to(dtype),\n            'bias':\n            mp0_weights[f\"{key_prefix}.{layer}.bias\"].to(dtype),\n        })\n\n    # export the whole wrapper\n    wrapper = VisionEncoderWrapper(vision_encoder,\n                                   vision_connector).to(args.device, dtype)\n    image_size = hf_config.vision_config.image_size\n    dummy_image = torch.empty(\n        1, 3, image_size, image_size, dtype=dtype,\n        device=args.device)  # dummy image shape [B, C, H, W]\n    export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [3, image_size, image_size],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size,\n        dtype)\n\n\ndef build_video_neva_engine(args):\n    # extract NeMo checkpoint\n    with tarfile.open(args.model_path) as tar:\n        nemo_config = yaml.safe_load(tar.extractfile(\"./model_config.yaml\"))\n        try:\n            # trained without TP\n            mp0_weights = torch.load(tar.extractfile(\"./model_weights.ckpt\"),\n                                     map_location=args.device)\n        except KeyError:\n            # trained with TP\n            mp0_weights = torch.load(\n                tar.extractfile(\"./mp_rank_00/model_weights.ckpt\"),\n                map_location=args.device)\n\n    vision_config = nemo_config[\"mm_cfg\"][\"vision_encoder\"]\n\n    class VisionEncoderWrapper(torch.nn.Module):\n\n        def __init__(self, encoder, connector):\n            super().__init__()\n            self.encoder = encoder\n            self.connector = connector\n\n        def forward(self, images):\n            b, num_frames, c, h, w = images.shape\n            images = images.view(b * num_frames, c, h, w)\n            vision_x = self.encoder(\n                pixel_values=images,  #[(B num_frames), C, H, W]\n                output_hidden_states=True)\n            vision_x = vision_x.hidden_states[-2]\n            vision_x = vision_x[:, 1:]\n\n            # reshape back to [B, num_frames, img_size, hidden_size]\n            vision_x = vision_x.view(b, num_frames, -1, vision_x.shape[-1])\n\n            vision_x = self.connector(vision_x)\n            return vision_x\n\n    encoder = AutoModel.from_pretrained(vision_config[\"from_pretrained\"],\n                                        torch_dtype=torch.bfloat16,\n                                        trust_remote_code=True)\n    vision_encoder = encoder.vision_model\n    hf_config = encoder.config\n    dtype = hf_config.torch_dtype\n\n    # connector\n    assert nemo_config[\"mm_cfg\"][\"mm_mlp_adapter_type\"] == \"linear\"\n    vision_connector = torch.nn.Linear(vision_config[\"hidden_size\"],\n                                       nemo_config[\"hidden_size\"],\n                                       bias=True)\n\n    key_prefix = \"model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector\"\n    vision_connector.load_state_dict({\n        'weight':\n        mp0_weights[f\"{key_prefix}.weight\"].to(dtype),\n        'bias':\n        mp0_weights[f\"{key_prefix}.bias\"].to(dtype),\n    })\n\n    # export the whole wrapper\n    wrapper = VisionEncoderWrapper(vision_encoder,\n                                   vision_connector).to(args.device, dtype)\n    image_size = hf_config.vision_config.image_size\n    num_frames = nemo_config['data']['num_frames']\n    dummy_video = torch.empty(1,\n                              num_frames,\n                              3,\n                              image_size,\n                              image_size,\n                              dtype=dtype,\n                              device=args.device)  # dummy image\n    export_visual_wrapper_onnx(wrapper, dummy_video, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [num_frames, 3, image_size, image_size],  # [num_frames, 3, H, W]\n        args.output_dir,\n        args.max_batch_size,\n        dtype,\n        num_frames=num_frames)\n\n\ndef build_kosmos_engine(args):\n    processor = AutoProcessor.from_pretrained(args.model_path)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    image = processor(text=\"dummy\", images=raw_image,\n                      return_tensors=\"pt\")['pixel_values'].to(\n                          args.device, torch.float16)\n\n    class VisionEncoderWrapper(torch.nn.Module):\n\n        def __init__(self, encoder, connector):\n            super().__init__()\n            self.encoder = encoder\n            self.connector = connector\n\n        def forward(self, images):\n            vision_x = self.encoder(images, output_hidden_states=True)\n            img_features = self.encoder.model.post_layernorm(\n                vision_x.last_hidden_state)\n            img_features = F.normalize(img_features, dim=-1)\n            img_features, _ = self.connector(img_features)\n            return img_features\n\n    model = AutoModelForVision2Seq.from_pretrained(args.model_path,\n                                                   torch_dtype=torch.float16)\n    wrapper = VisionEncoderWrapper(\n        model.vision_model.to(args.device),\n        model.image_to_text_projection.to(args.device))\n\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]\n        args.output_dir,\n        args.max_batch_size)\n\n\ndef build_phi_engine(args):\n    processor = AutoProcessor.from_pretrained(args.model_path,\n                                              trust_remote_code=True)\n    raw_image = Image.new('RGB', [10, 10])  # dummy image\n    image = processor(text=\"<|image_1|>\\ndummy\",\n                      images=raw_image,\n                      return_tensors=\"pt\")['pixel_values'].to(\n                          args.device, torch.float16)\n    try:\n        with open(f\"{args.model_path}/preprocessor_config.json\", \"r\") as file:\n            config = file.read()\n            config_dict = json.loads(config)\n            num_crops = config_dict.get(\"num_crops\")\n    except:\n        num_crops = 16\n\n    class Phi3VisionWrapper(torch.nn.Module):\n\n        def __init__(self, img_processor, img_projection, layer_idx,\n                     image_dim_out):\n            super().__init__()\n            self.img_processor = img_processor\n            self.img_projection = img_projection\n            self.layer_idx = layer_idx\n            self.image_dim_out = image_dim_out\n\n        def get_img_features(\n                self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:\n            LAYER_IDX = self.layer_idx\n\n            img_processor_output = self.img_processor(img_embeds,\n                                                      output_hidden_states=True)\n            img_feature = img_processor_output.hidden_states[LAYER_IDX]\n\n            patch_feature = img_feature[:, 1:]\n            return patch_feature\n\n        def forward(self, image):\n            img_features = self.get_img_features(image)\n            base_feat_height = int(math.sqrt(img_features.shape[1]))\n            C = self.image_dim_out\n            H = base_feat_height\n            img_features = img_features.reshape(-1, H, H, C).reshape(\n                -1, H // 2, 2, H // 2, 2,\n                C).contiguous().permute(0, 1, 3, 2, 4,\n                                        5).reshape(-1, H // 2, H // 2,\n                                                   4 * C).contiguous()\n            return self.apply_img_projection(img_features)\n\n        def apply_img_projection(self, input):\n            return self.img_projection(input)\n\n    model = AutoModelForCausalLM.from_pretrained(args.model_path,\n                                                 torch_dtype=torch.float16,\n                                                 trust_remote_code=True).to(\n                                                     args.device)\n\n    wrapper = Phi3VisionWrapper(model.model.vision_embed_tokens.img_processor,\n                                model.model.vision_embed_tokens.img_projection,\n                                model.model.vision_embed_tokens.layer_idx,\n                                model.model.vision_embed_tokens.image_dim_out)\n    image = image.flatten(0, 1)\n    glb_GN = wrapper.apply_img_projection(\n        model.model.vision_embed_tokens.glb_GN)\n    sub_GN = wrapper.apply_img_projection(\n        model.model.vision_embed_tokens.sub_GN)\n    tensors = {\"glb_GN\": glb_GN, \"sub_GN\": sub_GN}\n    save_file(tensors, args.output_dir + \"/image_newlines.safetensors\")\n    export_visual_wrapper_onnx(wrapper, image, args.output_dir)\n    build_trt_engine(\n        args.model_type,\n        [image.shape[1], image.shape[2], image.shape[3]], args.output_dir,\n        args.max_batch_size * (num_crops + 1))  #TODO: Take input from config\n\n\nif __name__ == '__main__':\n    logger = trt.Logger(trt.Logger.INFO)\n    args = parse_arguments()\n    builder = VisionEngineBuilder(args)\n    builder.build()\n"
  },
  {
    "path": "tools/tensorrt_utils/convert_checkpoint.py",
    "content": "import argparse\nimport configparser\nimport copy\nimport json\nimport logging\nimport os\nimport types\nfrom ast import literal_eval\nfrom datetime import datetime\nfrom pathlib import Path\n\nimport safetensors\nfrom helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split\nfrom transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,\n                          MBartForConditionalGeneration,\n                          Pix2StructForConditionalGeneration,\n                          AutoModelForVision2Seq,\n                          T5ForConditionalGeneration, VisionEncoderDecoderModel)\n\nfrom tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,\n                                     MLPType)\nfrom tensorrt_llm.models import PretrainedConfig\n\ndir_path = os.path.dirname(os.path.realpath(__file__))\nLOGGER = logging.getLogger(__name__)\n\nlayernorm_type_map = {i.name: i.value for i in LayerNormType}\nlayernorm_position_map = {i.name: i.value for i in LayerNormPositionType}\nmlp_type_map = {i.name: i.value for i in MLPType}\n\n\ndef copy_args_to_component_config(component_config, args):\n    for arg in vars(args):\n        setattr(component_config, arg, getattr(args, arg))\n    return component_config\n\n\ndef parse_t5_config(args, hf_model):\n    config = configparser.ConfigParser()\n\n    config[\"encoder\"] = {}\n    for key, val in hf_model.encoder.config.to_dict().items():\n        config[\"encoder\"][key] = f\"{val}\"\n\n    # manually set q_scaling to offset attention scaling's effect.\n    # TODO: modify kernels to control whether to disable attention scaling\n    def get_offset_q_scaling(config):\n        scaling = 1 / config.head_size**.5\n        return scaling\n\n    config[\"decoder\"] = {}\n    for key, val in hf_model.decoder.config.to_dict().items():\n        config[\"decoder\"][key] = f\"{val}\"\n\n    config[\"structure\"] = dict()\n    config[\"structure\"][\"t5_with_bias\"] = \"false\"\n    config[\"structure\"][\"use_gated_activation\"] = str(\n        hf_model.encoder.config.is_gated_act)\n    config[\"structure\"][\"position_embedding_type\"] = \"relative\"\n    config[\"structure\"][\"model_type\"] = args.model_type\n\n    def parse_t5_config_by_component(config, component, args):\n        component_config = types.SimpleNamespace()\n        component_config = copy_args_to_component_config(component_config, args)\n        component_config.n_head = config.getint(component, 'num_heads')\n        component_config.head_size = config.getint(component, 'd_kv')\n        component_config.hidden_size = config.getint(component, 'd_model')\n        component_config.ffn_hidden_size = config.getint(component, 'd_ff')\n        component_config.vocab_size = config.getint(component, 'vocab_size')\n        component_config.n_positions = config.getint(component,\n                                                     'n_positions',\n                                                     fallback=512)\n        component_config.has_position_embedding = config.getboolean(\n            component, 'has_position_embedding',\n            fallback=False)  # TODO: hardcoded here\n\n        component_config.has_token_type_embedding = config.getboolean(\n            component, 'has_token_type_embedding', fallback=False)\n        component_config.has_embedding_layernorm = config.getboolean(\n            component, 'has_embedding_layernorm', fallback=False)\n        component_config.has_embedding_scale = config.getboolean(\n            component, 'has_embedding_scale', fallback=False)\n        component_config.q_scaling = get_offset_q_scaling(component_config)\n        component_config.has_attention_qkvo_bias = config.getboolean(\n            component, 'has_attention_qkvo_bias',\n            fallback=False)  # TODO: hardcoded here\n        component_config.has_mlp_bias = config.getboolean(component,\n                                                          'has_mlp_bias',\n                                                          fallback=False)\n        component_config.has_model_final_layernorm = config.getboolean(\n            component, 'has_model_final_layernorm', fallback=True)\n        component_config.layernorm_eps = config.getfloat(\n            component, 'layer_norm_epsilon')\n        component_config.layernorm_position = layernorm_position_map[config.get(\n            component, 'layernorm_position',\n            fallback='pre_layernorm')]  # TODO: hardcoded here\n        component_config.layernorm_type = layernorm_type_map[config.get(\n            component, 'layernorm_type', fallback='RmsNorm')]\n        component_config.hidden_act = config.get(component, 'dense_act_fn')\n        component_config.gated_act = config.getboolean(component,\n                                                       'is_gated_act')\n        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.\n                                                 gated_act else 'MLP']\n        component_config.num_buckets = config.getint(\n            component, 'relative_attention_num_buckets')\n        component_config.max_distance = config.getint(\n            component, 'relative_attention_max_distance')\n        component_config.position_embedding_type = config.get(\n            'structure', 'position_embedding_type')\n        component_config.logits_dtype = config.get(component,\n                                                   'logits_dtype',\n                                                   fallback='float32')\n\n        if component == 'encoder':\n            component_config.n_layer = config.getint(component, 'num_layers')\n\n            component_config.relative_attention = config.get(\n                'structure', 'position_embedding_type') == 'relative'\n\n        elif component == 'decoder':\n            component_config.n_layer = config.getint(component,\n                                                     'num_decoder_layers')\n            component_config.has_lm_head_bias = config.getboolean(\n                component,  # TODO: T5 with bias\n                'has_lm_head_bias',\n                fallback=False)\n            component_config.relative_attention = config.getboolean(\n                component, 'relative_attention', fallback=True)\n            component_config.rescale_before_lm_head = config.getboolean(\n                component, 'tie_word_embeddings'\n            )  # default is True (for T5), but False for Flan-T5\n            component_config.encoder_hidden_size = config.getint(\n                'encoder', 'd_model')\n            component_config.encoder_num_heads = config.getint(\n                'encoder', 'num_heads')\n            component_config.encoder_head_size = config.getint(\n                'encoder', 'd_kv')\n            component_config.decoder_start_token_id = config.getint(\n                'decoder', 'decoder_start_token_id')\n\n        else:\n            assert False, 'Unsupported component!'\n\n        return component_config\n\n    encoder_config = parse_t5_config_by_component(config, \"encoder\", args)\n    decoder_config = parse_t5_config_by_component(config, \"decoder\", args)\n\n    return encoder_config, decoder_config\n\n\ndef convert_t5_weights_to_tllm_safetensors(config, component, params):\n    weights = {}\n\n    mapping = config.mapping\n\n    convert_weight_to_dtype(params, config.dtype)\n    hidden_size = config.hidden_size\n    ffn_hidden_size = config.intermediate_size\n    num_layers = config.num_hidden_layers\n    n_head = config.num_attention_heads\n    head_size = config.head_size\n    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5\n\n    hf_param_prefix = f'{component}'\n    trtllm_layer_name = f'{component}_layers'\n    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'\n    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'\n    hf_component_idx = 1 if component == 'encoder' else 2\n\n    def get_attn_module_name(component, block, layer, attn_type):\n        return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}'\n\n    weights['embedding.vocab_embedding.weight'] = reshape(\n        params['shared.weight'].clone(), None)\n\n    layers_range = mapping.pp_layers(num_layers)\n    for layer_idx in layers_range:\n        local_layer_idx = layer_idx - layers_range[0]\n        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'\n        hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}'\n\n        hidden_layer_name_split = {\n            f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',\n                \"shape\":\n                (hidden_size, attention_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight':\n            {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.proj.weight',\n                \"shape\": (hidden_size, ffn_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight':\n            {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.fc.weight',\n                \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                \"split_dim\": 0\n            },\n            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight':\n            {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.fc.weight',\n                \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                \"split_dim\": 0\n            },\n        }\n\n        hidden_layer_name_no_split = {\n            f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',\n                \"shape\": None\n            },\n            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight':\n            {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',\n                \"shape\": None\n            },\n        }\n\n        if config.gated_act:\n            hidden_layer_name_split.update({\n                f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight':\n                {\n                    \"name\": f'{trtllm_layer_name_prefix}.mlp.gate.weight',\n                    \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                    \"split_dim\": 0\n                },\n                f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight':\n                {\n                    \"name\": f'{trtllm_layer_name_prefix}.mlp.gate.weight',\n                    \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                    \"split_dim\": 0\n                },\n            })\n\n        if component == 'decoder':\n            hidden_layer_name_split.update({\n                f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': {\n                    \"name\":\n                    f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',\n                    \"shape\":\n                    (hidden_size, attention_hidden_size // mapping.tp_size),\n                    \"split_dim\": -1\n                },\n            })\n            hidden_layer_name_no_split.update({\n                f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': {\n                    \"name\":\n                    f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',\n                    \"shape\": None\n                },\n            })\n            self_attn_module_name = get_attn_module_name(\n                component, layer_idx, \"1\", 'EncDecAttention')\n            weights.update(\n                fuse_qkv_one_layer(\n                    params, self_attn_module_name,\n                    f'{trtllm_layer_name_prefix}.cross_attention',\n                    mapping.tp_size, mapping.tp_rank, config.model_type,\n                    (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                    None))\n\n        self_attn_module_name = get_attn_module_name(component, layer_idx, \"0\",\n                                                     'SelfAttention')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',\n                mapping.tp_size, mapping.tp_rank, config.model_type,\n                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                None))\n\n        weights[\n            f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(\n                split(\n                    params[\n                        f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']\n                    .T, mapping.tp_size, mapping.tp_rank, 0),\n                (n_head // mapping.tp_size, config.num_buckets))\n\n        for hf_weight_name, weight_info in hidden_layer_name_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    split(params[hf_weight_name],\n                          mapping.tp_size,\n                          mapping.tp_rank,\n                          dim=weight_info[\"split_dim\"]), weight_info[\"shape\"])\n        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    params[hf_weight_name].clone(), shape=weight_info[\"shape\"])\n\n    weights['final_layernorm.weight'] = reshape(\n        params[f'{component}.final_layer_norm.weight'].clone(), None)\n\n    if component == 'decoder':\n        weights['lm_head.weight'] = reshape(\n            split(params['lm_head.weight'],\n                  mapping.tp_size,\n                  mapping.tp_rank,\n                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))\n        if not config.use_implicit_relative_attention:\n            weights['rel_attn_table'] = reshape(\n                split(\n                    params[\n                        f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']\n                    .T, mapping.tp_size, mapping.tp_rank, 0),\n                (n_head // mapping.tp_size, config.num_buckets))\n\n    return weights\n\n\nconvert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors  # func alias\n\n\ndef parse_nmt_config(args, model):\n    config = configparser.ConfigParser()\n    fairseq_config = vars(model.cfg.model)  # Namespace --> dict\n\n    config['encoder'] = dict()\n    for key, val in fairseq_config.items():\n        config[\"encoder\"][key] = f\"{val}\"\n    config[\"encoder\"][\"q_scaling\"] = '1'\n    # NMT has final layernorm for pre-norm model architecture.\n    config['encoder']['has_model_final_layernorm'] = config['encoder'][\n        'encoder_normalize_before']\n    config['encoder']['vocab_size'] = str(len(model.src_dict))  # fairseq naming\n\n    config['decoder'] = dict()\n    for key, val in fairseq_config.items():\n        config[\"decoder\"][key] = f\"{val}\"\n    config[\"decoder\"][\"q_scaling\"] = '1'\n    config[\"decoder\"][\"rescale_before_lm_head\"] = 'false'\n    config['decoder']['has_model_final_layernorm'] = config['decoder'][\n        'decoder_normalize_before'] and not config['decoder'].getboolean(\n            'no_decoder_final_norm', False)\n    config['decoder']['vocab_size'] = str(len(model.tgt_dict))  # fairseq naming\n\n    config[\"structure\"] = dict()\n    config[\"structure\"][\"t5_with_bias\"] = \"true\"\n    config[\"structure\"][\"use_gated_activation\"] = \"false\"\n    config[\"structure\"][\n        \"position_embedding_type\"] = \"learned_absolute\"  # \"sinusoid\"\n    config[\"structure\"][\"model_type\"] = args.model_type\n\n    def parse_nmt_config_by_component(config, component, args):\n        assert component in ('encoder', 'decoder'), 'Unsupported component!'\n        component_config = types.SimpleNamespace()\n        component_config = copy_args_to_component_config(component_config, args)\n        component_config.n_layer = config.getint(component,\n                                                 f'{component}_layers')\n        component_config.n_head = config.getint(component,\n                                                f'{component}_attention_heads')\n        component_config.hidden_size = config.getint(\n            component, f'{component}_embed_dim')  # fairseq naming\n        component_config.head_size = config.getint(\n            component,\n            'd_kv',\n            fallback=component_config.hidden_size // component_config.n_head)\n        component_config.ffn_hidden_size = config.getint(\n            component, f'{component}_ffn_embed_dim')  # fairseq naming\n        component_config.vocab_size = config.getint(component, 'vocab_size')\n        component_config.n_positions = config.getint(\n            component, 'max_source_positions')  # fairseq naming\n        component_config.has_position_embedding = not config.getboolean(\n            component, 'no_token_positional_embeddings',\n            fallback=False)  # fairseq naming\n        component_config.has_token_type_embedding = config.getboolean(\n            component, 'has_token_type_embedding', fallback=False)\n        component_config.has_embedding_layernorm = config.getboolean(\n            component, 'layernorm_embedding', fallback=True)  # fairseq naming\n        component_config.has_embedding_scale = not config.getboolean(\n            component, 'no_scale_embedding')  # fairseq naming\n        component_config.q_scaling = config.getfloat(component,\n                                                     'q_scaling',\n                                                     fallback=1.0)\n        component_config.has_attention_qkvo_bias = config.getboolean(\n            'structure', 't5_with_bias', fallback=True)\n        component_config.has_mlp_bias = config.getboolean('structure',\n                                                          't5_with_bias',\n                                                          fallback=True)\n        component_config.has_model_final_layernorm = config.getboolean(\n            component, 'has_model_final_layernorm')\n        component_config.layernorm_eps = config.getfloat(\n            component, 'layer_norm_epsilon', fallback=1e-5)  # fairseq naming\n\n        normalize_before = config.getboolean(\n            component, f'{component}_normalize_before')  # fairseq naming\n        component_config.layernorm_position = layernorm_position_map[\n            'pre_layernorm' if normalize_before else 'post_layernorm']\n\n        component_config.layernorm_type = layernorm_type_map[config.get(\n            component, 'layernorm_type', fallback='LayerNorm')]\n        component_config.hidden_act = config.get(\n            component, 'activation_fn')  # fairseq naming\n        component_config.gated_act = config.getboolean(component,\n                                                       'is_gated_act',\n                                                       fallback=False)\n        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.\n                                                 gated_act else 'MLP']\n        component_config.relative_attention = config.get(\n            'structure', 'position_embedding_type') == 'relative'\n\n        component_config.num_buckets = config.getint(\n            component, 'relative_attention_num_buckets', fallback=0)\n        component_config.max_distance = config.getint(\n            component, 'relative_attention_max_distance', fallback=0)\n        component_config.position_embedding_type = config.get(\n            'structure', 'position_embedding_type')\n        component_config.logits_dtype = config.get(component,\n                                                   'logits_dtype',\n                                                   fallback='float32')\n        if component == 'decoder':\n            component_config.rescale_before_lm_head = config.getboolean(\n                component, 'rescale_before_lm_head')\n\n            component_config.encoder_hidden_size = config.getint(\n                'encoder', 'encoder_embed_dim')  # fairseq naming\n            component_config.encoder_num_heads = config.getint(\n                'encoder', 'encoder_attention_heads')\n            component_config.encoder_head_size = config.getint(\n                'encoder',\n                'd_kv',\n                fallback=component_config.encoder_hidden_size //\n                component_config.encoder_num_heads)\n            component_config.decoder_start_token_id = config.getint(\n                'decoder', 'decoder_start_token_id')\n\n        return component_config\n\n    encoder_config = parse_nmt_config_by_component(config, \"encoder\", args)\n    decoder_config = parse_nmt_config_by_component(config, \"decoder\", args)\n\n    return encoder_config, decoder_config\n\n\ndef convert_nmt_weights_to_tllm_safetensors(config, component, params,\n                                            sin_pos_embedding):\n    weights = {}\n\n    mapping = config.mapping\n\n    hidden_size = config.hidden_size\n\n    convert_weight_to_dtype(params, config.dtype)\n    ffn_hidden_size = config.intermediate_size\n    vocab_size = config.vocab_size\n\n    hf_param_prefix = f'models.0.{component}'\n    trtllm_layer_name = f'{component}_layers'\n    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'\n    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'\n\n    hidden_layer_name_split = {\n        'self_attn.out_proj.weight': {\n            \"name\": f'{trtllm_attn_layer_name}.dense.weight',\n            \"shape\": (hidden_size, hidden_size // mapping.tp_size),\n            \"split_dim\": -1\n        },\n        'fc1.weight': {\n            \"name\": 'mlp.fc.weight',\n            \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n            \"split_dim\": 0\n        },\n        'fc1.bias': {\n            \"name\": 'mlp.fc.bias',\n            \"shape\": (ffn_hidden_size // mapping.tp_size),\n            \"split_dim\": 0\n        },\n        'fc2.weight': {\n            \"name\": 'mlp.proj.weight',\n            \"shape\": (hidden_size, ffn_hidden_size // mapping.tp_size),\n            \"split_dim\": -1\n        },\n    }\n\n    hidden_layer_name_no_split = {\n        'self_attn.out_proj.bias': {\n            \"name\": f'{trtllm_attn_layer_name}.dense.bias',\n            \"shape\": (hidden_size)\n        },\n        'self_attn_layer_norm.weight': {\n            \"name\": f'{trtllm_attn_layernorm_name}.weight',\n            \"shape\": None\n        },\n        'self_attn_layer_norm.bias': {\n            \"name\": f'{trtllm_attn_layernorm_name}.bias',\n            \"shape\": None\n        },\n        'fc2.bias': {\n            \"name\": 'mlp.proj.bias',\n            \"shape\": (hidden_size)\n        },\n        'final_layer_norm.weight': {\n            \"name\": 'mlp_layernorm.weight',\n            \"shape\": None\n        },\n        'final_layer_norm.bias': {\n            \"name\": 'mlp_layernorm.bias',\n            \"shape\": None\n        },\n    }\n\n    if component == \"decoder\":\n        hidden_layer_name_split.update({\n            'encoder_attn.out_proj.weight': {\n                \"name\": 'cross_attention.dense.weight',\n                \"shape\": (hidden_size, hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n        })\n        hidden_layer_name_no_split.update({\n            'encoder_attn.out_proj.bias': {\n                \"name\": 'cross_attention.dense.bias',\n                \"shape\": (hidden_size)\n            },\n            'encoder_attn_layer_norm.weight': {\n                \"name\": 'cross_attention_layernorm.weight',\n                \"shape\": None,\n            },\n            'encoder_attn_layer_norm.bias': {\n                \"name\": 'cross_attention_layernorm.bias',\n                \"shape\": None\n            },\n        })\n\n    def get_attn_module_name(component, layer, attn_type):\n        return f'models.0.{component}.layers.{int(layer)}.{attn_type}'\n\n    weights[\"embedding.vocab_embedding.weight\"] = reshape(\n        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(),\n        (vocab_size, -1))\n    weights[\"embedding.position_embedding.weight\"] = reshape(\n        sin_pos_embedding, (config.max_position_embeddings, hidden_size))\n\n    num_layers = config.num_hidden_layers\n\n    layers_range = mapping.pp_layers(num_layers)\n    for layer_idx in layers_range:\n        local_layer_idx = layer_idx - layers_range[0]\n        hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'\n        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'\n\n        for hf_weight_name, weight_info in hidden_layer_name_split.items():\n            weights[\n                f'{trtllm_layer_name_prefix}.{weight_info[\"name\"]}'] = reshape(\n                    split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],\n                          mapping.tp_size,\n                          mapping.tp_rank,\n                          dim=weight_info[\"split_dim\"]), weight_info[\"shape\"])\n\n        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():\n            trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info[\"name\"]}'\n            hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'\n            weights[trtllm_layer_fullname] = reshape(\n                params[hf_layer_fullname].clone(), shape=weight_info[\"shape\"])\n\n        self_attn_module_name = get_attn_module_name(component, layer_idx,\n                                                     'self_attn')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',\n                mapping.tp_size, mapping.tp_rank, config.model_type,\n                (hidden_size * 3 // mapping.tp_size, hidden_size),\n                (hidden_size * 3 // mapping.tp_size)))\n        if component == 'decoder':\n            cross_attn_module_name = get_attn_module_name(\n                component, layer_idx, 'encoder_attn')\n            weights.update(\n                fuse_qkv_one_layer(\n                    params, cross_attn_module_name,\n                    f'{trtllm_layer_name_prefix}.cross_attention',\n                    mapping.tp_size, mapping.tp_rank, config.model_type,\n                    (hidden_size * 3 // mapping.tp_size, hidden_size),\n                    (hidden_size * 3 // mapping.tp_size)))\n\n    if component == 'decoder':\n        weights['lm_head.weight'] = reshape(\n            split(params[f'{hf_param_prefix}.output_projection.weight'],\n                  mapping.tp_size,\n                  mapping.tp_rank,\n                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))\n\n    if config.has_model_final_layernorm:\n        weights['final_layernorm.weight'] = params[\n            f'{hf_param_prefix}.layer_norm.weight'].clone()\n        weights['final_layernorm.bias'] = params[\n            f'{hf_param_prefix}.layer_norm.bias'].clone()\n\n    return weights\n\n\ndef parse_bart_config(args, hf_model):\n\n    config = configparser.ConfigParser()\n\n    config['decoder'] = dict()\n    for key, val in hf_model.model.decoder.config.to_dict().items():\n        config[\"decoder\"][key] = f\"{val}\"\n    config[\"decoder\"][\"q_scaling\"] = '1'\n    config[\"decoder\"][\"rescale_before_lm_head\"] = str(False)\n    config['decoder']['has_model_final_layernorm'] = str(\n        args.nougat or isinstance(hf_model, MBartForConditionalGeneration))\n\n    if args.nougat:\n        # These flags are true for mbart decoders, but missing in HF config\n        config['decoder']['normalize_before'] = str(True)\n        config['decoder']['normalize_embeddings'] = str(True)\n\n        config['encoder'] = dict()\n        # Init few encoder configs, needed by build, from decoder config\n        encoder_config_keys = [\n            \"encoder_ffn_dim\", \"encoder_layers\", \"encoder_attention_heads\",\n            \"encoder_layerdrop\", \"d_model\"\n        ]\n        for key in encoder_config_keys:\n            config['encoder'][key] = config['decoder'][key]\n    else:\n        config['encoder'] = dict()\n        for key, val in hf_model.model.encoder.config.to_dict().items():\n            config[\"encoder\"][key] = f\"{val}\"\n        config[\"encoder\"][\"q_scaling\"] = '1'\n\n        # mBART has final layernorm, BART does not\n        config['encoder']['has_model_final_layernorm'] = str(\n            isinstance(hf_model, MBartForConditionalGeneration))\n\n    config[\"structure\"] = dict()\n    config[\"structure\"][\"t5_with_bias\"] = \"true\"\n    config[\"structure\"][\"use_gated_activation\"] = \"false\"\n    config[\"structure\"][\"position_embedding_type\"] = \"learned_absolute\"\n    config[\"structure\"][\"model_type\"] = args.model_type\n\n    def parse_bart_config_by_component(config, component, args):\n        assert component in ('encoder', 'decoder'), 'Unsupported component!'\n        component_config = types.SimpleNamespace()\n        component_config = copy_args_to_component_config(component_config, args)\n        component_config.n_layer = config.getint(component,\n                                                 f'{component}_layers')\n        component_config.n_head = config.getint(component,\n                                                f'{component}_attention_heads')\n        component_config.hidden_size = config.getint(component, 'd_model')\n        component_config.head_size = config.getint(\n            component,\n            'd_kv',\n            fallback=component_config.hidden_size // component_config.n_head)\n        component_config.ffn_hidden_size = config.getint(\n            component, f'{component}_ffn_dim')\n        component_config.vocab_size = config.getint(component, 'vocab_size')\n        component_config.n_positions = config.getint(component,\n                                                     'max_position_embeddings')\n        component_config.has_position_embedding = config.getboolean(\n            component, 'has_position_embedding',\n            fallback=True)  # TODO: hardcoded here\n        component_config.has_token_type_embedding = config.getboolean(\n            component, 'has_token_type_embedding', fallback=False)\n        component_config.has_embedding_layernorm = config.getboolean(\n            component, 'has_embedding_layernorm', fallback=True)\n        component_config.has_embedding_scale = config.getboolean(\n            component, 'scale_embedding')\n        component_config.q_scaling = config.getfloat(component,\n                                                     'q_scaling',\n                                                     fallback=1.0)\n        component_config.has_attention_qkvo_bias = config.getboolean(\n            'structure', 't5_with_bias', fallback=True)\n        component_config.has_mlp_bias = config.getboolean('structure',\n                                                          't5_with_bias',\n                                                          fallback=True)\n        component_config.has_model_final_layernorm = config.getboolean(\n            component, 'has_model_final_layernorm')\n        component_config.layernorm_eps = config.getfloat(component,\n                                                         'layer_norm_epsilon',\n                                                         fallback=False)\n\n        normalize_before = config.getboolean(component, 'normalize_before')\n        component_config.layernorm_position = layernorm_position_map[\n            'pre_layernorm' if normalize_before else 'post_layernorm']\n\n        component_config.layernorm_type = layernorm_type_map[config.get(\n            component, 'layernorm_type', fallback='LayerNorm')]\n        component_config.hidden_act = config.get(component,\n                                                 'activation_function')\n        component_config.gated_act = config.getboolean(component,\n                                                       'is_gated_act',\n                                                       fallback=False)\n        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.\n                                                 gated_act else 'MLP']\n        component_config.relative_attention = config.get(\n            'structure', 'position_embedding_type') == 'relative'\n\n        component_config.num_buckets = config.getint(\n            component, 'relative_attention_num_buckets', fallback=0)\n        component_config.max_distance = config.getint(\n            component, 'relative_attention_max_distance', fallback=0)\n        component_config.max_lora_rank = config.getint(component,\n                                                       'max_lora_rank',\n                                                       fallback=0)\n        component_config.lora_target_modules = literal_eval(\n            config.get(component, 'lora_target_modules', fallback=\"[]\"))\n        component_config.hf_modules_to_trtllm_modules = literal_eval(\n            config.get(component, 'hf_modules_to_trtllm_modules',\n                       fallback=\"{}\"))\n        component_config.trtllm_modules_to_hf_modules = literal_eval(\n            config.get(component, 'trtllm_modules_to_hf_modules',\n                       fallback=\"{}\"))\n        component_config.logits_dtype = config.get(component,\n                                                   'logits_dtype',\n                                                   fallback='float32')\n        component_config.position_embedding_type = config.get(\n            'structure', 'position_embedding_type')\n\n        if component == 'decoder':\n            component_config.rescale_before_lm_head = config.getboolean(\n                component, 'rescale_before_lm_head')\n\n            component_config.encoder_hidden_size = config.getint(\n                'encoder', 'd_model')\n            component_config.encoder_num_heads = config.getint(\n                'encoder', 'encoder_attention_heads')\n            component_config.encoder_head_size = config.getint(\n                'encoder',\n                'd_kv',\n                fallback=component_config.encoder_hidden_size //\n                component_config.encoder_num_heads)\n\n            # nougat has decoder_start_token_id = None, special handling\n            decoder_start_token_id = config.get('decoder',\n                                                'decoder_start_token_id')\n            component_config.decoder_start_token_id = int(\n                decoder_start_token_id\n            ) if decoder_start_token_id != \"None\" else None\n\n        return component_config\n\n    encoder_config = None\n    if not args.nougat:\n        encoder_config = parse_bart_config_by_component(config, \"encoder\", args)\n    decoder_config = parse_bart_config_by_component(config, \"decoder\", args)\n\n    return encoder_config, decoder_config\n\n\ndef convert_bart_weights_to_tllm_safetensors(config, component, params):\n    weights = {}\n\n    mapping = config.mapping\n\n    hidden_size = config.hidden_size\n\n    convert_weight_to_dtype(params, config.dtype)\n    ffn_hidden_size = config.intermediate_size\n    vocab_size = config.vocab_size\n\n    hf_param_prefix = f'model.{component}'\n    trtllm_layer_name = f'{component}_layers'\n    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'\n    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'\n    embedding_layer_names = {\n        'embed_tokens.weight': {\n            \"name\": 'embedding.vocab_embedding.weight',\n            \"shape\": (vocab_size, -1)\n        },\n        'embed_positions.weight': {\n            \"name\": 'embedding.position_embedding.weight',\n            \"shape\": (config.max_position_embeddings, hidden_size)\n        },\n        'layernorm_embedding.weight': {\n            \"name\": 'embedding.embedding_layernorm.weight',\n            \"shape\": None\n        },\n        'layernorm_embedding.bias': {\n            \"name\": 'embedding.embedding_layernorm.bias',\n            \"shape\": None\n        },\n    }\n\n    hidden_layer_name_split = {\n        'self_attn.out_proj.weight': {\n            \"name\": f'{trtllm_attn_layer_name}.dense.weight',\n            \"shape\": (hidden_size, hidden_size // mapping.tp_size),\n            \"split_dim\": -1\n        },\n        'fc1.weight': {\n            \"name\": 'mlp.fc.weight',\n            \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n            \"split_dim\": 0\n        },\n        'fc1.bias': {\n            \"name\": 'mlp.fc.bias',\n            \"shape\": (ffn_hidden_size // mapping.tp_size),\n            \"split_dim\": 0\n        },\n        'fc2.weight': {\n            \"name\": 'mlp.proj.weight',\n            \"shape\": (hidden_size, ffn_hidden_size // mapping.tp_size),\n            \"split_dim\": -1\n        },\n    }\n\n    hidden_layer_name_no_split = {\n        'self_attn.out_proj.bias': {\n            \"name\": f'{trtllm_attn_layer_name}.dense.bias',\n            \"shape\": (hidden_size)\n        },\n        'self_attn_layer_norm.weight': {\n            \"name\": f'{trtllm_attn_layernorm_name}.weight',\n            \"shape\": None\n        },\n        'self_attn_layer_norm.bias': {\n            \"name\": f'{trtllm_attn_layernorm_name}.bias',\n            \"shape\": None\n        },\n        'fc2.bias': {\n            \"name\": 'mlp.proj.bias',\n            \"shape\": (hidden_size)\n        },\n        'final_layer_norm.weight': {\n            \"name\": 'mlp_layernorm.weight',\n            \"shape\": None\n        },\n        'final_layer_norm.bias': {\n            \"name\": 'mlp_layernorm.bias',\n            \"shape\": None\n        },\n    }\n\n    if config.model_type == 'mbart':\n        hidden_layer_name_split['layer_norm.weight'] = {\n            \"name\": 'final_layernorm.weight',\n            \"shape\": None,\n            \"split_dim\": 0\n        }\n        hidden_layer_name_no_split['layer_norm.bias'] = {\n            \"name\": 'final_layernorm.bias',\n            \"shape\": None,\n            \"split_dim\": 0\n        }\n\n    if component == \"decoder\":\n        hidden_layer_name_split.update({\n            'encoder_attn.out_proj.weight': {\n                \"name\": 'cross_attention.dense.weight',\n                \"shape\": (hidden_size, hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            }\n        })\n        hidden_layer_name_no_split.update({\n            'encoder_attn.out_proj.bias': {\n                \"name\": 'cross_attention.dense.bias',\n                \"shape\": (hidden_size)\n            },\n            'encoder_attn_layer_norm.weight': {\n                \"name\": 'cross_attention_layernorm.weight',\n                \"shape\": None\n            },\n            'encoder_attn_layer_norm.bias': {\n                \"name\": 'cross_attention_layernorm.bias',\n                \"shape\": None\n            },\n        })\n\n    def get_attn_module_name(component, layer, attn_type):\n        return f'model.{component}.layers.{int(layer)}.{attn_type}'\n\n    for hf_weight_name, weight_info in embedding_layer_names.items():\n        if 'position' in hf_weight_name:\n            weights[weight_info[\"name\"]] = params[\n                f'{hf_param_prefix}.{hf_weight_name}'][2:].clone()\n        else:\n            weights[weight_info[\"name\"]] = params[\n                f'{hf_param_prefix}.{hf_weight_name}'].clone()\n        weights[weight_info[\"name\"]] = reshape(weights[weight_info[\"name\"]],\n                                               weight_info[\"shape\"])\n\n    num_layers = config.num_hidden_layers\n\n    layers_range = mapping.pp_layers(num_layers)\n    for layer_idx in layers_range:\n        local_layer_idx = layer_idx - layers_range[0]\n        hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'\n        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'\n\n        for hf_weight_name, weight_info in hidden_layer_name_split.items():\n            weights[\n                f'{trtllm_layer_name_prefix}.{weight_info[\"name\"]}'] = reshape(\n                    split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],\n                          mapping.tp_size,\n                          mapping.tp_rank,\n                          dim=weight_info[\"split_dim\"]), weight_info[\"shape\"])\n\n        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():\n            trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info[\"name\"]}'\n            hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'\n            weights[trtllm_layer_fullname] = reshape(\n                params[hf_layer_fullname].clone(), shape=weight_info[\"shape\"])\n\n        self_attn_module_name = get_attn_module_name(component, layer_idx,\n                                                     'self_attn')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',\n                mapping.tp_size, mapping.tp_rank, config.model_type,\n                (hidden_size * 3 // mapping.tp_size, hidden_size),\n                (hidden_size * 3 // mapping.tp_size)))\n        if component == 'decoder':\n            cross_attn_module_name = get_attn_module_name(\n                component, layer_idx, 'encoder_attn')\n            weights.update(\n                fuse_qkv_one_layer(\n                    params, cross_attn_module_name,\n                    f'{trtllm_layer_name_prefix}.cross_attention',\n                    mapping.tp_size, mapping.tp_rank, config.model_type,\n                    (hidden_size * 3 // mapping.tp_size, hidden_size),\n                    (hidden_size * 3 // mapping.tp_size)))\n\n    if component == 'decoder':\n        weights['lm_head.weight'] = reshape(\n            split(params['lm_head.weight'],\n                  mapping.tp_size,\n                  mapping.tp_rank,\n                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))\n\n    if config.has_model_final_layernorm:\n        weights['final_layernorm.weight'] = params[\n            f'{hf_param_prefix}.layer_norm.weight'].clone()\n        weights['final_layernorm.bias'] = params[\n            f'{hf_param_prefix}.layer_norm.bias'].clone()\n\n    return weights\n\n\ndef parse_pix2struct_config(args, hf_model):\n    # manually set q_scaling to offset attention scaling's effect.\n    # TODO: modify kernels to control whether to disable attention scaling\n    config = configparser.ConfigParser()\n\n    def get_offset_q_scaling(config) -> str:\n        d_model = config.hidden_size\n        num_heads = config.num_heads\n        head_size = d_model / num_heads\n        scaling = 1 / head_size**.5\n        return str(scaling)\n\n    config[\"decoder\"] = {}\n    for key, val in hf_model.decoder.config.to_dict().items():\n        config[\"decoder\"][key] = f\"{val}\"\n\n    config[\"decoder\"][\"q_scaling\"] = get_offset_q_scaling(\n        hf_model.decoder.config)\n\n    config[\"structure\"] = dict()\n    config[\"structure\"][\"pix2struct_with_bias\"] = \"false\"\n    config[\"structure\"][\"use_gated_activation\"] = \"false\"\n    config[\"structure\"][\"position_embedding_type\"] = \"relative\"\n    config[\"structure\"][\"model_type\"] = args.model_type\n\n    def parse_pix2struct_config_by_component(config, component, args):\n        if component == 'decoder':\n            args.n_layer = config.getint(component, 'num_layers')\n            args.n_head = config.getint(component, 'num_heads')\n            args.head_size = config.getint(component, 'd_kv')\n            args.hidden_size = config.getint(component, 'hidden_size')\n            args.ffn_hidden_size = config.getint(component, 'd_ff')\n            args.vocab_size = config.getint(component, 'vocab_size')\n            args.n_positions = config.getint(component,\n                                             'n_positions',\n                                             fallback=512)\n            args.has_position_embedding = config.getboolean(\n                component, 'has_position_embedding',\n                fallback=False)  # TODO: hardcoded here\n            args.has_token_type_embedding = config.getboolean(\n                component, 'has_token_type_embedding', fallback=False)\n            args.has_embedding_layernorm = config.getboolean(\n                component, 'has_embedding_layernorm', fallback=False)\n            args.has_embedding_scale = config.getboolean(component,\n                                                         'has_embedding_scale',\n                                                         fallback=False)\n            args.q_scaling = config.getfloat(component,\n                                             'q_scaling',\n                                             fallback=1.0)\n            args.has_attention_qkvo_bias = config.getboolean(\n                component, 'has_attention_qkvo_bias', fallback=False)\n            args.has_mlp_bias = config.getboolean(component,\n                                                  'has_mlp_bias',\n                                                  fallback=False)\n            args.has_model_final_layernorm = config.getboolean(\n                component, 'has_model_final_layernorm', fallback=True)\n            args.layernorm_eps = config.getfloat(component,\n                                                 'layer_norm_epsilon')\n            args.layernorm_position = layernorm_position_map[config.get(\n                component, 'layernorm_position',\n                fallback='pre_layernorm')]  # TODO: hardcoded here\n            args.layernorm_type = layernorm_type_map[config.get(\n                component, 'layernorm_type', fallback='RmsNorm')]\n            args.hidden_act = config.get(component, 'dense_act_fn')\n            args.gated_act = True\n            args.mlp_type = mlp_type_map['GatedMLP' if args.\n                                         gated_act else 'MLP']\n            args.has_lm_head_bias = config.getboolean(\n                component,  # TODO: T5 with bias\n                'has_lm_head_bias',\n                fallback=False)\n            args.relative_attention = config.getboolean(component,\n                                                        'relative_attention',\n                                                        fallback=True)\n            args.num_buckets = config.getint(component,\n                                             'relative_attention_num_buckets')\n            args.max_distance = config.getint(\n                component, 'relative_attention_max_distance')\n            args.logits_dtype = config.get(component,\n                                           'logits_dtype',\n                                           fallback='float32')\n            args.rescale_before_lm_head = config.getboolean(\n                component, 'tie_word_embeddings'\n            )  # default is True (for T5), but False for Flan-T5\n            args.encoder_hidden_size = config.getint('decoder', 'hidden_size')\n            args.encoder_num_heads = config.getint('decoder', 'num_heads')\n            args.encoder_head_size = config.getint('decoder', 'd_kv')\n            args.position_embedding_type = config.get(\n                'structure', 'position_embedding_type')\n            args.decoder_start_token_id = config.getint(\n                'decoder', 'decoder_start_token_id')\n\n        else:\n            assert False, 'Unsupported component!'\n        return args\n\n    decoder_args = parse_pix2struct_config_by_component(config, \"decoder\", args)\n    return None, decoder_args\n\n\ndef convert_pix2struct_weights_to_tllm_safetensors(config, component, params):\n    weights = {}\n\n    mapping = config.mapping\n\n    convert_weight_to_dtype(params, config.dtype)\n    hidden_size = config.hidden_size\n    ffn_hidden_size = config.intermediate_size\n    num_layers = config.num_hidden_layers\n    n_head = config.num_attention_heads\n    head_size = config.head_size\n    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5\n\n    hf_param_prefix = f'{component}'\n    trtllm_layer_name = f'{component}_layers'\n    trtllm_attn_layer_name = 'self_attention'\n    trtllm_attn_layernorm_name = 'self_attention_layernorm'\n\n    def get_attn_module_name(component, layer, attn_type):\n        return f'{component}.layer.{int(layer)}.{attn_type}.attention'\n\n    weights['embedding.vocab_embedding.weight'] = reshape(\n        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)\n\n    layers_range = mapping.pp_layers(num_layers)\n    for layer_idx in layers_range:\n        local_layer_idx = layer_idx - layers_range[0]\n        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'\n        hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'\n\n        hidden_layer_name_split = {\n            f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',\n                \"shape\":\n                (hidden_size, attention_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.proj.weight',\n                \"shape\": (hidden_size, ffn_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.fc.weight',\n                \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                \"split_dim\": 0\n            },\n        }\n\n        hidden_layer_name_no_split = {\n            f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',\n                \"shape\": None\n            },\n            f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',\n                \"shape\": None\n            },\n        }\n\n        if config.gated_act:\n            hidden_layer_name_split.update({\n                f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {\n                    \"name\": f'{trtllm_layer_name_prefix}.mlp.gate.weight',\n                    \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                    \"split_dim\": 0\n                },\n            })\n\n        hidden_layer_name_split.update({\n            f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':\n            {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',\n                \"shape\":\n                (hidden_size, attention_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n        })\n        hidden_layer_name_no_split.update({\n            f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':\n            {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',\n                \"shape\": None\n            },\n        })\n        self_attn_module_name = get_attn_module_name(\n            component, layer_idx, 'encoder_decoder_attention')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,\n                mapping.tp_rank, config.model_type,\n                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                None))\n\n        self_attn_module_name = get_attn_module_name(component, layer_idx,\n                                                     'self_attention')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',\n                mapping.tp_size, mapping.tp_rank, config.model_type,\n                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                None))\n\n        weights[\n            f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(\n                split(\n                    params[\n                        f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']\n                    .T, mapping.tp_size, mapping.tp_rank, 0),\n                (n_head // mapping.tp_size, config.num_buckets))\n\n        for hf_weight_name, weight_info in hidden_layer_name_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    split(params[hf_weight_name],\n                          mapping.tp_size,\n                          mapping.tp_rank,\n                          dim=weight_info[\"split_dim\"]), weight_info[\"shape\"])\n        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    params[hf_weight_name].clone(), shape=weight_info[\"shape\"])\n\n    weights[f'final_layernorm.weight'] = reshape(\n        params[f'{component}.final_layer_norm.weight'].clone(), None)\n\n    weights['lm_head.weight'] = reshape(\n        split(params[f'{component}.lm_head.weight'],\n              mapping.tp_size,\n              mapping.tp_rank,\n              dim=0), (config.vocab_size // mapping.tp_size, hidden_size))\n    if not config.use_implicit_relative_attention:\n        weights[f'rel_attn_table'] = reshape(\n            split(\n                params[\n                    f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']\n                .T, mapping.tp_size, mapping.tp_rank, 0),\n            (n_head // mapping.tp_size, config.num_buckets))\n\n    return weights\n\n\ndef parse_StructEqTable_config(args, hf_model):\n    # manually set q_scaling to offset attention scaling's effect.\n    # TODO: modify kernels to control whether to disable attention scaling\n    config = configparser.ConfigParser()\n\n    def get_offset_q_scaling(config) -> str:\n        d_model = config.hidden_size\n        num_heads = config.num_heads\n        head_size = d_model / num_heads\n        scaling = 1 / head_size**.5\n        return str(scaling)\n\n    config[\"decoder\"] = {}\n    for key, val in hf_model.decoder.config.to_dict().items():\n        config[\"decoder\"][key] = f\"{val}\"\n\n    config[\"decoder\"][\"q_scaling\"] = get_offset_q_scaling(\n        hf_model.decoder.config)\n\n    config[\"structure\"] = dict()\n    config[\"structure\"][\"pix2struct_with_bias\"] = \"false\"\n    config[\"structure\"][\"use_gated_activation\"] = \"false\"\n    config[\"structure\"][\"position_embedding_type\"] = \"relative\"\n    config[\"structure\"][\"model_type\"] = args.model_type\n\n    def parse_StructEqTable_config_by_component(config, component, args):\n        if component == 'decoder':\n            args.n_layer = config.getint(component, 'num_layers')\n            args.n_head = config.getint(component, 'num_heads')\n            args.head_size = config.getint(component, 'd_kv')\n            args.hidden_size = config.getint(component, 'hidden_size')\n            args.ffn_hidden_size = config.getint(component, 'd_ff')\n            args.vocab_size = config.getint(component, 'vocab_size')\n            args.n_positions = config.getint(component,\n                                             'n_positions',\n                                             fallback=512)\n            args.has_position_embedding = config.getboolean(\n                component, 'has_position_embedding',\n                fallback=False)  # TODO: hardcoded here\n            args.has_token_type_embedding = config.getboolean(\n                component, 'has_token_type_embedding', fallback=False)\n            args.has_embedding_layernorm = config.getboolean(\n                component, 'has_embedding_layernorm', fallback=False)\n            args.has_embedding_scale = config.getboolean(component,\n                                                         'has_embedding_scale',\n                                                         fallback=False)\n            args.q_scaling = config.getfloat(component,\n                                             'q_scaling',\n                                             fallback=1.0)\n            args.has_attention_qkvo_bias = config.getboolean(\n                component, 'has_attention_qkvo_bias', fallback=False)\n            args.has_mlp_bias = config.getboolean(component,\n                                                  'has_mlp_bias',\n                                                  fallback=False)\n            args.has_model_final_layernorm = config.getboolean(\n                component, 'has_model_final_layernorm', fallback=True)\n            args.layernorm_eps = config.getfloat(component,\n                                                 'layer_norm_epsilon')\n            args.layernorm_position = layernorm_position_map[config.get(\n                component, 'layernorm_position',\n                fallback='pre_layernorm')]  # TODO: hardcoded here\n            args.layernorm_type = layernorm_type_map[config.get(\n                component, 'layernorm_type', fallback='RmsNorm')]\n            args.hidden_act = config.get(component, 'dense_act_fn')\n            args.gated_act = True\n            args.mlp_type = mlp_type_map['GatedMLP' if args.\n                                         gated_act else 'MLP']\n            args.has_lm_head_bias = config.getboolean(\n                component,  # TODO: T5 with bias\n                'has_lm_head_bias',\n                fallback=False)\n            args.relative_attention = config.getboolean(component,\n                                                        'relative_attention',\n                                                        fallback=True)\n            args.num_buckets = config.getint(component,\n                                             'relative_attention_num_buckets')\n            args.max_distance = config.getint(\n                component, 'relative_attention_max_distance')\n            args.logits_dtype = config.get(component,\n                                           'logits_dtype',\n                                           fallback='float32')\n            args.rescale_before_lm_head = config.getboolean(\n                component, 'tie_word_embeddings'\n            )  # default is True (for T5), but False for Flan-T5\n            args.encoder_hidden_size = config.getint('decoder', 'hidden_size')\n            args.encoder_num_heads = config.getint('decoder', 'num_heads')\n            args.encoder_head_size = config.getint('decoder', 'd_kv')\n            args.position_embedding_type = config.get(\n                'structure', 'position_embedding_type')\n            args.decoder_start_token_id = config.getint(\n                'decoder', 'decoder_start_token_id')\n\n        else:\n            assert False, 'Unsupported component!'\n        return args\n\n    decoder_args = parse_StructEqTable_config_by_component(config, \"decoder\", args)\n    return None, decoder_args\n\n\ndef convert_StructEqTable_weights_to_tllm_safetensors(config, component, params):\n    weights = {}\n\n    mapping = config.mapping\n\n    convert_weight_to_dtype(params, config.dtype)\n    hidden_size = config.hidden_size\n    ffn_hidden_size = config.intermediate_size\n    num_layers = config.num_hidden_layers\n    n_head = config.num_attention_heads\n    head_size = config.head_size\n    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5\n\n    hf_param_prefix = f'{component}'\n    trtllm_layer_name = f'{component}_layers'\n    trtllm_attn_layer_name = 'self_attention'\n    trtllm_attn_layernorm_name = 'self_attention_layernorm'\n\n    def get_attn_module_name(component, layer, attn_type):\n        return f'{component}.layer.{int(layer)}.{attn_type}.attention'\n\n    weights['embedding.vocab_embedding.weight'] = reshape(\n        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)\n\n    layers_range = mapping.pp_layers(num_layers)\n    for layer_idx in layers_range:\n        local_layer_idx = layer_idx - layers_range[0]\n        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'\n        hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'\n\n        hidden_layer_name_split = {\n            f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',\n                \"shape\":\n                (hidden_size, attention_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.proj.weight',\n                \"shape\": (hidden_size, ffn_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp.fc.weight',\n                \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                \"split_dim\": 0\n            },\n        }\n\n        hidden_layer_name_no_split = {\n            f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',\n                \"shape\": None\n            },\n            f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {\n                \"name\": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',\n                \"shape\": None\n            },\n        }\n\n        if config.gated_act:\n            hidden_layer_name_split.update({\n                f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {\n                    \"name\": f'{trtllm_layer_name_prefix}.mlp.gate.weight',\n                    \"shape\": (ffn_hidden_size // mapping.tp_size, hidden_size),\n                    \"split_dim\": 0\n                },\n            })\n\n        hidden_layer_name_split.update({\n            f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':\n            {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',\n                \"shape\":\n                (hidden_size, attention_hidden_size // mapping.tp_size),\n                \"split_dim\": -1\n            },\n        })\n        hidden_layer_name_no_split.update({\n            f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':\n            {\n                \"name\":\n                f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',\n                \"shape\": None\n            },\n        })\n        self_attn_module_name = get_attn_module_name(\n            component, layer_idx, 'encoder_decoder_attention')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,\n                mapping.tp_rank, config.model_type,\n                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                None))\n\n        self_attn_module_name = get_attn_module_name(component, layer_idx,\n                                                     'self_attention')\n        weights.update(\n            fuse_qkv_one_layer(\n                params, self_attn_module_name,\n                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',\n                mapping.tp_size, mapping.tp_rank, config.model_type,\n                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),\n                None))\n\n        weights[\n            f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(\n                split(\n                    params[\n                        f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']\n                    .T, mapping.tp_size, mapping.tp_rank, 0),\n                (n_head // mapping.tp_size, config.num_buckets))\n\n        for hf_weight_name, weight_info in hidden_layer_name_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    split(params[hf_weight_name],\n                          mapping.tp_size,\n                          mapping.tp_rank,\n                          dim=weight_info[\"split_dim\"]), weight_info[\"shape\"])\n        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():\n            if hf_weight_name in params.keys():\n                weights[weight_info[\"name\"]] = reshape(\n                    params[hf_weight_name].clone(), shape=weight_info[\"shape\"])\n\n    weights[f'final_layernorm.weight'] = reshape(\n        params[f'{component}.final_layer_norm.weight'].clone(), None)\n\n    weights['lm_head.weight'] = reshape(\n        split(params[f'{component}.lm_head.weight'],\n              mapping.tp_size,\n              mapping.tp_rank,\n              dim=0), (config.vocab_size // mapping.tp_size, hidden_size))\n    if not config.use_implicit_relative_attention:\n        weights[f'rel_attn_table'] = reshape(\n            split(\n                params[\n                    f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']\n                .T, mapping.tp_size, mapping.tp_rank, 0),\n            (n_head // mapping.tp_size, config.num_buckets))\n\n    return weights\n\n\ndef get_model(args):\n    if args.model_type == \"t5\":\n        model = T5ForConditionalGeneration.from_pretrained(args.model_dir)\n    elif args.model_type == \"nmt\":\n        from fairseq.models.transformer import TransformerModel\n        model = TransformerModel.from_pretrained(args.model_dir)\n    elif args.model_type == \"bart\":\n        if args.nougat:\n            model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)\n            model = model.get_decoder()\n        else:\n            model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)\n    elif args.model_type == \"pix2struct\":\n        model = Pix2StructForConditionalGeneration.from_pretrained(\n            args.model_dir)\n    elif args.model_type == \"blip2\":\n        model = Blip2ForConditionalGeneration.from_pretrained(\n            args.model_dir).language_model\n    elif args.model_type == \"StructEqTable\":\n        model = AutoModelForVision2Seq.from_pretrained(\n            args.model_dir)\n    return model\n\n\ndef convert_checkpoint(args):\n\n    model = get_model(args)\n\n    saved_dir = Path(args.output_dir)\n    saved_dir.mkdir(parents=True, exist_ok=True)\n\n    encoder_saved_dir = saved_dir / \"encoder\"\n    encoder_saved_dir.mkdir(parents=True, exist_ok=True)\n    decoder_saved_dir = saved_dir / \"decoder\"\n    decoder_saved_dir.mkdir(parents=True, exist_ok=True)\n\n    world_size = args.tp_size * args.pp_size\n\n    kv_cache_quant_algo = None\n    quant_algo = None\n\n    model_type = args.model_type if args.model_type != \"blip2\" else \"t5\"\n    encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](\n        args, model)\n\n    additional_settings = [\"gated_act\"]\n    if not args.nougat and args.model_type != \"pix2struct\" and args.model_type != \"StructEqTable\":\n        tllm_encoder_config = {\n            'architecture': \"EncoderModel\",\n            'dtype': args.dtype,\n            'logits_dtype': encoder_config.logits_dtype,\n            'num_hidden_layers': encoder_config.n_layer,\n            'num_attention_heads': encoder_config.n_head,\n            'hidden_size': encoder_config.hidden_size,\n            'norm_epsilon': encoder_config.layernorm_eps,\n            'vocab_size': encoder_config.vocab_size,\n            'position_embedding_type': encoder_config.position_embedding_type,\n            'hidden_act': encoder_config.hidden_act,\n            'quantization': {\n                'quant_algo': quant_algo,\n                'kv_cache_quant_algo': kv_cache_quant_algo,\n            },\n            'mapping': {\n                'world_size': world_size,\n                'tp_size': args.tp_size,\n                'pp_size': args.pp_size,\n            },\n            'use_parallel_embedding': args.use_parallel_embedding,\n            'embedding_sharding_dim': args.embedding_sharding_dim,\n            'share_embedding_table': args.use_embedding_sharing,\n            'max_position_embeddings': encoder_config.n_positions,\n            'num_key_value_heads': encoder_config.n_head,\n            'head_size': encoder_config.head_size,\n            'has_position_embedding': encoder_config.has_position_embedding,\n            'layernorm_type': encoder_config.layernorm_type,\n            'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias,\n            'has_mlp_bias': encoder_config.has_mlp_bias,\n            'has_model_final_layernorm':\n            encoder_config.has_model_final_layernorm,\n            'has_embedding_layernorm': encoder_config.has_embedding_layernorm,\n            'has_embedding_scale': encoder_config.has_embedding_scale,\n            'intermediate_size': encoder_config.ffn_hidden_size,\n            'q_scaling': encoder_config.q_scaling,\n            'layernorm_position': encoder_config.layernorm_position,\n            'mlp_type': encoder_config.mlp_type,\n            'relative_attention': encoder_config.relative_attention,\n            'max_distance': encoder_config.max_distance,\n            'num_buckets': encoder_config.num_buckets,\n            'model_type': encoder_config.model_type,\n        }\n\n        for additional_setting in additional_settings:\n            if hasattr(encoder_config, additional_setting):\n                tllm_encoder_config.update({\n                    additional_setting:\n                    getattr(encoder_config, additional_setting)\n                })\n\n        with (encoder_saved_dir / \"config.json\").open('w') as f:\n            json.dump(tllm_encoder_config, f, indent=4)\n\n        encoder_convert_args = dict(params=model.state_dict(),\n                                    component=\"encoder\")\n    tllm_decoder_config = {\n        'architecture': \"DecoderModel\",\n        'dtype': args.dtype,\n        'logits_dtype': decoder_config.logits_dtype,\n        'num_hidden_layers': decoder_config.n_layer,\n        'num_attention_heads': decoder_config.n_head,\n        'hidden_size': decoder_config.hidden_size,\n        'norm_epsilon': decoder_config.layernorm_eps,\n        'vocab_size': decoder_config.vocab_size,\n        'position_embedding_type': decoder_config.position_embedding_type,\n        'hidden_act': decoder_config.hidden_act,\n        'quantization': {\n            'quant_algo': quant_algo,\n            'kv_cache_quant_algo': kv_cache_quant_algo,\n        },\n        'mapping': {\n            'world_size': world_size,\n            'tp_size': args.tp_size,\n            'pp_size': args.pp_size,\n        },\n        'use_parallel_embedding': args.use_parallel_embedding,\n        'embedding_sharding_dim': args.embedding_sharding_dim,\n        'share_embedding_table': args.use_embedding_sharing,\n        'max_position_embeddings': decoder_config.n_positions,\n        'head_size': decoder_config.head_size,\n        'has_position_embedding': decoder_config.has_position_embedding,\n        'layernorm_type': decoder_config.layernorm_type,\n        'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias,\n        'has_mlp_bias': decoder_config.has_mlp_bias,\n        'has_model_final_layernorm': decoder_config.has_model_final_layernorm,\n        'has_embedding_layernorm': decoder_config.has_embedding_layernorm,\n        'has_embedding_scale': decoder_config.has_embedding_scale,\n        'intermediate_size': decoder_config.ffn_hidden_size,\n        'q_scaling': decoder_config.q_scaling,\n        'layernorm_position': decoder_config.layernorm_position,\n        'mlp_type': decoder_config.mlp_type,\n        'relative_attention': decoder_config.relative_attention,\n        'max_distance': decoder_config.max_distance,\n        'num_buckets': decoder_config.num_buckets,\n        'model_type': decoder_config.model_type,\n        'rescale_before_lm_head': decoder_config.rescale_before_lm_head,\n        'encoder_hidden_size': decoder_config.encoder_hidden_size,\n        'encoder_num_heads': decoder_config.encoder_num_heads,\n        'encoder_head_size': decoder_config.encoder_head_size,\n        'skip_cross_qkv': args.skip_cross_qkv,\n        'use_implicit_relative_attention': args.use_implicit_relative_attention,\n        'decoder_start_token_id': decoder_config.decoder_start_token_id,\n    }\n    for additional_setting in additional_settings:\n        if hasattr(decoder_config, additional_setting):\n            tllm_decoder_config.update({\n                additional_setting:\n                getattr(decoder_config, additional_setting)\n            })\n\n    with (decoder_saved_dir / \"config.json\").open('w') as f:\n        json.dump(tllm_decoder_config, f, indent=4)\n\n    decoder_convert_args = dict(params=model.state_dict(), component=\"decoder\")\n\n    if args.model_type == \"nmt\":\n        fairseq_config = vars(model.cfg.model)  # Namespace --> dict\n        num_embeddings = fairseq_config['max_source_positions']\n        embedding_dim = fairseq_config['encoder_embed_dim']\n        padding_idx = model.models[0].encoder.embed_tokens.padding_idx  # 1\n\n        sin_pos_embedding = model.models[\n            0].encoder.embed_positions.get_embedding(\n                padding_idx + 1 + num_embeddings,\n                embedding_dim,\n                padding_idx=padding_idx)  # [2 + num_embeddings, embed_dim]\n        sin_pos_embedding = sin_pos_embedding[2:, :]  # remove offset embeddings\n\n        encoder_convert_args[\"sin_pos_embedding\"] = sin_pos_embedding\n        decoder_convert_args[\"sin_pos_embedding\"] = sin_pos_embedding\n\n    if args.workers == 1:\n        if not args.nougat and args.model_type != \"pix2struct\" and args.model_type != \"StructEqTable\":\n            convert(0, world_size, args, tllm_encoder_config,\n                    encoder_convert_args, encoder_saved_dir)\n        convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,\n                decoder_saved_dir)\n    else:\n        if args.workers > world_size:\n            args.workers = world_size\n        LOGGER.info(f'Convert checkpoint using {args.workers} workers.')\n        import torch.multiprocessing as mp\n        if not args.nougat and args.model_type != \"pix2struct\" and args.model_type != \"StructEqTable\":\n            mp.spawn(convert,\n                     nprocs=args.workers,\n                     args=(world_size, args, tllm_encoder_config,\n                           encoder_convert_args, encoder_saved_dir))\n        mp.spawn(convert,\n                 nprocs=args.workers,\n                 args=(world_size, args, tllm_decoder_config,\n                       decoder_convert_args, decoder_saved_dir))\n\n\ndef convert(worker_rank, world_size, args, model_config, convert_args,\n            saved_dir):\n    for rank in range(worker_rank, world_size, args.workers):\n        rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config))\n        rank_config.set_rank(rank)\n        weights = globals(\n        )[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors'](\n            config=rank_config, **convert_args)\n        safetensors.torch.save_file(weights,\n                                    f'{saved_dir}/rank{rank}.safetensors')\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        formatter_class=argparse.RawTextHelpFormatter)\n    parser.add_argument(\n        '--model_type',\n        type=str,\n        default='t5',\n        choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2', 'StructEqTable'],\n        help=\n        'Multimodal type when this script is used for multimodal conversion.')\n    parser.add_argument('--world_size',\n                        type=int,\n                        default=1,\n                        help='MPI world size (must equal TP * PP)')\n    parser.add_argument('--tp_size',\n                        type=int,\n                        default=1,\n                        help='N-way tensor parallelism size')\n    parser.add_argument('--pp_size',\n                        type=int,\n                        default=1,\n                        help='N-way pipeline parallelism size')\n    parser.add_argument(\"--model_dir\",\n                        \"-i\",\n                        type=str,\n                        help=\"Path to the framework checkpoint file\",\n                        required=True)\n    parser.add_argument(\"--output_dir\",\n                        \"-o\",\n                        type=str,\n                        help=\"Path to the converted TRT-LLM model weight file\",\n                        required=True)\n    parser.add_argument(\n        \"--workers\",\n        type=int,\n        help=\"How many workers to spawn for conversion (default: 4)\",\n        default=4)\n    parser.add_argument(\"--nougat\",\n                        action=\"store_true\",\n                        help=\"Model which uses vision encoder + mbart decoder\")\n    parser.add_argument(\"--verbose\",\n                        action=\"store_true\",\n                        help=\"Provide verbose messages\")\n    parser.add_argument(\n        '--use_parallel_embedding',\n        action=\"store_true\",\n        default=False,\n        help=\n        'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'\n    )\n    parser.add_argument(\n        '--embedding_sharding_dim',\n        type=int,\n        default=0,\n        choices=[0, 1],\n        help=\n        'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '\n        'To shard it along hidden dimension, set embedding_sharding_dim=1'\n        'Note: embedding sharding is only enabled when embedding_sharding_dim = 0'\n    )\n    parser.add_argument(\n        '--use_weight_only',\n        default=False,\n        action=\"store_true\",\n        help='Quantize weights for the various GEMMs to INT4/INT8.'\n        'See --weight_only_precision to set the precision')\n    parser.add_argument(\n        '--weight_only_precision',\n        const='int8',\n        type=str,\n        nargs='?',\n        default='int8',\n        choices=['int8', 'int4'],\n        help=\n        'Define the precision for the weights when using weight-only quantization.'\n        'You must also use --use_weight_only for that argument to have an impact.'\n    )\n    parser.add_argument(\n        '--use_embedding_sharing',\n        action=\"store_true\",\n        default=False,\n        help=\n        'Try to reduce the engine size by sharing the embedding lookup table between two layers.'\n        'Note: the flag might not take effect when the criteria are not met.')\n    parser.add_argument(\n        '--dtype',\n        type=str,\n        default='float16',\n        choices=['float16', 'float32', 'bfloat16'],\n        help=\n        'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.'\n    )\n    parser.add_argument(\n        '--skip_cross_qkv',\n        action='store_true',\n        help=\n        'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).'\n    )\n    parser.add_argument(\n        '--use_implicit_relative_attention',\n        action='store_true',\n        help=\n        'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.'\n    )\n    args = parser.parse_args()\n    log_format = \"%(asctime)s %(name)s [%(levelname)s] %(message)s\"\n    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,\n                        format=log_format)\n    LOGGER.info(\"\\n=============== Argument ===============\")\n    for key in vars(args):\n        LOGGER.info(f\"{key}: {vars(args)[key]}\")\n    LOGGER.info(\"========================================\")\n\n    start_time = datetime.now()\n    convert_checkpoint(args)\n    stop_time = datetime.now()\n    run_time = (stop_time - start_time)\n    LOGGER.info(\"Spend {} (h:m:s) to convert the model\".format(run_time))\n"
  },
  {
    "path": "tools/tensorrt_utils/helper.py",
    "content": "import typing\nfrom typing import Union\n\nimport numpy as np\nimport torch  # pytype: disable=import-error\n\nfrom tensorrt_llm._utils import str_dtype_to_torch\n\n\ndef split(v: Union[np.ndarray, torch.Tensor],\n          tp_size: int,\n          tp_rank: int,\n          dim=0):\n    if tp_size == 1:\n        if isinstance(v, np.ndarray):\n            return np.ascontiguousarray(v.copy())\n        else:\n            return v.clone().detach()\n    assert len(v.shape) > 1 or dim == 0\n    if isinstance(v, np.ndarray):\n        return np.ascontiguousarray(\n            np.split(v, tp_size, axis=dim)[tp_rank].copy())\n    else:\n        assert v.shape[dim] % tp_size == 0, \\\n            'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'\n        split_size = v.shape[dim] // tp_size\n        return v.split(split_size, dim=dim)[tp_rank].clone().detach()\n\n\ndef reshape(v: torch.Tensor, shape=None):\n    if shape is None:\n        return v.contiguous()\n    else:\n        return v.reshape(shape).contiguous()\n\n\ndef fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,\n                       tp_rank, model_type, weight_shape, bias_shape):\n\n    qkv_module_names = get_qkv_module_name(model_type)\n\n    weight = {}\n\n    # fuse weights of q, k, v\n    q_w = params[f'{attn_module_name}.{qkv_module_names[\"q\"]}.weight']\n    k_w = params[f'{attn_module_name}.{qkv_module_names[\"k\"]}.weight']\n    v_w = params[f'{attn_module_name}.{qkv_module_names[\"v\"]}.weight']\n\n    # fuse qkv weight\n    shape = q_w.shape  # (do, din)\n    qkv_w = torch.cat([q_w, k_w, v_w],\n                      dim=0).reshape([3, shape[0], shape[1]])  # (3, do, din)\n    qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)\n    weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,\n                                                        shape=weight_shape)\n\n    # fuse qkv biases if present\n    if f'{attn_module_name}.{qkv_module_names[\"q\"]}.bias' in params.keys(\n    ) and params[f'{attn_module_name}.{qkv_module_names[\"q\"]}.bias'] is not None:\n        q_b = params[f'{attn_module_name}.{qkv_module_names[\"q\"]}.bias']\n        k_b = params[f'{attn_module_name}.{qkv_module_names[\"k\"]}.bias']\n        v_b = params[f'{attn_module_name}.{qkv_module_names[\"v\"]}.bias']\n        shape = q_b.shape[0]  # (do,)\n        qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape])  # (3, do)\n        qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)\n        weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,\n                                                          shape=bias_shape)\n    return weight\n\n\ndef get_qkv_module_name(model_type):\n    if model_type == \"t5\":\n        q = \"q\"\n        k = \"k\"\n        v = \"v\"\n    elif model_type == \"bart\" or model_type == \"nmt\":\n        q = \"q_proj\"\n        k = \"k_proj\"\n        v = \"v_proj\"\n    elif model_type == \"pix2struct\":\n        q = \"query\"\n        k = \"key\"\n        v = \"value\"\n    elif model_type == \"StructEqTable\":\n        q = \"query\"\n        k = \"key\"\n        v = \"value\"\n    return {\"q\": q, \"k\": k, \"v\": v}\n\n\ndef convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],\n                            dtype: typing.Optional[np.dtype] = None):\n    if dtype is not None:\n        assert isinstance(dtype,\n                          str), f\"dtype must be str, but get type {type(dtype)}\"\n        for name in params.keys():\n            params[name] = params[name].to(str_dtype_to_torch(dtype))\n"
  }
]