main 1b4f0859bd82 cached
20 files
212.6 KB
47.1k tokens
112 symbols
1 requests
Download .txt
Showing preview only (222K chars total). Download the full file or copy to clipboard to get everything.
Repository: UniModal4Reasoning/StructEqTable-Deploy
Branch: main
Commit: 1b4f0859bd82
Files: 20
Total size: 212.6 KB

Directory structure:
gitextract_pzxcbxif/

├── .gitignore
├── LICENSE
├── README.md
├── docs/
│   └── TENSORRT_GETTING_STARTED.md
├── requirements.txt
├── setup.py
├── struct_eqtable/
│   ├── __init__.py
│   ├── internvl/
│   │   ├── __init__.py
│   │   ├── conversation.py
│   │   ├── internvl.py
│   │   └── internvl_lmdeploy.py
│   └── pix2s/
│       ├── __init__.py
│       ├── pix2s.py
│       └── pix2s_trt.py
└── tools/
    ├── demo/
    │   ├── demo.py
    │   └── demo.tex
    ├── scripts/
    │   └── build_tensorrt.sh
    └── tensorrt_utils/
        ├── build_visual_engine.py
        ├── convert_checkpoint.py
        └── helper.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
dist/
build/
**.egg-info/
**__pycache__/
**.cache
ckpts/
**version.py



================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
<div align="center">
<h1>StructEqTable-Deploy: A High-efficiency Open-source Toolkit for Table-to-Latex Transformation</h1>


[[ Paper ]](https://arxiv.org/abs/2505.16938) [[ Website ]](https://alpha-innovator.github.io/InternAgent-project-page) [[ Dataset🤗 ]](https://huggingface.co/datasets/U4R/DocGenome/tree/main) [[ Models🤗 ]](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) [[ Demo💬 ]](https://www.modelscope.cn/studios/HongbinZhou/StructEqTable-Demo/)


</div>

Welcome to the official repository StructEqTable-Deploy of InternScience group, a solution that converts images of Table into LaTeX/HTML/MarkDown, powered by scalable data from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/).


## Overview
Table is an effective way to represent structured data in scientific publications, financial statements, invoices, web pages, and many other scenarios. Extracting tabular data from a visual table image and performing the downstream reasoning tasks according to the extracted data is challenging, mainly due to that tables often present complicated column and row headers with spanning cell operation. To address these challenges, we present TableX, a large-scale multi-modal table benchmark extracted from [DocGenome benchmark](https://alpha-innovator.github.io/InternAgent-project-page/) for table pre-training, comprising more than 2 million high-quality Image-LaTeX pair data covering 156 disciplinary classes. Besides, benefiting from such large-scale data, we train an end-to-end model, StructEqTable, which provides the capability to precisely obtain the corresponding LaTeX description from a visual table image and perform multiple table-related reasoning tasks, including structural extraction and question answering, broadening its application scope and potential.

## Changelog
- [2024/12/12] 🔥 We have released latest model **[StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main)** with enhanced recognition stability for HTML and Markdown formats!

- [2024/10/19] We have released our latest model StructTable-InternVL2-1B!

  Thanks to IntenrVL2 powerful foundational capabilities, and through fine-tuning on the synthetic tabular data and DocGenome dataset, StructTable can convert table image into various common table formats including LaTeX, HTML, and Markdown. Moreover, inference speed has been significantly improved compared to the v0.2 version.
- [2024/8/22] We have released our StructTable-base-v0.2, fine-tuned on the DocGenome dataset. This version features improved inference speed and robustness, achieved through data augmentation and reduced image token num.
- [2024/8/08] We have released the TensorRT accelerated version, which only takes about 1 second for most images on GPU A100. Please follow the tutorial to install the environment and compile the model weights.
- [2024/7/30] We have released the first version of StructEqTable. 

## TODO

- [x] Release inference code and checkpoints of StructEqTable.
- [x] Support Chinese version of StructEqTable.
- [x] Accelerated version of StructEqTable using TensorRT-LLM.
- [x] Expand more domains of table image to improve the model's general capabilities.
- [x] Efficient inference of StructTable-InternVL2-1B by [LMDeploy](https://github.com/InternLM/lmdeploy) Tookit.
- [ ] Release our table pre-training and fine-tuning code


## Installation
``` bash 
conda create -n structeqtable python>=3.10
conda activate structeqtable

# Install from Source code  (Suggested)
git clone https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git
cd StructEqTable-Deploy
pip install -r requirements.txt
python setup develop

# or Install from Github repo
pip install "git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git"

# or Install from PyPI
pip install struct-eqtable --upgrade
```

## Model Zoo

| Base Model | Model Size | Training Data | Data Augmentation | LMDeploy | TensorRT | HuggingFace |
|---------------------|------------|------------------|-------------------|----------|----------|-------------------|
| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) |
| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.1](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/v0.1) |
| Pix2Struct-base | ~300M | DocGenome | ✔ | | ✔ | [StructTable-base v0.2](https://huggingface.co/U4R/StructTable-base/tree/v0.2) |
| Pix2Struct-base | ~300M | DocGenome | | | ✔ | [StructTable-base v0.1](https://huggingface.co/U4R/StructTable-base/tree/v0.1) |



## Quick Demo
- Run the demo/demo.py
```shell script
cd tools/demo

python demo.py \
  --image_path ./demo.png \
  --ckpt_path U4R/StructTable-InternVL2-1B \
  --output_format latex
```

- HTML or Markdown format output (Only Supported by StructTable-InternVL2-1B)

```shell script
python demo.py \
  --image_path ./demo.png \
  --ckpt_path U4R/StructTable-InternVL2-1B \
  --output_format html markdown
```

## Efficient Inference
- Install LMDeploy Tookit
```shell script
pip install lmdeploy
```

- Run the demo/demo.py
```shell script
cd tools/demo

python demo.py \
  --image_path ./demo.png \
  --ckpt_path U4R/StructTable-InternVL2-1B \
  --output_format latex \
  --lmdeploy
```


- Visualization Result

  You can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) for table visualization.
![](docs/imgs/output.png)


## Acknowledgements
- [DocGenome](https://github.com/UniModal4Reasoning/DocGenome). An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Models.
- [ChartVLM](https://github.com/UniModal4Reasoning/ChartVLM). A Versatile Benchmark and Foundation Model for Complicated Chart Reasoning.
- [Pix2Struct](https://github.com/google-research/pix2struct). Screenshot Parsing as Pretraining for Visual Language Understanding.
- [InternVL Family](https://github.com/OpenGVLab/InternVL). A Series of Powerful Foundational Vision-Language Models.
- [LMDeploy](https://github.com/InternLM/lmdeploy). A toolkit for compressing, deploying, and serving LLM and MLLM.
- [UniMERNet](https://github.com/opendatalab/UniMERNet). A Universal Network for Real-World Mathematical Expression Recognition.
- [Donut](https://huggingface.co/naver-clova-ix/donut-base). The UniMERNet's Transformer Encoder-Decoder are referenced from Donut.
- [Nougat](https://github.com/facebookresearch/nougat). Data Augmentation follows Nougat.  
- [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Model inference acceleration uses TensorRT-LLM.


## License
StructEqTable is released under the [Apache License 2.0](LICENSE)

## Citation
If you find our models / code / papers useful in your research, please consider giving ⭐ and citations 📝, thx :)  
```bibtex
@article{xia2024docgenome,
  title={DocGenome: An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Language Models},
  author={Xia, Renqiu and Mao, Song and Yan, Xiangchao and Zhou, Hongbin and Zhang, Bo and Peng, Haoyang and Pi, Jiahao and Fu, Daocheng and Wu, Wenjie and Ye, Hancheng and others},
  journal={arXiv preprint arXiv:2406.11633},
  year={2024}
}
```

## Contact Us
If you encounter any issues or have questions, please feel free to contact us via zhouhongbin@pjlab.org.cn.


================================================
FILE: docs/TENSORRT_GETTING_STARTED.md
================================================
# Getting Started
[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is used for model inference speeding up.  

All the codes are successfully tested in the following enviroments:
* Linux (18.04, 20.04, 22.04)
* Python 3.10
* Pytorch 2.0 or higher
* CUDA 12.1 or higher
* TensorRT-LLM 0.11.0 (stable version)

### 1. Conda or Python Environment Preparation


* Please follow the step 1, 2 from the [official tutorial](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) of TensorRT-LLM to install the environment.  

Note we used the TensorRT-LLM **stable version `0.11.0`**.
``` bash
# Installing on Linux
Step 1. Retrieve and launch the docker container (optional).

    You can pre-install the environment using the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit) to avoid manual environment configuration.

    ```bash
    # Obtain and start the basic docker image environment (optional).
    docker run --rm --ipc=host --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.4.1-devel-ubuntu22.04
    ```
    Note: please make sure to set `--ipc=host` as a docker run argument to avoid `Bus error (core dumped)`.

Step 2. Install TensorRT-LLM.

    ```bash
    # Install dependencies, TensorRT-LLM requires Python 3.10
    apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs

    # Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
    # If you want to install the stable version (corresponding to the release branch), please
    # remove the `--pre` option.
    pip3 install tensorrt_llm==0.11.0 --extra-index-url https://pypi.nvidia.com

    # Check installation
    python3 -c "import tensorrt_llm"
    ```

    Please note that TensorRT-LLM depends on TensorRT. In earlier versions that include TensorRT 8,
    overwriting an upgraded to a new version may require explicitly running `pip uninstall tensorrt`
    to uninstall the old version.
```
* Once you successfully execute `python3 -c "import tensorrt_llm"`, it means that you have completed Environment Preparation.  

Tips: If you want to install the environment manually, please note that the version of Python require >= 3.10


### 2. Model Compilation
You can refer to the [official tutorial](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) to complete the model compilation, or follow our instructions and use the provided scripts to implement it.

#### 2.1 Download [StructEqTable checkpoints](https://huggingface.co/U4R/StructTable-base/tree/v0.2)
```
cd StructEqTable-Deploy

# using huggingface-cli download checkpoint
huggingface-cli download --resume-download --local-dir-use-symlinks False U4R/StructTable-base --local-dir ckpts/StructTable-base

```
After above steps, the files to directory of StructEqTable-Deploy as follows:  
```
StructEqTable-Deploy
├── ckpts
│   ├── StructTable-base 
├── docs
├── struct_eqtable
├── tools
```

#### 2.2 Convert Checkpoint and Build Engine
We provide a script to help users quickly implement model compilation.

``` bash
cd StructEqTable-Deploy/tools
# execute the script to quickly compile the model.
bash scripts/build_tensorrt.sh 
```
After the script runs successfully, the built models can be found in `ckpts/StructTable-base-TensorRT`.  
The file structure in the path `ckpts/StructTable-base-TensorRT` should be as follows:  
```
ckpts
├── StructTable-base 
├── StructTable-base-TensorRT 
│   ├── trt_engines 
│   ├── trt_models
│   ├── visual_engiens
```

#### 2.3 Run Quickly Demo
Run the demo/demo.py with TensorRT mode.

``` bash
cd StructEqTable-Deploy/tools/demo

python demo.py \
  --image_path ./demo.png \
  --ckpt_path ../../ckpts/StructTable-base \
  --output_format latex
  --tensorrt ../../ckpts/StructTable-base-TensorRT
```

You may get output as follows:
```
total cost time: 0.88s
Table 0 LATEX format output:
\begin{tabular}{|c|c|c|c|}
\hline
Quantity $\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\
\hline
Mass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\
\hline
Length, translational motion & meter (m) & foot (ft) & inch (in.) \\
\hline
Time, $t$ & second (s) & second (s) & second (s) \\
\hline
Force, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\
\hline
Translational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\
\hline
Translational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\
\hline
Angle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\
\hline
Rotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\
\hline
Moment or torque, rotational action & N-m & lb-ft & lb-inch \\
\hline
Rotational stiffness constant, $k_\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\
\hline
Rotational damping constant, $c_\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\
\hline
\end{tabular}
```


### 3. Table Visualization
You can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) or Visual Studio Code LaTeX Workshop Extension for table visualization.

![](./imgs/demo.png)

================================================
FILE: requirements.txt
================================================
torch
transformers<=4.47


================================================
FILE: setup.py
================================================
from pathlib import Path
from setuptools import find_packages, setup


def write_version_to_file(version, target_file):
    with open(target_file, 'w') as f:
        print('__version__ = "%s"' % version, file=f)

if __name__ == '__main__':
    version = '0.3.3'
    write_version_to_file(version, 'struct_eqtable/version.py')
    with Path(Path(__file__).parent,
              'README.md').open(encoding='utf-8') as file:
        long_description = file.read()
    setup(
        name='struct_eqtable',
        version=version,
        description='A High-efficiency Open-source Toolkit for Table-to-Latex Transformation',
        long_description=long_description,
        long_description_content_type="text/markdown",
        install_requires=[
            'torch',
            'transformers<=4.47',
        ],
        python_requires=">=3.9",
        author='Hongbin Zhou, Xiangchao Yan, Bo Zhang',
        author_email='zhangbo@pjlab.org.cn',
        url="https://github.com/UniModal4Reasoning/StructEqTable-Deploy",
        license='Apache License 2.0',
        packages=find_packages(exclude=['demo']),
    )


================================================
FILE: struct_eqtable/__init__.py
================================================
from .pix2s import Pix2Struct, Pix2StructTensorRT
from .internvl import InternVL, InternVL_LMDeploy

from transformers import AutoConfig


__ALL_MODELS__ = {
    'Pix2Struct': Pix2Struct,
    'Pix2StructTensorRT': Pix2StructTensorRT,
    'InternVL': InternVL,
    'InternVL_LMDeploy': InternVL_LMDeploy,
}


def get_model_name(model_path):
    model_config = AutoConfig.from_pretrained(
        model_path,
        trust_remote_code=True,
    )

    if 'Pix2Struct' in model_config.architectures[0]:
        model_name = 'Pix2Struct'
    elif 'InternVL' in model_config.architectures[0]:
        model_name = 'InternVL'
    else:
        raise ValueError(f"Unsupported model type: {model_config.architectures[0]}")

    return model_name


def build_model(model_ckpt='U4R/StructTable-InternVL2-1B', **kwargs):
    model_name = get_model_name(model_ckpt)
    if model_name == 'InternVL' and kwargs.get('lmdeploy', False):
        model_name = 'InternVL_LMDeploy'
    elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):
        model_name = 'Pix2StructTensorRT'

    model = __ALL_MODELS__[model_name](
        model_ckpt, 
        **kwargs
    )

    return model

================================================
FILE: struct_eqtable/internvl/__init__.py
================================================
from .internvl import InternVL
from .internvl_lmdeploy import InternVL_LMDeploy

================================================
FILE: struct_eqtable/internvl/conversation.py
================================================
"""
Conversation prompt templates.

We kindly request that you import fastchat instead of copying this file if you wish to use it.
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
"""

import dataclasses
from enum import IntEnum, auto
from typing import Any, Dict, List, Tuple, Union


class SeparatorStyle(IntEnum):
    """Separator styles."""

    ADD_COLON_SINGLE = auto()
    ADD_COLON_TWO = auto()
    ADD_COLON_SPACE_SINGLE = auto()
    NO_COLON_SINGLE = auto()
    NO_COLON_TWO = auto()
    ADD_NEW_LINE_SINGLE = auto()
    LLAMA2 = auto()
    CHATGLM = auto()
    CHATML = auto()
    CHATINTERN = auto()
    DOLLY = auto()
    RWKV = auto()
    PHOENIX = auto()
    ROBIN = auto()
    FALCON_CHAT = auto()
    CHATGLM3 = auto()
    INTERNVL_ZH = auto()
    MPT = auto()


@dataclasses.dataclass
class Conversation:
    """A class that manages prompt templates and keeps all conversation history."""

    # The name of this template
    name: str
    # The template of the system prompt
    system_template: str = '{system_message}'
    # The system message
    system_message: str = ''
    # The names of two roles
    roles: Tuple[str] = ('USER', 'ASSISTANT')
    # All messages. Each item is (role, message).
    messages: List[List[str]] = ()
    # The number of few shot examples
    offset: int = 0
    # The separator style and configurations
    sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
    sep: str = '\n'
    sep2: str = None
    # Stop criteria (the default one is EOS token)
    stop_str: Union[str, List[str]] = None
    # Stops generation if meeting any token in this list
    stop_token_ids: List[int] = None

    def get_prompt(self) -> str:
        """Get the prompt for generation."""
        system_prompt = self.system_template.format(system_message=self.system_message)
        if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
            ret = system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ': ' + message + self.sep
                else:
                    ret += role + ':'
            return ret
        elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
            seps = [self.sep, self.sep2]
            ret = system_prompt + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ': ' + message + seps[i % 2]
                else:
                    ret += role + ':'
            return ret
        elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
            ret = system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ': ' + message + self.sep
                else:
                    ret += role + ': '  # must be end with a space
            return ret
        elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
            ret = '' if system_prompt == '' else system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + '\n' + message + self.sep
                else:
                    ret += role + '\n'
            return ret
        elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
            ret = system_prompt
            for role, message in self.messages:
                if message:
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
            seps = [self.sep, self.sep2]
            ret = system_prompt
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + message + seps[i % 2]
                else:
                    ret += role
            return ret
        elif self.sep_style == SeparatorStyle.RWKV:
            ret = system_prompt
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += (
                        role
                        + ': '
                        + message.replace('\r\n', '\n').replace('\n\n', '\n')
                    )
                    ret += '\n\n'
                else:
                    ret += role + ':'
            return ret
        elif self.sep_style == SeparatorStyle.LLAMA2:
            seps = [self.sep, self.sep2]
            if self.system_message:
                ret = system_prompt
            else:
                ret = '[INST] '
            for i, (role, message) in enumerate(self.messages):
                tag = self.roles[i % 2]
                if message:
                    if i == 0:
                        ret += message + ' '
                    else:
                        ret += tag + ' ' + message + seps[i % 2]
                else:
                    ret += tag
            return ret
        elif self.sep_style == SeparatorStyle.CHATGLM:
            # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
            # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
            round_add_n = 1 if self.name == 'chatglm2' else 0
            if system_prompt:
                ret = system_prompt + self.sep
            else:
                ret = ''

            for i, (role, message) in enumerate(self.messages):
                if i % 2 == 0:
                    ret += f'[Round {i//2 + round_add_n}]{self.sep}'

                if message:
                    ret += f'{role}:{message}{self.sep}'
                else:
                    ret += f'{role}:'
            return ret
        elif self.sep_style == SeparatorStyle.CHATML:
            ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
            for role, message in self.messages:
                if message:
                    ret += role + '\n' + message + self.sep + '\n'
                else:
                    ret += role + '\n'
            return ret
        elif self.sep_style == SeparatorStyle.CHATGLM3:
            ret = ''
            if self.system_message:
                ret += system_prompt
            for role, message in self.messages:
                if message:
                    ret += role + '\n' + ' ' + message
                else:
                    ret += role
            return ret
        elif self.sep_style == SeparatorStyle.CHATINTERN:
            # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
            seps = [self.sep, self.sep2]
            ret = system_prompt
            for i, (role, message) in enumerate(self.messages):
                # if i % 2 == 0:
                #     ret += "<s>"
                if message:
                    ret += role + ':' + message + seps[i % 2] + '\n'
                else:
                    ret += role + ':'
            return ret
        elif self.sep_style == SeparatorStyle.DOLLY:
            seps = [self.sep, self.sep2]
            ret = system_prompt
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ':\n' + message + seps[i % 2]
                    if i % 2 == 1:
                        ret += '\n\n'
                else:
                    ret += role + ':\n'
            return ret
        elif self.sep_style == SeparatorStyle.PHOENIX:
            ret = system_prompt
            for role, message in self.messages:
                if message:
                    ret += role + ': ' + '<s>' + message + '</s>'
                else:
                    ret += role + ': ' + '<s>'
            return ret
        elif self.sep_style == SeparatorStyle.ROBIN:
            ret = system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ':\n' + message + self.sep
                else:
                    ret += role + ':\n'
            return ret
        elif self.sep_style == SeparatorStyle.FALCON_CHAT:
            ret = ''
            if self.system_message:
                ret += system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ': ' + message + self.sep
                else:
                    ret += role + ':'

            return ret
        elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
            seps = [self.sep, self.sep2]
            ret = self.system_message + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ': ' + message + seps[i % 2]
                else:
                    ret += role + ':'
            return ret
        elif self.sep_style == SeparatorStyle.MPT:
            ret = system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        else:
            raise ValueError(f'Invalid style: {self.sep_style}')

    def set_system_message(self, system_message: str):
        """Set the system message."""
        self.system_message = system_message

    def append_message(self, role: str, message: str):
        """Append a new message."""
        self.messages.append([role, message])

    def update_last_message(self, message: str):
        """Update the last output.

        The last message is typically set to be None when constructing the prompt,
        so we need to update it in-place after getting the response from a model.
        """
        self.messages[-1][1] = message

    def to_gradio_chatbot(self):
        """Convert the conversation to gradio chatbot format."""
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def to_openai_api_messages(self):
        """Convert the conversation to OpenAI chat completion format."""
        ret = [{'role': 'system', 'content': self.system_message}]

        for i, (_, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append({'role': 'user', 'content': msg})
            else:
                if msg is not None:
                    ret.append({'role': 'assistant', 'content': msg})
        return ret

    def copy(self):
        return Conversation(
            name=self.name,
            system_template=self.system_template,
            system_message=self.system_message,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            stop_str=self.stop_str,
            stop_token_ids=self.stop_token_ids,
        )

    def dict(self):
        return {
            'template_name': self.name,
            'system_message': self.system_message,
            'roles': self.roles,
            'messages': self.messages,
            'offset': self.offset,
        }


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
    """Register a new conversation template."""
    if not override:
        assert (
            template.name not in conv_templates
        ), f'{template.name} has been registered.'

    conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
    """Get a conversation template."""
    return conv_templates[name].copy()


# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
# is that during training, the preprocessing function for the Hermes-2 template doesn't add
# <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.
# Therefore, they are completely equivalent during inference.
register_conv_template(
    Conversation(
        name='Hermes-2',
        system_template='<|im_start|>system\n{system_message}',
        # note: The new system prompt was not used here to avoid changes in benchmark performance.
        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
        # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
        system_message='You are a Table Image to LaTeX/Markdown/HMTL Code converter.',
        roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
        sep_style=SeparatorStyle.MPT,
        sep='<|im_end|>',
        stop_token_ids=[
            2,
            6,
            7,
            8,
        ],
        stop_str='<|endoftext|>',
    )
)


register_conv_template(
    Conversation(
        name='internlm2-chat',
        system_template='<|im_start|>system\n{system_message}',
        # note: The new system prompt was not used here to avoid changes in benchmark performance.
        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
        roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
        sep_style=SeparatorStyle.MPT,
        sep='<|im_end|>',
        stop_token_ids=[
            2,
            92543,
            92542
        ]
    )
)


register_conv_template(
    Conversation(
        name='phi3-chat',
        system_template='<|system|>\n{system_message}',
        # note: The new system prompt was not used here to avoid changes in benchmark performance.
        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
        roles=('<|user|>\n', '<|assistant|>\n'),
        sep_style=SeparatorStyle.MPT,
        sep='<|end|>',
        stop_token_ids=[
            2,
            32000,
            32007
        ]
    )
)


================================================
FILE: struct_eqtable/internvl/internvl.py
================================================
import torch

from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, GenerationConfig

from .conversation import get_conv_template

class InternVL(nn.Module):
    def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, max_time=30, flash_attn=True, **kwargs):
        super().__init__()
        self.model_path = model_path
        self.max_new_tokens = max_new_tokens
        self.max_generate_time = max_time
        self.flash_attn = flash_attn

        # init model and image processor from ckpt path
        self.init_tokenizer(model_path)
        self.init_image_processor(model_path)
        self.init_model(model_path)

        self.prompt_template = {
            'latex': '<latex>',
            'html': '<html>',
            'markdown': '<markdown>',
        }
        # support output format
        self.supported_output_format = ['latex', 'html', 'markdown']

    def init_model(self, model_path):
        self.model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=self.flash_attn,
        )
        self.model.eval()
    
    def init_image_processor(self, image_processor_path):
        self.image_processor = AutoImageProcessor.from_pretrained(
            image_processor_path,
            trust_remote_code=True,
        )

    def init_tokenizer(self, tokenizer_path):
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            trust_remote_code=True,
            use_fast=False,
        )

        self.image_context_token = '<IMG_CONTEXT>'
        self.image_token_num = 256
        self.image_start_token = '<img>'
        self.image_end_token = '</img>'
        self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(self.image_context_token)
    
    def format_image_tokens(self, path_num):
        return f'{self.image_start_token}{self.image_context_token* self.image_token_num * path_num}{self.image_end_token}'

    def forward(self, images, output_format='latex', **kwargs):
        # process image to tokens
        if not isinstance(images, list):
            images = [images] 
        
        pixel_values_list = []
        for image in images:
            path_images = self.dynamic_preprocess(
                image, image_size=448, max_num=12
            )
            pixel_values = self.image_processor(
                path_images, 
                return_tensors='pt'
            )['pixel_values'].to(torch.bfloat16)
            pixel_values_list.append(pixel_values)
        
        batch_size = len(pixel_values_list)
        conversation_list = []
        for bs_idx in range(batch_size):
            pixel_values= pixel_values_list[bs_idx].to(torch.bfloat16)

            image_tokens = self.format_image_tokens(pixel_values.shape[0])
            question = '<image>\n' + self.prompt_template[output_format]
            answer = None
        
            template = get_conv_template(self.model.config.template)
            template.append_message(template.roles[0], question)
            template.append_message(template.roles[1], answer)
            conversation = template.get_prompt()
            conversation = conversation.replace('<image>', image_tokens, 1)
            conversation_list.append(conversation)

        device = next(self.parameters()).device
        self.tokenizer.padding_side = 'left'
        model_inputs = self.tokenizer(
            conversation_list, 
            return_tensors='pt', 
            padding=True,
            max_length=self.tokenizer.model_max_length,
            truncation=True,
        ).to(device)
        pixel_values = torch.cat(pixel_values_list, axis=0).to(device)

        # generation config
        generation_config = dict(
            max_new_tokens=self.max_new_tokens,
            max_time=self.max_generate_time,
            img_context_token_id=self.img_context_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            do_sample=False,
            no_repeat_ngram_size=20,
        )

        # generate text from image tokens
        model_output = self.model.generate(
            pixel_values=pixel_values,
            input_ids=model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask, 
            **generation_config,
            # **kwargs
        )

        batch_decode_texts = self.tokenizer.batch_decode(
            model_output,
            skip_special_tokens=True
        )
        return batch_decode_texts
    
    def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio

    def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
        orig_width, orig_height = image.size
        aspect_ratio = orig_width / orig_height

        # calculate the existing image aspect ratio
        target_ratios = set(
            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
            i * j <= max_num and i * j >= min_num)
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # find the closest aspect ratio to the target
        target_aspect_ratio = self.find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size)

        # calculate the target width and height
        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]
        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

        # resize the image
        resized_img = image.resize((target_width, target_height))
        processed_images = []
        for i in range(blocks):
            box = (
                (i % (target_width // image_size)) * image_size,
                (i // (target_width // image_size)) * image_size,
                ((i % (target_width // image_size)) + 1) * image_size,
                ((i // (target_width // image_size)) + 1) * image_size
            )
            # split the image
            split_img = resized_img.crop(box)
            processed_images.append(split_img)
        assert len(processed_images) == blocks
        if use_thumbnail and len(processed_images) != 1:
            thumbnail_img = image.resize((image_size, image_size))
            processed_images.append(thumbnail_img)
        return processed_images


================================================
FILE: struct_eqtable/internvl/internvl_lmdeploy.py
================================================
import torch
from torch import nn

from transformers import AutoTokenizer
try:
    from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, ChatTemplateConfig
except:
    print("\033[93mimport lmdeploy failed, if do not use lmdeploy, ignore this message\033[0m")


class InternVL_LMDeploy(nn.Module):
    def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, batch_size=4, **kwargs):
        super().__init__()
        self.model_path = model_path
        self.max_new_tokens = max_new_tokens
        self.max_batch_size = batch_size

        # init model and tokenizer from ckpt path
        self.init_tokenizer(model_path)
        self.init_model(model_path)

        self.prompt_template = {
            'latex': '<latex>',
            'html': '<html>',
            'markdown': '<markdown>',
        }
        # support output format
        self.supported_output_format = ['latex', 'html', 'markdown']
    
    def init_tokenizer(self, tokenizer_path):
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            trust_remote_code=True,
            use_fast=False,
        )

    def init_model(self, model_path):
        engine_config = PytorchEngineConfig(
            dtype='bfloat16',
            max_batch_size=self.max_batch_size,
            cache_max_entry_count=0.1
        )
        self.pipeline = pipeline(
            model_path,
            backend_config=engine_config,
            chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2')
        )

    def forward(self, images, output_format='latex', **kwargs):
        # process image to tokens
        if not isinstance(images, list):
            images = [images] 
        
        prompts = [self.prompt_template[output_format]] * len(images)
        generation_config = GenerationConfig(
            max_new_tokens=self.max_new_tokens,
            do_sample=False,
            temperature=1.0,
            stop_token_ids=[self.tokenizer.eos_token_id],
        )
        
        responses = self.pipeline(
            [(x, y) for x, y in zip(prompts, images)],
            gen_config=generation_config,
        )
        batch_decode_texts = [responce.text for responce in responses]
        return batch_decode_texts
    



================================================
FILE: struct_eqtable/pix2s/__init__.py
================================================
from .pix2s import Pix2Struct
from .pix2s_trt import Pix2StructTensorRT
    

================================================
FILE: struct_eqtable/pix2s/pix2s.py
================================================
import torch

from torch import nn
from transformers import AutoModelForVision2Seq, AutoProcessor


class Pix2Struct(nn.Module):
    def __init__(self, model_path='U4R/StructTable-base', max_new_tokens=1024, max_time=30, **kwargs):
        super().__init__()
        self.model_path = model_path
        self.max_new_tokens = max_new_tokens
        self.max_generate_time = max_time

        # init model and image processor from ckpt path
        self.init_image_processor(model_path)
        self.init_model(model_path)

        self.special_str_list = ['\\midrule', '\\hline']
        self.supported_output_format = ['latex']

    def postprocess_latex_code(self, code):
        for special_str in self.special_str_list:
            code = code.replace(special_str, special_str + ' ')
        return code

    def init_model(self, model_path):
        self.model = AutoModelForVision2Seq.from_pretrained(model_path)
        self.model.eval()

    def init_image_processor(self, image_processor_path):
        self.data_processor = AutoProcessor.from_pretrained(image_processor_path)

    def forward(self, image, **kwargs):
        # process image to tokens
        image_tokens = self.data_processor.image_processor(
            images=image,
            return_tensors='pt',
        )

        device = next(self.parameters()).device
        for k, v in image_tokens.items():
            image_tokens[k] = v.to(device)

        # generate text from image tokens
        model_output = self.model.generate(
            flattened_patches=image_tokens['flattened_patches'],
            attention_mask=image_tokens['attention_mask'], 
            max_new_tokens=self.max_new_tokens,
            max_time=self.max_generate_time,
            no_repeat_ngram_size=20,
        )

        latex_codes = self.data_processor.batch_decode(model_output, skip_special_tokens=True)
        # postprocess
        for i, code in enumerate(latex_codes):
            latex_codes[i] = self.postprocess_latex_code(code)

        return latex_codes


================================================
FILE: struct_eqtable/pix2s/pix2s_trt.py
================================================
import os
import time
import json

import torch
import torch.nn as nn

try:
    import tensorrt_llm
    import tensorrt as trt
    import tensorrt_llm.profiler as profiler

    from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy
    from tensorrt_llm.lora_manager import LoraManager
    from tensorrt_llm.runtime import Session, TensorInfo, ModelConfig, SamplingConfig
except:
    print("\033[93mimport tensorrt_llm failed, if do not use tensorrt, ignore this message\033[0m")

from typing import List
from transformers import AutoProcessor, AutoTokenizer, AutoConfig


def trt_dtype_to_torch(dtype):
    if dtype == trt.float16:
        return torch.float16
    elif dtype == trt.float32:
        return torch.float32
    elif dtype == trt.int32:
        return torch.int32
    elif dtype == trt.bfloat16:
        return torch.bfloat16
    else:
        raise TypeError("%s is not supported" % dtype)


class Pix2StructTensorRT(nn.Module):

    def __init__(self, model_path, tensorrt_path, batch_size=1, max_new_tokens=4096, **kwargs):
        
        self.model_ckpt_path = model_path
        self.tensorrt_path = tensorrt_path
        self.batch_size = batch_size
        self.max_new_tokens = max_new_tokens

        self.llm_engine_path = os.path.join(tensorrt_path, 'llm_engines')
        self.visual_engine_path = os.path.join(tensorrt_path, 'visual_engines')
        
        device_id = torch.cuda.current_device() % torch.cuda.device_count()
        self.device_id = device_id
        self.device = "cuda:%d" % (device_id)
        
        self.stream = torch.cuda.Stream(torch.cuda.current_device())
        torch.cuda.set_stream(self.stream)

        # parse model type from visual engine config
        with open(os.path.join(self.visual_engine_path, "config.json"),
                  "r") as f:
            config = json.load(f)
        self.model_type = config['builder_config']['model_type']
        self.vision_precision = config['builder_config']['precision']

        self.vision_precision = 'float16'
        self.decoder_llm = not (
            't5' in self.model_type
            or self.model_type in ['nougat', 'pix2struct', 'StructEqTable']
        )  # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs

        self.profiling_iterations = 20

        self.init_image_encoder()
        self.init_tokenizer()
        self.init_llm()
        self.init_image_processor()

        self.special_str_list = ['\\midrule', '\\hline']
        self.supported_output_format = ['latex']

    def postprocess_latex_code(self, code):
        for special_str in self.special_str_list:
            code = code.replace(special_str, special_str + ' ')
        return code

    def init_image_processor(self):
        self.data_processor = AutoProcessor.from_pretrained(
            self.model_ckpt_path)

    def init_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_ckpt_path, use_fast=True, use_legacy=False)
        # self.tokenizer.padding_side = "right"

    def init_image_encoder(self):
        vision_encoder_path = os.path.join(self.visual_engine_path,
                                           'visual_encoder.engine')
        with open(vision_encoder_path, 'rb') as f:
            engine_buffer = f.read()
        self.visual_encoder_session = Session.from_serialized_engine(
            engine_buffer)

    def init_llm(self):

        self.model = TRTLLMEncDecModel.from_engine(
            os.path.basename(self.model_ckpt_path),
            self.llm_engine_path,
            skip_encoder=self.model_type in ['nougat', 'pix2struct', 'StructEqTable'],
            debug_mode=False,
            stream=self.stream)

        self.model_config = self.model.decoder_model_config
        self.runtime_mapping = self.model.decoder_runtime_mapping

    def __call__(self, image, **kwargs):
        # process image to tokens
        image_tokens = self.data_processor.image_processor(
            images=image,
            return_tensors='pt',
        )

        for k, v in image_tokens.items():
            image_tokens[k] = v.cuda()

        model_output = self.run(
            flattened_patches=image_tokens['flattened_patches'],
            attention_mask=image_tokens['attention_mask'], 
            max_new_tokens=self.max_new_tokens
        )

        # postprocess
        latex_codes = []
        for i, code in enumerate(model_output):
            latex_codes.append(self.postprocess_latex_code(code[0]))

        return latex_codes

    def preprocess(self, warmup, pre_prompt, post_prompt, image,
                   attention_mask):
        if not warmup:
            profiler.start("Vision")

        visual_features, visual_atts = self.get_visual_features(
            torch.stack(image['image_patches'], dim=0)
            if self.model_type == 'fuyu' else image, attention_mask)

        if not warmup:
            profiler.stop("Vision")
       
        pre_input_ids = self.tokenizer(pre_prompt,
                                        return_tensors="pt",
                                        padding=True).input_ids
        if post_prompt[0] is not None:
            post_input_ids = self.tokenizer(post_prompt,
                                            return_tensors="pt",
                                            padding=True).input_ids
            length = pre_input_ids.shape[1] + post_input_ids.shape[
                1] + visual_atts.shape[1]
        else:
            post_input_ids = None
            length = pre_input_ids.shape[1] + visual_atts.shape[1]

        input_lengths = torch.IntTensor([length] * 1).to(
            torch.int32)

        input_ids, ptuning_args = self.setup_fake_prompts(
            visual_features, pre_input_ids, post_input_ids, input_lengths)

        return input_ids, input_lengths, ptuning_args, visual_features

    def generate(self, pre_prompt, post_prompt, image, decoder_input_ids,
                 max_new_tokens, attention_mask, warmup):
        if not warmup:
            profiler.start("Generate")

        input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
            warmup, pre_prompt, post_prompt, image, attention_mask)

        if warmup: return None

        profiler.start("LLM")

        # Trim encoder input_ids to match visual features shape
        ids_shape = (self.batch_size, visual_features.shape[1])

        input_ids = torch.ones(ids_shape, dtype=torch.int32)

        output_ids = self.model.generate(
            input_ids,
            decoder_input_ids,
            max_new_tokens,
            num_beams=1,
            bos_token_id=self.tokenizer.bos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            debug_mode=False,
            prompt_embedding_table=ptuning_args[0],
            prompt_tasks=ptuning_args[1],
            prompt_vocab_size=ptuning_args[2],
            attention_mask=attention_mask)

        # Reset input_lengths to match decoder_input_ids
        input_lengths = torch.ones(input_lengths.shape,
                                    dtype=input_lengths.dtype)
        profiler.stop("LLM")

        if tensorrt_llm.mpi_rank() == 0:
            # Extract a list of tensors of shape beam_width x output_ids.
            output_beams_list = [
                self.tokenizer.batch_decode(
                    output_ids[batch_idx, :, input_lengths[batch_idx]:],
                    skip_special_tokens=True)
                for batch_idx in range(self.batch_size)
            ]

            stripped_text = [[
                output_beams_list[batch_idx][beam_idx].strip()
                for beam_idx in range(1)
            ] for batch_idx in range(self.batch_size)]
            profiler.stop("Generate")
            return stripped_text
        else:
            profiler.stop("Generate")
            return None
        
    def get_visual_features(self, image, attention_mask):
        visual_features = {
            'input':
            image.to(
                tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))
        }
        if attention_mask is not None:
            visual_features['attention_mask'] = attention_mask
        tensor_info = [
            TensorInfo('input', str_dtype_to_trt(self.vision_precision),
                       image.shape)
        ]
        if attention_mask is not None:
            tensor_info.append(
                TensorInfo('attention_mask', trt.DataType.INT32,
                           attention_mask.shape))
        visual_output_info = self.visual_encoder_session.infer_shapes(
            tensor_info)
        visual_outputs = {
            t.name: torch.empty(tuple(t.shape),
                                dtype=trt_dtype_to_torch(t.dtype),
                                device=image.device)
            for t in visual_output_info
        }

        ok = self.visual_encoder_session.run(visual_features, visual_outputs,
                                             self.stream.cuda_stream)
        assert ok, "Runtime execution failed for vision encoder session"
        self.stream.synchronize()

        image_embeds = visual_outputs['output']
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        return image_embeds, image_atts
    
    def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,
                           input_lengths):
        # Assemble fake prompts which points to image embedding actually
        fake_prompt_id = torch.arange(
            self.model_config.vocab_size, self.model_config.vocab_size +
            visual_features.shape[0] * visual_features.shape[1])
        fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0],
                                                visual_features.shape[1])

        if post_input_ids is not None:
            input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
        else:
            input_ids = [fake_prompt_id, pre_input_ids]
        
        input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)

        if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
            ptuning_args = self.ptuning_setup(visual_features, input_ids,
                                              input_lengths)
        else:
            ptuning_args = [None, None, None]

        return input_ids, ptuning_args

    def ptuning_setup(self, prompt_table, input_ids, input_lengths):
        hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
        if prompt_table is not None:
            task_vocab_size = torch.tensor(
                [prompt_table.shape[1]],
                dtype=torch.int32,
            ).cuda()
            prompt_table = prompt_table.view(
                (prompt_table.shape[0] * prompt_table.shape[1],
                 prompt_table.shape[2]))
            assert prompt_table.shape[
                1] == hidden_size, "Prompt table dimensions do not match hidden size"

            prompt_table = prompt_table.cuda().to(
                dtype=tensorrt_llm._utils.str_dtype_to_torch(
                    self.model_config.dtype))
        else:
            prompt_table = torch.empty([1, hidden_size]).cuda()
            task_vocab_size = torch.zeros([1]).cuda()

        if self.model_config.remove_input_padding:
            tasks = torch.zeros([torch.sum(input_lengths)],
                                dtype=torch.int32).cuda()
            if self.decoder_llm: tasks = tasks.unsqueeze(0)
        else:
            tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()

        return [prompt_table, tasks, task_vocab_size]

    def setup_inputs(self, input_text, raw_image):
        attention_mask = None
       
        image_processor = AutoProcessor.from_pretrained(self.model_ckpt_path)
        if input_text is None:
            input_text = ""
        inputs = image_processor(
            images=raw_image,
            text=input_text,
            return_tensors="pt",
        )
        image = inputs['flattened_patches']
        image = image.expand(self.batch_size, -1, -1).contiguous()
        attention_mask = inputs['attention_mask'].to(self.device).to(
            torch.int)
        attention_mask = attention_mask.expand(self.batch_size,
                                                -1).contiguous()
        pre_prompt = ""
        post_prompt = None

        # Repeat inputs to match batch size
        pre_prompt = [pre_prompt] * self.batch_size
        post_prompt = [post_prompt] * self.batch_size
        image = image.to(self.device)

        # Generate decoder_input_ids for enc-dec models
        # Custom prompts can be added as:
        # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
        if self.decoder_llm:
            decoder_input_ids = None
        else:
            config = AutoConfig.from_pretrained(self.model_ckpt_path)
            decoder_start_id = config.decoder_start_token_id  # T5
            if decoder_start_id is None:
                decoder_start_id = config.decoder.bos_token_id  # Nougat

            decoder_input_ids = torch.IntTensor([[decoder_start_id]])
            decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))

        return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask

    def run(self, flattened_patches, attention_mask, max_new_tokens):
        # input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs(
        #     None, raw_image)
        pre_prompt = [""] * self.batch_size
        post_prompt = [None] * self.batch_size
        config = AutoConfig.from_pretrained(self.model_ckpt_path)
        decoder_start_id = config.decoder_start_token_id  # T5 
        decoder_input_ids = torch.IntTensor([[decoder_start_id]])
        decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))

        processed_image = flattened_patches.expand(self.batch_size, -1, -1).contiguous()
        attention_mask = attention_mask.to(self.device).to(torch.int)
        attention_mask = attention_mask.expand(self.batch_size,-1).contiguous()

        self.generate(pre_prompt,
                       post_prompt,
                       processed_image,
                       decoder_input_ids,
                       max_new_tokens,
                       attention_mask=attention_mask,
                       warmup=True)
        # num_iters = self.profiling_iterations if self.args.run_profiling else 1
        num_iters = 1
        # print(num_iters)
        for _ in range(num_iters):
            output_text = self.generate(pre_prompt,
                                         post_prompt,
                                         processed_image,
                                         decoder_input_ids,
                                         max_new_tokens,
                                         attention_mask=attention_mask,
                                         warmup=False)
        # if self.runtime_rank == 0:
        #     self.print_result(input_text, output_text)
        return output_text


def read_config(config_path):
    with open(config_path, "r") as f:
        config = json.load(f)

    builder_config = config['build_config']
    plugin_config = builder_config['plugin_config']
    pretrained_config = config['pretrained_config']
    lora_config = builder_config['lora_config']
    auto_parallel_config = builder_config['auto_parallel_config']
    use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
    remove_input_padding = plugin_config["remove_input_padding"]
    use_lora_plugin = plugin_config["lora_plugin"]
    tp_size = pretrained_config['mapping']['tp_size']
    pp_size = pretrained_config['mapping']['pp_size']
    gpus_per_node = auto_parallel_config['gpus_per_node']
    world_size = tp_size * pp_size
    assert world_size == tensorrt_llm.mpi_world_size(), \
        f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
    num_heads = pretrained_config["num_attention_heads"]
    hidden_size = pretrained_config["hidden_size"]
    head_size = pretrained_config["head_size"]
    vocab_size = pretrained_config["vocab_size"]
    max_batch_size = builder_config["max_batch_size"]
    max_beam_width = builder_config["max_beam_width"]
    num_layers = pretrained_config["num_hidden_layers"]
    num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)

    assert (num_heads % tp_size) == 0
    num_heads = num_heads // tp_size
    hidden_size = hidden_size // tp_size
    num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size

    cross_attention = pretrained_config["architecture"] == "DecoderModel"
    skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False)
    has_position_embedding = pretrained_config["has_position_embedding"]
    has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
    use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False)
    dtype = pretrained_config["dtype"]

    paged_kv_cache = plugin_config['paged_kv_cache']
    tokens_per_block = plugin_config['tokens_per_block']

    gather_context_logits = builder_config.get('gather_context_logits', False)
    gather_generation_logits = builder_config.get('gather_generation_logits',
                                                  False)
    max_prompt_embedding_table_size = builder_config.get(
        'max_prompt_embedding_table_size', 0)

    model_config = ModelConfig(
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        hidden_size=hidden_size,
        head_size=head_size,
        max_batch_size=max_batch_size,
        max_beam_width=max_beam_width,
        vocab_size=vocab_size,
        num_layers=num_layers,
        gpt_attention_plugin=use_gpt_attention_plugin,
        remove_input_padding=remove_input_padding,
        paged_kv_cache=paged_kv_cache,
        tokens_per_block=tokens_per_block,
        cross_attention=cross_attention,
        has_position_embedding=has_position_embedding,
        has_token_type_embedding=has_token_type_embedding,
        use_custom_all_reduce=use_custom_all_reduce,
        dtype=dtype,
        gather_context_logits=gather_context_logits,
        gather_generation_logits=gather_generation_logits,
        max_prompt_embedding_table_size=max_prompt_embedding_table_size,
        lora_plugin=use_lora_plugin,
        lora_target_modules=lora_config.get('lora_target_modules'),
        trtllm_modules_to_hf_modules=lora_config.get(
            'trtllm_modules_to_hf_modules'),
        skip_cross_qkv=skip_cross_qkv,
    )

    return model_config, tp_size, pp_size, gpus_per_node, dtype


class Mapping(object):
    def __init__(
            self,
            world_size=1,
            rank=0,
            gpus_per_node=8,
            tp_size=1,
            pp_size=1,
            moe_tp_size=-1,  # -1 means no moe
            moe_ep_size=-1):  # -1 means no moe
        # set default values for non-moe cases
        if moe_tp_size == -1:
            moe_tp_size = tp_size
            moe_ep_size = 1

        if pp_size * tp_size != world_size:
            raise ValueError(
                f"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}"
            )

        moe_tp_ep_size = moe_tp_size * moe_ep_size
        if moe_tp_ep_size != tp_size:
            raise ValueError(
                f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
            )

        self.tp_size = tp_size
        self.pp_size = pp_size
        self.moe_tp_size = moe_tp_size
        self.moe_ep_size = moe_ep_size
        self.world_size = world_size
        self.rank = rank
        self.gpus_per_node = gpus_per_node

        self.pp_groups = []
        self.tp_groups = []
        self.moe_tp_groups = []
        self.moe_ep_groups = []

        # init pp group
        for i in range(tp_size):
            ranks = range(i+ self.rank, world_size+ self.rank, tp_size)
            self.pp_groups.append(list(ranks))

        # init tp group
        for i in range(pp_size):
            ranks = range(i * tp_size + self.rank, (i + 1) * tp_size + self.rank)
            self.tp_groups.append(list(ranks))

        # init moe tp group
        for i in range(pp_size):
            for j in range(moe_ep_size):
                ranks = range(i * moe_tp_ep_size + j, (i + 1) * moe_tp_ep_size,
                              moe_ep_size)
                self.moe_tp_groups.append(list(ranks))

        # init moe ep group
        for i in range(pp_size):
            for j in range(moe_tp_size):
                ranks = range(i * moe_tp_ep_size + j * moe_ep_size,
                              i * moe_tp_ep_size + (j + 1) * moe_ep_size)
                self.moe_ep_groups.append(list(ranks))

        # self.pp_rank = self.rank // self.tp_size
        # self.tp_rank = self.rank % self.tp_size
        self.pp_rank = 0
        self.tp_rank = 0
        self.moe_tp_rank = self.tp_rank // self.moe_ep_size
        self.moe_ep_rank = self.tp_rank % self.moe_ep_size

        # self.tp_group = self.tp_groups[self.pp_rank]
        # self.pp_group = self.pp_groups[self.tp_rank]
        self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size +
                                               self.moe_ep_rank]
        self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size +
                                               self.moe_tp_rank]

        self.node_rank = self.rank // self.gpus_per_node
        self.local_rank = self.rank % self.gpus_per_node

    def get_node_rank(self, rank: int):
        return rank // self.gpus_per_node

    def get_local_rank(self, rank: int):
        return rank % self.gpus_per_node

    def has_tp(self):
        return self.tp_size > 1

    def is_last_pp_rank(self):
        return self.pp_rank == self.pp_size - 1

    def is_first_pp_rank(self):
        return self.pp_rank == 0

    def has_pp(self):
        return self.pp_size > 1

    def prev_pp_rank(self):
        p = self.rank - self.tp_size
        if p < 0:
            p = p + self.world_size
        return p

    def next_pp_rank(self):
        p = self.rank + self.tp_size
        if p >= self.world_size:
            p = p - self.world_size
        return p

    def has_moe_tp(self):
        return self.moe_tp_size > 1

    def has_moe_ep(self):
        return self.moe_ep_size > 1

    def pp_layers(self, num_layers: int) -> List[int]:
        layers_per_pipeline_stage = num_layers // self.pp_size
        layers_range = range(self.pp_rank * layers_per_pipeline_stage,
                             (self.pp_rank + 1) * layers_per_pipeline_stage)
        return list(layers_range)

    def ep_experts(self, num_experts: int) -> List[int]:
        experts_per_rank = num_experts // self.moe_ep_size
        experts_range = range(self.moe_ep_rank * experts_per_rank,
                              (self.moe_ep_rank + 1) * experts_per_rank)
        return list(experts_range)


def get_engine_name(rank):
    return 'rank{}.engine'.format(rank)

class TRTLLMEncDecModel:

    def __init__(
        self,
        engine_name,
        engine_dir,
        lora_dir=None,
        lora_task_uids=None,
        debug_mode=False,
        skip_encoder=False,
        stream: torch.cuda.Stream = None,
    ):
        # in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
        # accordingly, all input & output tensors should be moved to current device
        # otherwise, it's default to 'cuda:0'
        
        # self.runtime_rank = tensorrt_llm.mpi_rank()
        self.device_id = torch.cuda.current_device()
        # torch.cuda.set_device(device_id)
        self.device = torch.cuda.current_device()
        self.skip_encoder = skip_encoder
        self.lora_task_uids = lora_task_uids

        # when enc-dec runs by itself, stream can be None and we create new stream here
        # when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
        self.stream = stream
        if self.stream is None:
            self.stream = torch.cuda.Stream(self.device)
        torch.cuda.set_stream(self.stream)

        def engine_setup(component):
            # model config
            config_path = os.path.join(engine_dir, component, "config.json")
            model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
                config_path)

            # MGMN config
            world_size = tp_size * pp_size
            # runtime_rank = tensorrt_llm.mpi_rank()
            runtime_rank = torch.cuda.current_device()
            # assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
            # runtime_mapping = tensorrt_llm.Mapping(world_size,
            #                                        runtime_rank,
            #                                        tp_size=tp_size,
            #                                        pp_size=pp_size,
            #                                        gpus_per_node=gpus_per_node)
            # tensorrt_llm.Mapping
            runtime_mapping = Mapping(world_size,
                                      runtime_rank,
                                      tp_size=tp_size,
                                      pp_size=pp_size,
                                      gpus_per_node=gpus_per_node)
            # load engine
            # engine_fname = get_engine_name(runtime_rank)
            engine_fname = get_engine_name(0)
            with open(os.path.join(engine_dir, component, engine_fname), "rb") as f:
                engine_buffer = f.read()

            return model_config, runtime_mapping, engine_buffer

        # Note: encoder and decoder doesn't necessarily have the same TP & PP config

        if not skip_encoder:
            self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(
                component='encoder')

            self.nccl_comm = None
            if self.encoder_runtime_mapping.has_pp():
                # for Pipeline Parallelism in encoder
                self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
                    self.encoder_runtime_mapping.tp_size,
                    self.encoder_runtime_mapping.pp_size,
                    self.encoder_runtime_mapping.rank)

            # session setup
            self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
                encoder_engine_buffer)

            # encoder lora manager setup
            if self.encoder_model_config.lora_plugin:
                self.encoder_lora_manager = LoraManager()
                # TODO: this is only for bart
                self.encoder_lora_manager.load_from_hf(
                    model_dirs=lora_dir,
                    model_config=self.encoder_model_config,
                    runtime_mapping=self.encoder_runtime_mapping,
                    component='encoder',
                )
            else:
                self.encoder_lora_manager = None
        else:
            self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None
            self.nccl_comm, self.encoder_session = None, None

        self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(
            component='decoder')

        self.decoder_session = tensorrt_llm.runtime.GenerationSession(
            self.decoder_model_config,
            decoder_engine_buffer,
            self.decoder_runtime_mapping,
            debug_mode=debug_mode)

        # decoder lora manager setup
        if self.decoder_model_config.lora_plugin:
            self.decoder_lora_manager = LoraManager()
            # TODO: this is only for bart
            self.decoder_lora_manager.load_from_hf(
                model_dirs=lora_dir,
                model_config=self.decoder_model_config,
                runtime_mapping=self.decoder_runtime_mapping,
                component='decoder',
            )
        else:
            self.decoder_lora_manager = None
    
    @classmethod
    def from_engine(cls,
                    engine_name,
                    engine_dir,
                    lora_dir=None,
                    lora_task_uids=None,
                    debug_mode=False,
                    skip_encoder=False,
                    stream=None):
        return cls(engine_name,
                   engine_dir,
                   lora_dir,
                   lora_task_uids,
                   debug_mode=debug_mode,
                   skip_encoder=skip_encoder,
                   stream=stream)

    def process_input(self,
                      input_ids,
                      remove_input_padding=False,
                      pad_token_id=0,
                      prompt_tasks=None):
        if remove_input_padding:
            # in remove padding mode --> flatten input, calculate actual length and max length
            # Note: 1st token should never be removed, even if it is pad_token_id
            first_ids = input_ids[:, 0]
            input_ids = input_ids[:, 1:]
            input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
                torch.IntTensor).to(self.device)  # [batch_size]
            new_ids = []
            for i in range(len(input_ids)):
                row = input_ids[i, :]
                row = row[row != pad_token_id]
                new_ids.append(
                    torch.cat(
                        (torch.IntTensor([first_ids[i]]).to(self.device), row)))
            input_ids = torch.cat(new_ids)  # [num_tokens]
            if prompt_tasks is not None:
                prompt_tasks = prompt_tasks[:input_ids.shape[0]]
        else:
            # in padding mode --> keep input, just calculate actual length and max length
            # Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
            input_lengths = torch.tensor(
                1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type(
                    torch.IntTensor).to(self.device),
                dtype=torch.int32,
                device=self.device)
        max_input_length = torch.max(input_lengths).item()
        return input_ids, input_lengths, max_input_length, prompt_tasks

    def encoder_run(self,
                    input_ids,
                    input_lengths,
                    max_input_length,
                    position_ids=None,
                    token_type_ids=None,
                    debug_mode=False,
                    prompt_embedding_table=None,
                    prompt_tasks=None,
                    prompt_vocab_size=None,
                    attention_mask=None):

        # each engine has hidden_dim/TP, don't forget to multiply TP
        hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
        if input_ids.dim() == 1:
            hidden_states_shape = (input_ids.shape[0], hidden_size
                                   )  # [num_tokens,D]
        else:
            hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],
                                   hidden_size)  # [BS,seqlen,D]
        hidden_states_dtype = lambda name: trt_dtype_to_torch(
            self.encoder_session.engine.get_tensor_dtype(name))

        # input tensors. only first PP rank has id input, others are hidden_states input
        inputs = {}
        if self.encoder_runtime_mapping.is_first_pp_rank():
            inputs['input_ids'] = input_ids.contiguous()
            if self.encoder_model_config.has_position_embedding:
                if position_ids is None:
                    if self.encoder_model_config.remove_input_padding:
                        position_ids = [
                            torch.arange(sample_length,
                                         dtype=torch.int32,
                                         device=input_ids.device)
                            for sample_length in torch_to_numpy(input_lengths)
                        ]
                        position_ids = torch.cat(position_ids)
                    else:
                        bsz, seq_len = input_ids.shape[:2]
                        position_ids = torch.arange(
                            seq_len, dtype=torch.int32,
                            device=input_ids.device).expand(bsz, -1)
                inputs['position_ids'] = position_ids.contiguous()
            if self.encoder_model_config.has_token_type_embedding:
                inputs['token_type_ids'] = token_type_ids.contiguous()

            if self.encoder_model_config.max_prompt_embedding_table_size > 0:
                inputs[
                    'prompt_embedding_table'] = prompt_embedding_table.contiguous(
                    )
                inputs['tasks'] = prompt_tasks.contiguous()
                inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()
        else:
            # just need a placeholder, engine will call NCCL to recv and fill data from previous rank
            inputs['hidden_states_input'] = torch.empty(
                hidden_states_shape,
                dtype=hidden_states_dtype('hidden_states_input'),
                device=self.device).contiguous()
        if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:
            inputs['attention_mask'] = attention_mask.contiguous()

        inputs['input_lengths'] = input_lengths
        # use shape info to pass max length info in remove padding mode
        inputs['max_input_length'] = torch.empty(
            (max_input_length, ),
            dtype=hidden_states_dtype('max_input_length'),
            device=self.device).contiguous()
        batch_size = input_lengths.size(0)
        inputs['host_request_types'] = torch.IntTensor([0] *
                                                       batch_size).to('cpu')
        if self.encoder_model_config.remove_input_padding:
            inputs['host_context_lengths'] = input_lengths.to('cpu')

        if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
            inputs.update(
                self.encoder_lora_manager.input_buffers(
                    self.lora_task_uids,
                    self.encoder_runtime_mapping,
                    self.encoder_model_config.num_layers,
                ))

        # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
        self.encoder_session.set_shapes(inputs)

        # output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
        outputs = {}
        if self.encoder_runtime_mapping.is_last_pp_rank():
            outputs['encoder_output'] = torch.empty(
                hidden_states_shape,
                dtype=hidden_states_dtype('encoder_output'),
                device=self.device).contiguous()
        else:
            outputs['hidden_states_output'] = torch.empty(
                hidden_states_shape,
                dtype=hidden_states_dtype('hidden_states_output'),
                device=self.device).contiguous()

        # -------------------------------------------
        if debug_mode:
            engine = self.encoder_session.engine
            context = self.encoder_session.context
            # setup debugging buffer for the encoder
            for i in range(self.encoder_session.engine.num_io_tensors):
                name = engine.get_tensor_name(i)
                if engine.get_tensor_mode(
                        name
                ) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
                    dtype = engine.get_tensor_dtype(name)
                    shape = context.get_tensor_shape(name)
                    outputs[name] = torch.zeros(tuple(shape),
                                                dtype=trt_dtype_to_torch(dtype),
                                                device=self.device)
                    context.set_tensor_address(name, outputs[name].data_ptr())
        # -------------------------------------------

        # TRT session run
        # Note: need cuda stream ID, not a torch Stream
        ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
        assert ok, "Runtime execution failed"
        self.stream.synchronize()

        # Tensor Parallelism is handled by model/engine definition
        # But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
        # After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
        def pp_communicate_encoder_output(encoder_output):
            if self.encoder_runtime_mapping.is_last_pp_rank():
                for pp_rank in self.encoder_runtime_mapping.pp_group:
                    if pp_rank != self.encoder_runtime_mapping.rank:
                        self.nccl_comm.send(encoder_output, pp_rank)
                return encoder_output
            else:
                self.nccl_comm.recv(encoder_output,
                                    self.encoder_runtime_mapping.pp_group[-1])
                return encoder_output

        if self.encoder_runtime_mapping.has_pp():
            # use hidden_states output buffer to receive output as the shapes are same
            encoder_output_buf = outputs[
                'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(
                ) else outputs['hidden_states_output']
            encoder_output = pp_communicate_encoder_output(encoder_output_buf)
        else:
            encoder_output = outputs['encoder_output']

        return encoder_output

    def generate(self,
                 encoder_input_ids,
                 decoder_input_ids,
                 max_new_tokens,
                 num_beams=1,
                 pad_token_id=None,
                 eos_token_id=None,
                 bos_token_id=None,
                 debug_mode=False,
                 return_dict=False,
                 prompt_embedding_table=None,
                 prompt_tasks=None,
                 prompt_vocab_size=None,
                 attention_mask=None,
                 time_encoder=False,
                 return_encoder_output=False):
        ## ensure all externally provided tensors are on the correct device.
        encoder_input_ids = encoder_input_ids.to(self.device)
        decoder_input_ids = decoder_input_ids.to(self.device)

        if attention_mask is not None:
            attention_mask = torch.tensor(attention_mask,
                                          dtype=torch.int32,
                                          device=self.device)

        ## encoder run
        encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding

        encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(
            encoder_input_ids, encoder_remove_input_padding, pad_token_id,
            prompt_tasks)

        if not self.skip_encoder:
            #logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
            if time_encoder:
                tik = time.time()
            encoder_output = self.encoder_run(
                encoder_input_ids,
                encoder_input_lengths,
                encoder_max_input_length,
                debug_mode=debug_mode,
                prompt_embedding_table=prompt_embedding_table,
                prompt_tasks=prompt_tasks,
                prompt_vocab_size=prompt_vocab_size,
                attention_mask=attention_mask)
            if time_encoder:
                tok = time.time()
                print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
        else:
            encoder_output = prompt_embedding_table
            if encoder_input_ids.dim() > 1:
                encoder_output = encoder_output.unsqueeze(0)

        ## decoder run
        # logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
        decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(
            decoder_input_ids, self.decoder_model_config.remove_input_padding,
            pad_token_id)

        # `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
        # where query_len happens to be 1 in current cases, but not necessarily always, and
        # `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
        # the query_len is always 1 since we have kv cache.
        cross_attention_mask = None
        if attention_mask is not None:
            cross_attention_mask = torch.tensor(attention_mask,
                                                dtype=torch.int32,
                                                device=self.device).reshape(
                                                    attention_mask.shape[0], 1,
                                                    attention_mask.shape[1])

        # generation config
        sampling_config = SamplingConfig(end_id=eos_token_id,
                                         pad_id=pad_token_id,
                                         num_beams=num_beams,
                                         min_length=1,
                                         return_dict=return_dict)
        sampling_config.update(output_cum_log_probs=return_dict,
                               output_log_probs=return_dict)

        # decoder autoregressive generation
        self.decoder_session.setup(
            decoder_input_lengths.size(0),
            decoder_max_input_length,
            max_new_tokens,
            num_beams,
            max_attention_window_size=None,
            encoder_max_input_length=encoder_max_input_length,
            lora_manager=self.decoder_lora_manager,
            lora_uids=self.lora_task_uids,
        )

        output = self.decoder_session.decode(
            decoder_input_ids,
            decoder_input_lengths,
            sampling_config,
            encoder_output=encoder_output,
            encoder_input_lengths=encoder_input_lengths,
            return_dict=return_dict,
            cross_attention_mask=cross_attention_mask)

        if return_dict and return_encoder_output:
            output['encoder_output'] = encoder_output

        return output


================================================
FILE: tools/demo/demo.py
================================================
import time
import torch
import argparse

from PIL import Image
from struct_eqtable import build_model


def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--image_path', type=str, default='demo.png', help='data path for table image')
    parser.add_argument('--ckpt_path', type=str, default='U4R/StructTable-InternVL2-1B', help='ckpt path for table model, which can be downloaded from huggingface')
    parser.add_argument('--max_new_tokens', type=int, default=1024, help='maximum output tokens of model inference')
    parser.add_argument('-t', '--max_waiting_time', type=int, default=60, help='maximum waiting time of model inference')
    parser.add_argument('-f', '--output_format', type=str, nargs='+', default=['latex'], 
                        help='The model outputs LaTeX format code by default. Simple structured table LaTeX code can be converted to HTML or Markdown format using pypandoc.')
    parser.add_argument('--tensorrt_path', type=str, default=None, help='enable tensorrt for model acceleration')
    parser.add_argument('--lmdeploy', action='store_true', help='use lmdepoly to accelerate model inference')
    parser.add_argument('--disable_flash_attn', action='store_true', help='disable flash attention for non ampere gpu')
    args = parser.parse_args()
    return args

def main():
    args = parse_config()

    # build model
    model = build_model(
        args.ckpt_path, 
        max_new_tokens=args.max_new_tokens, 
        max_time=args.max_waiting_time,
        tensorrt_path=args.tensorrt_path,
        lmdeploy=args.lmdeploy,
        flash_attn=not args.disable_flash_attn
    )

    assert torch.cuda.is_available(), "Our model current only support with gpu"
    if not args.tensorrt_path:
        model = model.cuda()

    # process output format
    output_formats = list(set(args.output_format) & set(model.supported_output_format))
    print(f"Supported output format: {' '.join(output_formats)}")

    # model inference
    raw_image = Image.open(args.image_path)

    output_list = []
    start_time = time.time()

    with torch.no_grad():
        for tgt_fmt in output_formats:
            output = model(raw_image, output_format=tgt_fmt)
            output_list.append(output)

    # show output latex code of table
    cost_time = time.time() - start_time
    print(f"total cost time: {cost_time:.2f}s")

    if cost_time >= args.max_waiting_time:
        warn_log = f"\033[93mThe model inference time exceeds the maximum waiting time {args.max_waiting_time} seconds, the result may be incomplete.\n" \
        "Please increase the maximum waiting time with argument --max_waiting_time or Model may not support the type of input table image \033[0m"
        print(warn_log)

    for i, tgt_fmt in enumerate(output_formats):
        for j, output in enumerate(output_list[i]):
            print(f"Table {j} {tgt_fmt.upper()} format output:\n{output}")


if __name__ == '__main__':
    main()


================================================
FILE: tools/demo/demo.tex
================================================

\documentclass[border=20pt]{standalone}
\usepackage{blindtext}%
\usepackage{subcaption}
\usepackage{url}
\usepackage{graphicx}
\usepackage{caption}
\usepackage{multirow}
\usepackage{booktabs}
\usepackage{color}
\usepackage{colortbl}
\usepackage{xcolor,soul,framed}
\usepackage{xeCJK}
%\usepackage{fontspec}
%\usepackage[margin=1in]{geometry} 
\usepackage{printlen}
\usepackage{amsmath,amssymb,mathtools,bm,mathrsfs,textcomp}
\setlength{\parindent}{0pt}

\begin{document}

\begin{tabular}{|c|c|c|c|}
  \hline
  Quantity $\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\
  \hline
  Mass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\
  \hline
  Length, translational motion & meter (m) & foot (ft) & inch (in.) \\
  \hline
  Time, $t$ & second (s) & second (s) & second (s) \\
  \hline
  Force, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\
  \hline
  Translational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\
  \hline
  Translational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\
  \hline
  Angle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\
  \hline
  Rotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\
  \hline
  Moment or torque, rotational action & N-m & lb-ft & lb-inch \\
  \hline
  Rotational stiffness constant, $k_\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\
  \hline
  Rotational damping constant, $c_\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\
  \hline
\end{tabular}

\end{document}

================================================
FILE: tools/scripts/build_tensorrt.sh
================================================
set -x 

HF_CKPT_PATH=${1:-"../ckpts/StructTable-base"}
MODEL_OUTPUT=${2:-"../ckpts/StructTable-base-TensorRT"}
MAX_IMAGE_TOKEN_NUM=${3:-2048}
MAX_OUPTPUT_TOKEN_NUM=${4:-2048}
MODEL_TYPE=${5:-"StructEqTable"}

if [ ! -d $MODEL_OUTPUT ]; then
    mkdir -p $MODEL_OUTPUT
fi

# Step1 Convert the model into TensorrtLLM checkpoint format
echo "Step1 Convert the model into TensorrtLLM checkpoint format"

python tensorrt_utils/convert_checkpoint.py --model_type $MODEL_TYPE \
    --model_dir $HF_CKPT_PATH \
    --output_dir $MODEL_OUTPUT/trt_models/float16 \
    --tp_size 1 \
    --pp_size 1 \
    --workers 1 \
    --dtype float16

# Step2 Compile the model
echo "Step2 build LLM Engine"

trtllm-build --checkpoint_dir $MODEL_OUTPUT/trt_models/float16/decoder \
    --output_dir $MODEL_OUTPUT/llm_engines/decoder \
    --paged_kv_cache disable \
    --moe_plugin disable \
    --enable_xqa disable \
    --use_custom_all_reduce disable \
    --gemm_plugin float16 \
    --bert_attention_plugin float16 \
    --gpt_attention_plugin float16 \
    --remove_input_padding enable \
    --context_fmha disable \
    --max_beam_width 1 \
    --max_batch_size 1 \
    --max_seq_len $MAX_OUPTPUT_TOKEN_NUM \
    --max_encoder_input_len $MAX_IMAGE_TOKEN_NUM \
    --max_input_len 1

# Step3 build visual engine
echo "Step3 Build Visual Engine"

python tensorrt_utils/build_visual_engine.py --model_type $MODEL_TYPE \
    --model_path $HF_CKPT_PATH \
    --output_dir $MODEL_OUTPUT/visual_engines \
    --max_batch_size 1

if [ -f './model.cache' ]; then
    rm ./model.cache
fi

echo "Build TensorRT model and Visual Engine Successfully"

================================================
FILE: tools/tensorrt_utils/build_visual_engine.py
================================================
import argparse
import os
import shutil
import sys
import tarfile
from time import time

import yaml

# isort: off
import torch
import tensorrt as trt
from tensorrt_llm.builder import Builder
# isort: on
import json
import math

import torch.nn.functional as F
from PIL import Image
from safetensors.torch import save_file
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
                          AutoModelForVision2Seq, AutoProcessor,
                          Blip2ForConditionalGeneration, Blip2Processor,
                          FuyuForCausalLM, FuyuProcessor,
                          LlavaForConditionalGeneration, NougatProcessor,
                          Pix2StructForConditionalGeneration,
                          VisionEncoderDecoderModel)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type',
                        type=str,
                        default=None,
                        choices=[
                            'opt-2.7b', 'opt-6.7b', 'flan-t5-xl', 'flan-t5-xxl',
                            'llava', 'vila', 'nougat', 'cogvlm', 'fuyu', 'pix2struct',
                            'StructEqTable', 'neva', 'kosmos-2', 'video-neva',
                            'phi-3-vision'
                        ],
                        help="Model type")
    parser.add_argument(
        '--model_path',
        type=str,
        default=None,
        help=
        "Huggingface repo, local directory with weights or path to checkpoint file"
    )
    parser.add_argument('--vila_path',
                        type=str,
                        default=None,
                        help="Path to VILA source code directory")
    parser.add_argument('--output_dir',
                        type=str,
                        default=None,
                        help="Directory where visual TRT engines are saved")
    parser.add_argument('--max_batch_size',
                        type=int,
                        default=4,
                        help="Maximum batch size for input images")
    return parser.parse_args()


class VisionEngineBuilder:

    def __init__(self, args):
        args.device = torch.device(
            "cuda") if torch.cuda.is_available() else "cpu"
        if args.output_dir is None:
            args.output_dir = 'visual_engines/%s' % (
                args.model_path.split('/')[-1] if args.vila_path is not None
                else args.model_path.split('/')[-1])
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        self.args = args

    def build(self):
        args = self.args
        if 'opt' in args.model_type or 't5' in args.model_type:
            build_blip2_engine(args)
        elif args.model_type == 'pix2struct':
            build_pix2struct_engine(args)
        elif args.model_type == 'StructEqTable':
            build_StructEqTable_engine(args)
        elif args.model_type == 'llava':
            build_llava_engine(args)
        elif args.model_type == 'vila':
            assert args.vila_path is not None, "Please clone and provide VILA source code path"
            build_vila_engine(args)
        elif args.model_type == 'nougat':
            build_nougat_engine(args)
        elif args.model_type == 'cogvlm':
            build_cogvlm_engine(args)
        elif args.model_type == 'fuyu':
            build_fuyu_engine(args)
        elif args.model_type == 'neva':
            build_neva_engine(args)
        elif args.model_type == 'video-neva':
            build_video_neva_engine(args)
        elif args.model_type == 'kosmos-2':
            build_kosmos_engine(args)
        elif args.model_type == 'phi-3-vision':
            build_phi_engine(args)
        else:
            raise RuntimeError(f"Invalid model type {args.model_type}")


def export_visual_wrapper_onnx(visual_wrapper,
                               input,
                               output_dir,
                               input_names=['input'],
                               dynamic_axes={'input': {
                                   0: 'batch'
                               }}):
    logger.log(trt.Logger.INFO, "Exporting onnx")
    os.makedirs(f'{output_dir}/onnx', exist_ok=True)
    torch.onnx.export(visual_wrapper,
                      input,
                      f'{output_dir}/onnx/visual_encoder.onnx',
                      opset_version=17,
                      input_names=input_names,
                      output_names=['output'],
                      dynamic_axes=dynamic_axes)


def build_trt_engine(model_type,
                     input_sizes,
                     output_dir,
                     max_batch_size,
                     dtype=torch.float16,
                     num_frames=None):
    part_name = 'visual_encoder'
    onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
    engine_file = '%s/%s.engine' % (output_dir, part_name)
    config_file = '%s/%s' % (output_dir, "config.json")
    logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)

    builder = trt.Builder(logger)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    profile = builder.create_optimization_profile()

    config_args = {
        "precision": str(dtype).split('.')[-1],
        "model_type": model_type
    }
    if num_frames is not None:
        config_args["num_frames"] = num_frames

    config_wrapper = Builder().create_builder_config(**config_args)
    config = config_wrapper.trt_builder_config

    parser = trt.OnnxParser(network, logger)

    with open(onnx_file, 'rb') as model:
        if not parser.parse(model.read(), os.path.abspath(onnx_file)):
            logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
            for error in range(parser.num_errors):
                logger.log(trt.Logger.ERROR, parser.get_error(error))
        logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)

    # Delete onnx files since we don't need them now
    shutil.rmtree(f'{output_dir}/onnx')

    nBS = -1
    nMinBS = 1
    nOptBS = max(nMinBS, int(max_batch_size / 2))
    nMaxBS = max_batch_size

    inputT = network.get_input(0)

    # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images,
    # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]).
    assert isinstance(input_sizes, list), "input_sizes must be a list"
    if isinstance(input_sizes[0], int):
        logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}")
        inputT.shape = [nBS, *input_sizes]
        min_size = opt_size = max_size = input_sizes
    elif len(input_sizes) == 3 and isinstance(input_sizes[0], list):
        min_size, opt_size, max_size = input_sizes
        logger.log(
            trt.Logger.INFO,
            f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}"
        )
    else:
        raise ValueError(f"invalid input sizes: {input_sizes}")

    profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size],
                      [nMaxBS, *max_size])
    if model_type == "pix2struct" or model_type == "StructEqTable" :
        inputT = network.get_input(1)
        P = input_sizes[0]  # Number of patches
        inputT.shape = [nBS, P]
        profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P])
    config.add_optimization_profile(profile)

    t0 = time()
    engine_string = builder.build_serialized_network(network, config)
    t1 = time()
    if engine_string is None:
        raise RuntimeError("Failed building %s" % (engine_file))
    else:
        logger.log(trt.Logger.INFO,
                   "Succeeded building %s in %d s" % (engine_file, t1 - t0))
        with open(engine_file, 'wb') as f:
            f.write(engine_string)

    Builder.save_config(config_wrapper, config_file)


def build_blip2_engine(args):
    model_type = 'Salesforce/blip2-' + args.model_type
    processor = Blip2Processor.from_pretrained(model_type)

    raw_image = Image.new('RGB', [10, 10])  # dummy image
    prompt = "Question: what is this? Answer:"
    inputs = processor(raw_image, prompt,
                       return_tensors="pt").to(args.device, torch.float16)
    image = inputs['pixel_values']

    class Blip2VisionWrapper(torch.nn.Module):

        def __init__(self, vision_model, qformer, projector, query_tokens):
            super().__init__()
            self.vision_model = vision_model
            self.qformer = qformer
            self.projector = projector
            self.query_tokens = query_tokens

        def forward(self, image):
            features = self.vision_model(image)[0]
            qformer_output = self.qformer(query_embeds=self.query_tokens,
                                          encoder_hidden_states=features,
                                          return_dict=True)
            return self.projector(qformer_output.last_hidden_state)

    model = Blip2ForConditionalGeneration.from_pretrained(
        model_type, torch_dtype=torch.float16)
    wrapper = Blip2VisionWrapper(model.vision_model, model.qformer,
                                 model.language_projection, model.query_tokens)
    wrapper.to(args.device)

    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size)


def build_pix2struct_engine(args):
    processor = AutoProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    dtype = torch.float16
    inputs = processor(text="dummy", images=raw_image, return_tensors="pt", max_patches=processor.image_processor.max_patches)
    image = inputs['flattened_patches'].to(args.device, dtype)
    attention_mask = inputs['attention_mask'].to(args.device, torch.int)
    class pix2structVisionWrapper(torch.nn.Module):

        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def forward(self, image, attention_mask):
            vision_x = self.encoder.embeddings(image)
            img_features = self.encoder.encoder(vision_x,
                                                attention_mask=attention_mask)
            img_features = self.encoder.layernorm(img_features[0])
            return img_features

    model = Pix2StructForConditionalGeneration.from_pretrained(
        args.model_path, torch_dtype=dtype)

    wrapper = pix2structVisionWrapper(model.encoder.to(args.device))
    # input shape: batch size, number of patches, hidden dimension
    # attention mask shape: batch size, number of patches
    # The number of image patches can vary depending on the image size, but it typically
    # falls within a relatively narrow range. To improve performance, we can avoid using
    # dynamic axis for the input patches and instead use a fixed number of patches along
    # with an attention mask.
    export_visual_wrapper_onnx(wrapper, (image, attention_mask),
                               args.output_dir,
                               input_names=['input', 'attention_mask'],
                               dynamic_axes={
                                   'input': {
                                       0: 'batch'
                                   },
                                   'attention_mask': {
                                       0: 'batch'
                                   }
                               })
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2]],  # Number of Patches, Hidden Dimension
        args.output_dir,
        args.max_batch_size,
        torch.bfloat16)


def build_StructEqTable_engine(args):
    processor = AutoProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    dtype = torch.float16
    inputs = processor(text="dummy", images=raw_image, return_tensors="pt", max_patches=processor.image_processor.max_patches)
    image = inputs['flattened_patches'].to(args.device, dtype)
    attention_mask = inputs['attention_mask'].to(args.device, torch.int)
    class StructEqTableVisionWrapper(torch.nn.Module):

        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def forward(self, image, attention_mask):
            vision_x = self.encoder.embeddings(image)
            img_features = self.encoder.encoder(vision_x,
                                                attention_mask=attention_mask)
            img_features = self.encoder.layernorm(img_features[0])
            return img_features

    model = AutoModelForVision2Seq.from_pretrained(
        args.model_path, torch_dtype=dtype)

    wrapper = StructEqTableVisionWrapper(model.encoder.to(args.device))
    # input shape: batch size, number of patches, hidden dimension
    # attention mask shape: batch size, number of patches
    # The number of image patches can vary depending on the image size, but it typically
    # falls within a relatively narrow range. To improve performance, we can avoid using
    # dynamic axis for the input patches and instead use a fixed number of patches along
    # with an attention mask.
    export_visual_wrapper_onnx(wrapper, (image, attention_mask),
                               args.output_dir,
                               input_names=['input', 'attention_mask'],
                               dynamic_axes={
                                   'input': {
                                       0: 'batch'
                                   },
                                   'attention_mask': {
                                       0: 'batch'
                                   }
                               })
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2]],  # Number of Patches, Hidden Dimension
        args.output_dir,
        args.max_batch_size,
        torch.bfloat16)


def build_llava_engine(args):
    processor = AutoProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    image = processor(text="dummy", images=raw_image,
                      return_tensors="pt")['pixel_values'].to(
                          args.device, torch.float16)

    class LlavaVisionWrapper(torch.nn.Module):

        def __init__(self, tower, projector, feature_layer):
            super().__init__()
            self.tower = tower
            self.projector = projector
            self.feature_layer = feature_layer

        def forward(self, image):
            all_hidden_states = self.tower(
                image, output_hidden_states=True).hidden_states
            features = all_hidden_states[self.feature_layer][:, 1:]
            return self.projector(features)

    model = LlavaForConditionalGeneration.from_pretrained(
        args.model_path, torch_dtype=torch.float16)
    wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device),
                                 model.multi_modal_projector.to(args.device),
                                 model.config.vision_feature_layer)

    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size)


def build_vila_engine(args):
    # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo
    sys.path.append(args.vila_path)
    from llava.model import LlavaLlamaConfig, LlavaLlamaModel  # noqa
    from transformers import AutoModel
    model = AutoModel.from_pretrained(
        args.model_path,
        device_map='auto',
    )

    vision_tower = model.get_vision_tower()
    image_processor = vision_tower.image_processor
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    image = image_processor(images=raw_image,
                            return_tensors="pt")['pixel_values']
    if isinstance(image, list):
        image = image[0].unsqueeze(0)
    image = image.to(args.device, torch.float16)

    class VilaVisionWrapper(torch.nn.Module):

        def __init__(self, tower, projector):
            super().__init__()
            self.tower = tower
            self.projector = projector

        def forward(self, image):
            features = self.tower(image)
            return self.projector(features)

    model = AutoModel.from_pretrained(
        args.model_path,
        device_map='auto',
    )
    wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device),
                                model.mm_projector.to(args.device))
    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size)


def build_nougat_engine(args):
    processor = NougatProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    image = processor(raw_image, return_tensors="pt")['pixel_values'].to(
        args.device, torch.float16)

    class SwinEncoderWrapper(torch.nn.Module):

        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def forward(self, image):
            return self.encoder(image).last_hidden_state

    model = VisionEncoderDecoderModel.from_pretrained(args.model_path,
                                                      torch_dtype=torch.float16)
    swin_encoder = model.get_encoder().to(args.device)
    wrapper = SwinEncoderWrapper(swin_encoder)

    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size)


def build_cogvlm_engine(args):
    hf_config = AutoConfig.from_pretrained(args.model_path,
                                           trust_remote_code=True)
    image_size = hf_config.vision_config['image_size']
    dtype = hf_config.torch_dtype
    image = torch.empty(1,
                        3,
                        image_size,
                        image_size,
                        dtype=dtype,
                        device=args.device)  # dummy image

    class CogVlmVisionWrapper(torch.nn.Module):

        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def forward(self, image):
            return self.encoder(image)

    cogvlm = AutoModelForCausalLM.from_pretrained(args.model_path,
                                                  torch_dtype=dtype,
                                                  trust_remote_code=True)
    vit_encoder = cogvlm.model.vision.to(args.device).eval()

    wrapper = CogVlmVisionWrapper(vit_encoder)
    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size,
        dtype)


def build_fuyu_engine(args):
    processor = FuyuProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])
    image = processor(text="dummy", images=raw_image,
                      return_tensors="pt")['image_patches'][0].to(
                          args.device, torch.float16).unsqueeze(0)

    class FuyuEncoderWrapper(torch.nn.Module):

        def __init__(self, linear):
            super().__init__()
            self.linear = linear.to(torch.float16)

        def forward(self, patches):
            return self.linear(patches).flatten(0, 1)

    model = FuyuForCausalLM.from_pretrained(args.model_path,
                                            torch_dtype=torch.float16)

    vision_encoder = model.vision_embed_tokens
    wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device)

    export_visual_wrapper_onnx(wrapper,
                               image,
                               args.output_dir,
                               dynamic_axes={'input': {
                                   0: 'batch',
                                   2: 'patch'
                               }})
    build_trt_engine(
        args.model_type,
        # [nImgs, nImgPatches, nDims]
        # nImgs is always one since each query has exactly one image
        # nImgPatches depends on image size (patch size: 30x30)
        # nDims is 30x30x3=2700 (patch size x color channels)
        [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]],
        args.output_dir,
        args.max_batch_size)


def build_neva_engine(args):
    # extract NeMo checkpoint
    with tarfile.open(args.model_path) as tar:
        nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
        try:
            # trained without TP
            mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"),
                                     map_location=args.device)
        except KeyError:
            # trained with TP
            mp0_weights = torch.load(
                tar.extractfile("./mp_rank_00/model_weights.ckpt"),
                map_location=args.device)

    vision_config = nemo_config["mm_cfg"]["vision_encoder"]

    class VisionEncoderWrapper(torch.nn.Module):

        def __init__(self, encoder, connector):
            super().__init__()
            self.encoder = encoder
            self.connector = connector

        def forward(self, images):
            vision_x = self.encoder(pixel_values=images,
                                    output_hidden_states=True)
            vision_x = vision_x.hidden_states[-2]
            vision_x = vision_x[:, 1:]
            vision_x = self.connector(vision_x)
            return vision_x

    encoder = AutoModel.from_pretrained(vision_config["from_pretrained"],
                                        torch_dtype=torch.bfloat16,
                                        trust_remote_code=True)
    vision_encoder = encoder.vision_model
    hf_config = encoder.config
    dtype = hf_config.torch_dtype

    # connector
    assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
    vision_connector = torch.nn.Sequential(
        torch.nn.Linear(vision_config["hidden_size"],
                        nemo_config["hidden_size"],
                        bias=True), torch.nn.GELU(),
        torch.nn.Linear(nemo_config["hidden_size"],
                        nemo_config["hidden_size"],
                        bias=True)).to(dtype=dtype)

    key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
    for layer in range(0, 3, 2):
        vision_connector[layer].load_state_dict({
            'weight':
            mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
            'bias':
            mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
        })

    # export the whole wrapper
    wrapper = VisionEncoderWrapper(vision_encoder,
                                   vision_connector).to(args.device, dtype)
    image_size = hf_config.vision_config.image_size
    dummy_image = torch.empty(
        1, 3, image_size, image_size, dtype=dtype,
        device=args.device)  # dummy image shape [B, C, H, W]
    export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [3, image_size, image_size],  # [3, H, W]
        args.output_dir,
        args.max_batch_size,
        dtype)


def build_video_neva_engine(args):
    # extract NeMo checkpoint
    with tarfile.open(args.model_path) as tar:
        nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
        try:
            # trained without TP
            mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"),
                                     map_location=args.device)
        except KeyError:
            # trained with TP
            mp0_weights = torch.load(
                tar.extractfile("./mp_rank_00/model_weights.ckpt"),
                map_location=args.device)

    vision_config = nemo_config["mm_cfg"]["vision_encoder"]

    class VisionEncoderWrapper(torch.nn.Module):

        def __init__(self, encoder, connector):
            super().__init__()
            self.encoder = encoder
            self.connector = connector

        def forward(self, images):
            b, num_frames, c, h, w = images.shape
            images = images.view(b * num_frames, c, h, w)
            vision_x = self.encoder(
                pixel_values=images,  #[(B num_frames), C, H, W]
                output_hidden_states=True)
            vision_x = vision_x.hidden_states[-2]
            vision_x = vision_x[:, 1:]

            # reshape back to [B, num_frames, img_size, hidden_size]
            vision_x = vision_x.view(b, num_frames, -1, vision_x.shape[-1])

            vision_x = self.connector(vision_x)
            return vision_x

    encoder = AutoModel.from_pretrained(vision_config["from_pretrained"],
                                        torch_dtype=torch.bfloat16,
                                        trust_remote_code=True)
    vision_encoder = encoder.vision_model
    hf_config = encoder.config
    dtype = hf_config.torch_dtype

    # connector
    assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear"
    vision_connector = torch.nn.Linear(vision_config["hidden_size"],
                                       nemo_config["hidden_size"],
                                       bias=True)

    key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
    vision_connector.load_state_dict({
        'weight':
        mp0_weights[f"{key_prefix}.weight"].to(dtype),
        'bias':
        mp0_weights[f"{key_prefix}.bias"].to(dtype),
    })

    # export the whole wrapper
    wrapper = VisionEncoderWrapper(vision_encoder,
                                   vision_connector).to(args.device, dtype)
    image_size = hf_config.vision_config.image_size
    num_frames = nemo_config['data']['num_frames']
    dummy_video = torch.empty(1,
                              num_frames,
                              3,
                              image_size,
                              image_size,
                              dtype=dtype,
                              device=args.device)  # dummy image
    export_visual_wrapper_onnx(wrapper, dummy_video, args.output_dir)
    build_trt_engine(
        args.model_type,
        [num_frames, 3, image_size, image_size],  # [num_frames, 3, H, W]
        args.output_dir,
        args.max_batch_size,
        dtype,
        num_frames=num_frames)


def build_kosmos_engine(args):
    processor = AutoProcessor.from_pretrained(args.model_path)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    image = processor(text="dummy", images=raw_image,
                      return_tensors="pt")['pixel_values'].to(
                          args.device, torch.float16)

    class VisionEncoderWrapper(torch.nn.Module):

        def __init__(self, encoder, connector):
            super().__init__()
            self.encoder = encoder
            self.connector = connector

        def forward(self, images):
            vision_x = self.encoder(images, output_hidden_states=True)
            img_features = self.encoder.model.post_layernorm(
                vision_x.last_hidden_state)
            img_features = F.normalize(img_features, dim=-1)
            img_features, _ = self.connector(img_features)
            return img_features

    model = AutoModelForVision2Seq.from_pretrained(args.model_path,
                                                   torch_dtype=torch.float16)
    wrapper = VisionEncoderWrapper(
        model.vision_model.to(args.device),
        model.image_to_text_projection.to(args.device))

    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]],  # [3, H, W]
        args.output_dir,
        args.max_batch_size)


def build_phi_engine(args):
    processor = AutoProcessor.from_pretrained(args.model_path,
                                              trust_remote_code=True)
    raw_image = Image.new('RGB', [10, 10])  # dummy image
    image = processor(text="<|image_1|>\ndummy",
                      images=raw_image,
                      return_tensors="pt")['pixel_values'].to(
                          args.device, torch.float16)
    try:
        with open(f"{args.model_path}/preprocessor_config.json", "r") as file:
            config = file.read()
            config_dict = json.loads(config)
            num_crops = config_dict.get("num_crops")
    except:
        num_crops = 16

    class Phi3VisionWrapper(torch.nn.Module):

        def __init__(self, img_processor, img_projection, layer_idx,
                     image_dim_out):
            super().__init__()
            self.img_processor = img_processor
            self.img_projection = img_projection
            self.layer_idx = layer_idx
            self.image_dim_out = image_dim_out

        def get_img_features(
                self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
            LAYER_IDX = self.layer_idx

            img_processor_output = self.img_processor(img_embeds,
                                                      output_hidden_states=True)
            img_feature = img_processor_output.hidden_states[LAYER_IDX]

            patch_feature = img_feature[:, 1:]
            return patch_feature

        def forward(self, image):
            img_features = self.get_img_features(image)
            base_feat_height = int(math.sqrt(img_features.shape[1]))
            C = self.image_dim_out
            H = base_feat_height
            img_features = img_features.reshape(-1, H, H, C).reshape(
                -1, H // 2, 2, H // 2, 2,
                C).contiguous().permute(0, 1, 3, 2, 4,
                                        5).reshape(-1, H // 2, H // 2,
                                                   4 * C).contiguous()
            return self.apply_img_projection(img_features)

        def apply_img_projection(self, input):
            return self.img_projection(input)

    model = AutoModelForCausalLM.from_pretrained(args.model_path,
                                                 torch_dtype=torch.float16,
                                                 trust_remote_code=True).to(
                                                     args.device)

    wrapper = Phi3VisionWrapper(model.model.vision_embed_tokens.img_processor,
                                model.model.vision_embed_tokens.img_projection,
                                model.model.vision_embed_tokens.layer_idx,
                                model.model.vision_embed_tokens.image_dim_out)
    image = image.flatten(0, 1)
    glb_GN = wrapper.apply_img_projection(
        model.model.vision_embed_tokens.glb_GN)
    sub_GN = wrapper.apply_img_projection(
        model.model.vision_embed_tokens.sub_GN)
    tensors = {"glb_GN": glb_GN, "sub_GN": sub_GN}
    save_file(tensors, args.output_dir + "/image_newlines.safetensors")
    export_visual_wrapper_onnx(wrapper, image, args.output_dir)
    build_trt_engine(
        args.model_type,
        [image.shape[1], image.shape[2], image.shape[3]], args.output_dir,
        args.max_batch_size * (num_crops + 1))  #TODO: Take input from config


if __name__ == '__main__':
    logger = trt.Logger(trt.Logger.INFO)
    args = parse_arguments()
    builder = VisionEngineBuilder(args)
    builder.build()


================================================
FILE: tools/tensorrt_utils/convert_checkpoint.py
================================================
import argparse
import configparser
import copy
import json
import logging
import os
import types
from ast import literal_eval
from datetime import datetime
from pathlib import Path

import safetensors
from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
                          MBartForConditionalGeneration,
                          Pix2StructForConditionalGeneration,
                          AutoModelForVision2Seq,
                          T5ForConditionalGeneration, VisionEncoderDecoderModel)

from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
                                     MLPType)
from tensorrt_llm.models import PretrainedConfig

dir_path = os.path.dirname(os.path.realpath(__file__))
LOGGER = logging.getLogger(__name__)

layernorm_type_map = {i.name: i.value for i in LayerNormType}
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
mlp_type_map = {i.name: i.value for i in MLPType}


def copy_args_to_component_config(component_config, args):
    for arg in vars(args):
        setattr(component_config, arg, getattr(args, arg))
    return component_config


def parse_t5_config(args, hf_model):
    config = configparser.ConfigParser()

    config["encoder"] = {}
    for key, val in hf_model.encoder.config.to_dict().items():
        config["encoder"][key] = f"{val}"

    # manually set q_scaling to offset attention scaling's effect.
    # TODO: modify kernels to control whether to disable attention scaling
    def get_offset_q_scaling(config):
        scaling = 1 / config.head_size**.5
        return scaling

    config["decoder"] = {}
    for key, val in hf_model.decoder.config.to_dict().items():
        config["decoder"][key] = f"{val}"

    config["structure"] = dict()
    config["structure"]["t5_with_bias"] = "false"
    config["structure"]["use_gated_activation"] = str(
        hf_model.encoder.config.is_gated_act)
    config["structure"]["position_embedding_type"] = "relative"
    config["structure"]["model_type"] = args.model_type

    def parse_t5_config_by_component(config, component, args):
        component_config = types.SimpleNamespace()
        component_config = copy_args_to_component_config(component_config, args)
        component_config.n_head = config.getint(component, 'num_heads')
        component_config.head_size = config.getint(component, 'd_kv')
        component_config.hidden_size = config.getint(component, 'd_model')
        component_config.ffn_hidden_size = config.getint(component, 'd_ff')
        component_config.vocab_size = config.getint(component, 'vocab_size')
        component_config.n_positions = config.getint(component,
                                                     'n_positions',
                                                     fallback=512)
        component_config.has_position_embedding = config.getboolean(
            component, 'has_position_embedding',
            fallback=False)  # TODO: hardcoded here

        component_config.has_token_type_embedding = config.getboolean(
            component, 'has_token_type_embedding', fallback=False)
        component_config.has_embedding_layernorm = config.getboolean(
            component, 'has_embedding_layernorm', fallback=False)
        component_config.has_embedding_scale = config.getboolean(
            component, 'has_embedding_scale', fallback=False)
        component_config.q_scaling = get_offset_q_scaling(component_config)
        component_config.has_attention_qkvo_bias = config.getboolean(
            component, 'has_attention_qkvo_bias',
            fallback=False)  # TODO: hardcoded here
        component_config.has_mlp_bias = config.getboolean(component,
                                                          'has_mlp_bias',
                                                          fallback=False)
        component_config.has_model_final_layernorm = config.getboolean(
            component, 'has_model_final_layernorm', fallback=True)
        component_config.layernorm_eps = config.getfloat(
            component, 'layer_norm_epsilon')
        component_config.layernorm_position = layernorm_position_map[config.get(
            component, 'layernorm_position',
            fallback='pre_layernorm')]  # TODO: hardcoded here
        component_config.layernorm_type = layernorm_type_map[config.get(
            component, 'layernorm_type', fallback='RmsNorm')]
        component_config.hidden_act = config.get(component, 'dense_act_fn')
        component_config.gated_act = config.getboolean(component,
                                                       'is_gated_act')
        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
                                                 gated_act else 'MLP']
        component_config.num_buckets = config.getint(
            component, 'relative_attention_num_buckets')
        component_config.max_distance = config.getint(
            component, 'relative_attention_max_distance')
        component_config.position_embedding_type = config.get(
            'structure', 'position_embedding_type')
        component_config.logits_dtype = config.get(component,
                                                   'logits_dtype',
                                                   fallback='float32')

        if component == 'encoder':
            component_config.n_layer = config.getint(component, 'num_layers')

            component_config.relative_attention = config.get(
                'structure', 'position_embedding_type') == 'relative'

        elif component == 'decoder':
            component_config.n_layer = config.getint(component,
                                                     'num_decoder_layers')
            component_config.has_lm_head_bias = config.getboolean(
                component,  # TODO: T5 with bias
                'has_lm_head_bias',
                fallback=False)
            component_config.relative_attention = config.getboolean(
                component, 'relative_attention', fallback=True)
            component_config.rescale_before_lm_head = config.getboolean(
                component, 'tie_word_embeddings'
            )  # default is True (for T5), but False for Flan-T5
            component_config.encoder_hidden_size = config.getint(
                'encoder', 'd_model')
            component_config.encoder_num_heads = config.getint(
                'encoder', 'num_heads')
            component_config.encoder_head_size = config.getint(
                'encoder', 'd_kv')
            component_config.decoder_start_token_id = config.getint(
                'decoder', 'decoder_start_token_id')

        else:
            assert False, 'Unsupported component!'

        return component_config

    encoder_config = parse_t5_config_by_component(config, "encoder", args)
    decoder_config = parse_t5_config_by_component(config, "decoder", args)

    return encoder_config, decoder_config


def convert_t5_weights_to_tllm_safetensors(config, component, params):
    weights = {}

    mapping = config.mapping

    convert_weight_to_dtype(params, config.dtype)
    hidden_size = config.hidden_size
    ffn_hidden_size = config.intermediate_size
    num_layers = config.num_hidden_layers
    n_head = config.num_attention_heads
    head_size = config.head_size
    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5

    hf_param_prefix = f'{component}'
    trtllm_layer_name = f'{component}_layers'
    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
    hf_component_idx = 1 if component == 'encoder' else 2

    def get_attn_module_name(component, block, layer, attn_type):
        return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}'

    weights['embedding.vocab_embedding.weight'] = reshape(
        params['shared.weight'].clone(), None)

    layers_range = mapping.pp_layers(num_layers)
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
        hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}'

        hidden_layer_name_split = {
            f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
                "shape":
                (hidden_size, attention_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight':
            {
                "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
                "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight':
            {
                "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
                "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                "split_dim": 0
            },
            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight':
            {
                "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
                "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                "split_dim": 0
            },
        }

        hidden_layer_name_no_split = {
            f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
                "shape": None
            },
            f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight':
            {
                "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
                "shape": None
            },
        }

        if config.gated_act:
            hidden_layer_name_split.update({
                f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight':
                {
                    "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
                    "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                    "split_dim": 0
                },
                f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight':
                {
                    "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
                    "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                    "split_dim": 0
                },
            })

        if component == 'decoder':
            hidden_layer_name_split.update({
                f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': {
                    "name":
                    f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
                    "shape":
                    (hidden_size, attention_hidden_size // mapping.tp_size),
                    "split_dim": -1
                },
            })
            hidden_layer_name_no_split.update({
                f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': {
                    "name":
                    f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
                    "shape": None
                },
            })
            self_attn_module_name = get_attn_module_name(
                component, layer_idx, "1", 'EncDecAttention')
            weights.update(
                fuse_qkv_one_layer(
                    params, self_attn_module_name,
                    f'{trtllm_layer_name_prefix}.cross_attention',
                    mapping.tp_size, mapping.tp_rank, config.model_type,
                    (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
                    None))

        self_attn_module_name = get_attn_module_name(component, layer_idx, "0",
                                                     'SelfAttention')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
                mapping.tp_size, mapping.tp_rank, config.model_type,
                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
                None))

        weights[
            f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
                split(
                    params[
                        f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
                    .T, mapping.tp_size, mapping.tp_rank, 0),
                (n_head // mapping.tp_size, config.num_buckets))

        for hf_weight_name, weight_info in hidden_layer_name_split.items():
            if hf_weight_name in params.keys():
                weights[weight_info["name"]] = reshape(
                    split(params[hf_weight_name],
                          mapping.tp_size,
                          mapping.tp_rank,
                          dim=weight_info["split_dim"]), weight_info["shape"])
        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
            if hf_weight_name in params.keys():
                weights[weight_info["name"]] = reshape(
                    params[hf_weight_name].clone(), shape=weight_info["shape"])

    weights['final_layernorm.weight'] = reshape(
        params[f'{component}.final_layer_norm.weight'].clone(), None)

    if component == 'decoder':
        weights['lm_head.weight'] = reshape(
            split(params['lm_head.weight'],
                  mapping.tp_size,
                  mapping.tp_rank,
                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
        if not config.use_implicit_relative_attention:
            weights['rel_attn_table'] = reshape(
                split(
                    params[
                        f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
                    .T, mapping.tp_size, mapping.tp_rank, 0),
                (n_head // mapping.tp_size, config.num_buckets))

    return weights


convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors  # func alias


def parse_nmt_config(args, model):
    config = configparser.ConfigParser()
    fairseq_config = vars(model.cfg.model)  # Namespace --> dict

    config['encoder'] = dict()
    for key, val in fairseq_config.items():
        config["encoder"][key] = f"{val}"
    config["encoder"]["q_scaling"] = '1'
    # NMT has final layernorm for pre-norm model architecture.
    config['encoder']['has_model_final_layernorm'] = config['encoder'][
        'encoder_normalize_before']
    config['encoder']['vocab_size'] = str(len(model.src_dict))  # fairseq naming

    config['decoder'] = dict()
    for key, val in fairseq_config.items():
        config["decoder"][key] = f"{val}"
    config["decoder"]["q_scaling"] = '1'
    config["decoder"]["rescale_before_lm_head"] = 'false'
    config['decoder']['has_model_final_layernorm'] = config['decoder'][
        'decoder_normalize_before'] and not config['decoder'].getboolean(
            'no_decoder_final_norm', False)
    config['decoder']['vocab_size'] = str(len(model.tgt_dict))  # fairseq naming

    config["structure"] = dict()
    config["structure"]["t5_with_bias"] = "true"
    config["structure"]["use_gated_activation"] = "false"
    config["structure"][
        "position_embedding_type"] = "learned_absolute"  # "sinusoid"
    config["structure"]["model_type"] = args.model_type

    def parse_nmt_config_by_component(config, component, args):
        assert component in ('encoder', 'decoder'), 'Unsupported component!'
        component_config = types.SimpleNamespace()
        component_config = copy_args_to_component_config(component_config, args)
        component_config.n_layer = config.getint(component,
                                                 f'{component}_layers')
        component_config.n_head = config.getint(component,
                                                f'{component}_attention_heads')
        component_config.hidden_size = config.getint(
            component, f'{component}_embed_dim')  # fairseq naming
        component_config.head_size = config.getint(
            component,
            'd_kv',
            fallback=component_config.hidden_size // component_config.n_head)
        component_config.ffn_hidden_size = config.getint(
            component, f'{component}_ffn_embed_dim')  # fairseq naming
        component_config.vocab_size = config.getint(component, 'vocab_size')
        component_config.n_positions = config.getint(
            component, 'max_source_positions')  # fairseq naming
        component_config.has_position_embedding = not config.getboolean(
            component, 'no_token_positional_embeddings',
            fallback=False)  # fairseq naming
        component_config.has_token_type_embedding = config.getboolean(
            component, 'has_token_type_embedding', fallback=False)
        component_config.has_embedding_layernorm = config.getboolean(
            component, 'layernorm_embedding', fallback=True)  # fairseq naming
        component_config.has_embedding_scale = not config.getboolean(
            component, 'no_scale_embedding')  # fairseq naming
        component_config.q_scaling = config.getfloat(component,
                                                     'q_scaling',
                                                     fallback=1.0)
        component_config.has_attention_qkvo_bias = config.getboolean(
            'structure', 't5_with_bias', fallback=True)
        component_config.has_mlp_bias = config.getboolean('structure',
                                                          't5_with_bias',
                                                          fallback=True)
        component_config.has_model_final_layernorm = config.getboolean(
            component, 'has_model_final_layernorm')
        component_config.layernorm_eps = config.getfloat(
            component, 'layer_norm_epsilon', fallback=1e-5)  # fairseq naming

        normalize_before = config.getboolean(
            component, f'{component}_normalize_before')  # fairseq naming
        component_config.layernorm_position = layernorm_position_map[
            'pre_layernorm' if normalize_before else 'post_layernorm']

        component_config.layernorm_type = layernorm_type_map[config.get(
            component, 'layernorm_type', fallback='LayerNorm')]
        component_config.hidden_act = config.get(
            component, 'activation_fn')  # fairseq naming
        component_config.gated_act = config.getboolean(component,
                                                       'is_gated_act',
                                                       fallback=False)
        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
                                                 gated_act else 'MLP']
        component_config.relative_attention = config.get(
            'structure', 'position_embedding_type') == 'relative'

        component_config.num_buckets = config.getint(
            component, 'relative_attention_num_buckets', fallback=0)
        component_config.max_distance = config.getint(
            component, 'relative_attention_max_distance', fallback=0)
        component_config.position_embedding_type = config.get(
            'structure', 'position_embedding_type')
        component_config.logits_dtype = config.get(component,
                                                   'logits_dtype',
                                                   fallback='float32')
        if component == 'decoder':
            component_config.rescale_before_lm_head = config.getboolean(
                component, 'rescale_before_lm_head')

            component_config.encoder_hidden_size = config.getint(
                'encoder', 'encoder_embed_dim')  # fairseq naming
            component_config.encoder_num_heads = config.getint(
                'encoder', 'encoder_attention_heads')
            component_config.encoder_head_size = config.getint(
                'encoder',
                'd_kv',
                fallback=component_config.encoder_hidden_size //
                component_config.encoder_num_heads)
            component_config.decoder_start_token_id = config.getint(
                'decoder', 'decoder_start_token_id')

        return component_config

    encoder_config = parse_nmt_config_by_component(config, "encoder", args)
    decoder_config = parse_nmt_config_by_component(config, "decoder", args)

    return encoder_config, decoder_config


def convert_nmt_weights_to_tllm_safetensors(config, component, params,
                                            sin_pos_embedding):
    weights = {}

    mapping = config.mapping

    hidden_size = config.hidden_size

    convert_weight_to_dtype(params, config.dtype)
    ffn_hidden_size = config.intermediate_size
    vocab_size = config.vocab_size

    hf_param_prefix = f'models.0.{component}'
    trtllm_layer_name = f'{component}_layers'
    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'

    hidden_layer_name_split = {
        'self_attn.out_proj.weight': {
            "name": f'{trtllm_attn_layer_name}.dense.weight',
            "shape": (hidden_size, hidden_size // mapping.tp_size),
            "split_dim": -1
        },
        'fc1.weight': {
            "name": 'mlp.fc.weight',
            "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
            "split_dim": 0
        },
        'fc1.bias': {
            "name": 'mlp.fc.bias',
            "shape": (ffn_hidden_size // mapping.tp_size),
            "split_dim": 0
        },
        'fc2.weight': {
            "name": 'mlp.proj.weight',
            "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
            "split_dim": -1
        },
    }

    hidden_layer_name_no_split = {
        'self_attn.out_proj.bias': {
            "name": f'{trtllm_attn_layer_name}.dense.bias',
            "shape": (hidden_size)
        },
        'self_attn_layer_norm.weight': {
            "name": f'{trtllm_attn_layernorm_name}.weight',
            "shape": None
        },
        'self_attn_layer_norm.bias': {
            "name": f'{trtllm_attn_layernorm_name}.bias',
            "shape": None
        },
        'fc2.bias': {
            "name": 'mlp.proj.bias',
            "shape": (hidden_size)
        },
        'final_layer_norm.weight': {
            "name": 'mlp_layernorm.weight',
            "shape": None
        },
        'final_layer_norm.bias': {
            "name": 'mlp_layernorm.bias',
            "shape": None
        },
    }

    if component == "decoder":
        hidden_layer_name_split.update({
            'encoder_attn.out_proj.weight': {
                "name": 'cross_attention.dense.weight',
                "shape": (hidden_size, hidden_size // mapping.tp_size),
                "split_dim": -1
            },
        })
        hidden_layer_name_no_split.update({
            'encoder_attn.out_proj.bias': {
                "name": 'cross_attention.dense.bias',
                "shape": (hidden_size)
            },
            'encoder_attn_layer_norm.weight': {
                "name": 'cross_attention_layernorm.weight',
                "shape": None,
            },
            'encoder_attn_layer_norm.bias': {
                "name": 'cross_attention_layernorm.bias',
                "shape": None
            },
        })

    def get_attn_module_name(component, layer, attn_type):
        return f'models.0.{component}.layers.{int(layer)}.{attn_type}'

    weights["embedding.vocab_embedding.weight"] = reshape(
        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(),
        (vocab_size, -1))
    weights["embedding.position_embedding.weight"] = reshape(
        sin_pos_embedding, (config.max_position_embeddings, hidden_size))

    num_layers = config.num_hidden_layers

    layers_range = mapping.pp_layers(num_layers)
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'

        for hf_weight_name, weight_info in hidden_layer_name_split.items():
            weights[
                f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
                    split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
                          mapping.tp_size,
                          mapping.tp_rank,
                          dim=weight_info["split_dim"]), weight_info["shape"])

        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
            trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
            hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
            weights[trtllm_layer_fullname] = reshape(
                params[hf_layer_fullname].clone(), shape=weight_info["shape"])

        self_attn_module_name = get_attn_module_name(component, layer_idx,
                                                     'self_attn')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
                mapping.tp_size, mapping.tp_rank, config.model_type,
                (hidden_size * 3 // mapping.tp_size, hidden_size),
                (hidden_size * 3 // mapping.tp_size)))
        if component == 'decoder':
            cross_attn_module_name = get_attn_module_name(
                component, layer_idx, 'encoder_attn')
            weights.update(
                fuse_qkv_one_layer(
                    params, cross_attn_module_name,
                    f'{trtllm_layer_name_prefix}.cross_attention',
                    mapping.tp_size, mapping.tp_rank, config.model_type,
                    (hidden_size * 3 // mapping.tp_size, hidden_size),
                    (hidden_size * 3 // mapping.tp_size)))

    if component == 'decoder':
        weights['lm_head.weight'] = reshape(
            split(params[f'{hf_param_prefix}.output_projection.weight'],
                  mapping.tp_size,
                  mapping.tp_rank,
                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))

    if config.has_model_final_layernorm:
        weights['final_layernorm.weight'] = params[
            f'{hf_param_prefix}.layer_norm.weight'].clone()
        weights['final_layernorm.bias'] = params[
            f'{hf_param_prefix}.layer_norm.bias'].clone()

    return weights


def parse_bart_config(args, hf_model):

    config = configparser.ConfigParser()

    config['decoder'] = dict()
    for key, val in hf_model.model.decoder.config.to_dict().items():
        config["decoder"][key] = f"{val}"
    config["decoder"]["q_scaling"] = '1'
    config["decoder"]["rescale_before_lm_head"] = str(False)
    config['decoder']['has_model_final_layernorm'] = str(
        args.nougat or isinstance(hf_model, MBartForConditionalGeneration))

    if args.nougat:
        # These flags are true for mbart decoders, but missing in HF config
        config['decoder']['normalize_before'] = str(True)
        config['decoder']['normalize_embeddings'] = str(True)

        config['encoder'] = dict()
        # Init few encoder configs, needed by build, from decoder config
        encoder_config_keys = [
            "encoder_ffn_dim", "encoder_layers", "encoder_attention_heads",
            "encoder_layerdrop", "d_model"
        ]
        for key in encoder_config_keys:
            config['encoder'][key] = config['decoder'][key]
    else:
        config['encoder'] = dict()
        for key, val in hf_model.model.encoder.config.to_dict().items():
            config["encoder"][key] = f"{val}"
        config["encoder"]["q_scaling"] = '1'

        # mBART has final layernorm, BART does not
        config['encoder']['has_model_final_layernorm'] = str(
            isinstance(hf_model, MBartForConditionalGeneration))

    config["structure"] = dict()
    config["structure"]["t5_with_bias"] = "true"
    config["structure"]["use_gated_activation"] = "false"
    config["structure"]["position_embedding_type"] = "learned_absolute"
    config["structure"]["model_type"] = args.model_type

    def parse_bart_config_by_component(config, component, args):
        assert component in ('encoder', 'decoder'), 'Unsupported component!'
        component_config = types.SimpleNamespace()
        component_config = copy_args_to_component_config(component_config, args)
        component_config.n_layer = config.getint(component,
                                                 f'{component}_layers')
        component_config.n_head = config.getint(component,
                                                f'{component}_attention_heads')
        component_config.hidden_size = config.getint(component, 'd_model')
        component_config.head_size = config.getint(
            component,
            'd_kv',
            fallback=component_config.hidden_size // component_config.n_head)
        component_config.ffn_hidden_size = config.getint(
            component, f'{component}_ffn_dim')
        component_config.vocab_size = config.getint(component, 'vocab_size')
        component_config.n_positions = config.getint(component,
                                                     'max_position_embeddings')
        component_config.has_position_embedding = config.getboolean(
            component, 'has_position_embedding',
            fallback=True)  # TODO: hardcoded here
        component_config.has_token_type_embedding = config.getboolean(
            component, 'has_token_type_embedding', fallback=False)
        component_config.has_embedding_layernorm = config.getboolean(
            component, 'has_embedding_layernorm', fallback=True)
        component_config.has_embedding_scale = config.getboolean(
            component, 'scale_embedding')
        component_config.q_scaling = config.getfloat(component,
                                                     'q_scaling',
                                                     fallback=1.0)
        component_config.has_attention_qkvo_bias = config.getboolean(
            'structure', 't5_with_bias', fallback=True)
        component_config.has_mlp_bias = config.getboolean('structure',
                                                          't5_with_bias',
                                                          fallback=True)
        component_config.has_model_final_layernorm = config.getboolean(
            component, 'has_model_final_layernorm')
        component_config.layernorm_eps = config.getfloat(component,
                                                         'layer_norm_epsilon',
                                                         fallback=False)

        normalize_before = config.getboolean(component, 'normalize_before')
        component_config.layernorm_position = layernorm_position_map[
            'pre_layernorm' if normalize_before else 'post_layernorm']

        component_config.layernorm_type = layernorm_type_map[config.get(
            component, 'layernorm_type', fallback='LayerNorm')]
        component_config.hidden_act = config.get(component,
                                                 'activation_function')
        component_config.gated_act = config.getboolean(component,
                                                       'is_gated_act',
                                                       fallback=False)
        component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
                                                 gated_act else 'MLP']
        component_config.relative_attention = config.get(
            'structure', 'position_embedding_type') == 'relative'

        component_config.num_buckets = config.getint(
            component, 'relative_attention_num_buckets', fallback=0)
        component_config.max_distance = config.getint(
            component, 'relative_attention_max_distance', fallback=0)
        component_config.max_lora_rank = config.getint(component,
                                                       'max_lora_rank',
                                                       fallback=0)
        component_config.lora_target_modules = literal_eval(
            config.get(component, 'lora_target_modules', fallback="[]"))
        component_config.hf_modules_to_trtllm_modules = literal_eval(
            config.get(component, 'hf_modules_to_trtllm_modules',
                       fallback="{}"))
        component_config.trtllm_modules_to_hf_modules = literal_eval(
            config.get(component, 'trtllm_modules_to_hf_modules',
                       fallback="{}"))
        component_config.logits_dtype = config.get(component,
                                                   'logits_dtype',
                                                   fallback='float32')
        component_config.position_embedding_type = config.get(
            'structure', 'position_embedding_type')

        if component == 'decoder':
            component_config.rescale_before_lm_head = config.getboolean(
                component, 'rescale_before_lm_head')

            component_config.encoder_hidden_size = config.getint(
                'encoder', 'd_model')
            component_config.encoder_num_heads = config.getint(
                'encoder', 'encoder_attention_heads')
            component_config.encoder_head_size = config.getint(
                'encoder',
                'd_kv',
                fallback=component_config.encoder_hidden_size //
                component_config.encoder_num_heads)

            # nougat has decoder_start_token_id = None, special handling
            decoder_start_token_id = config.get('decoder',
                                                'decoder_start_token_id')
            component_config.decoder_start_token_id = int(
                decoder_start_token_id
            ) if decoder_start_token_id != "None" else None

        return component_config

    encoder_config = None
    if not args.nougat:
        encoder_config = parse_bart_config_by_component(config, "encoder", args)
    decoder_config = parse_bart_config_by_component(config, "decoder", args)

    return encoder_config, decoder_config


def convert_bart_weights_to_tllm_safetensors(config, component, params):
    weights = {}

    mapping = config.mapping

    hidden_size = config.hidden_size

    convert_weight_to_dtype(params, config.dtype)
    ffn_hidden_size = config.intermediate_size
    vocab_size = config.vocab_size

    hf_param_prefix = f'model.{component}'
    trtllm_layer_name = f'{component}_layers'
    trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
    trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
    embedding_layer_names = {
        'embed_tokens.weight': {
            "name": 'embedding.vocab_embedding.weight',
            "shape": (vocab_size, -1)
        },
        'embed_positions.weight': {
            "name": 'embedding.position_embedding.weight',
            "shape": (config.max_position_embeddings, hidden_size)
        },
        'layernorm_embedding.weight': {
            "name": 'embedding.embedding_layernorm.weight',
            "shape": None
        },
        'layernorm_embedding.bias': {
            "name": 'embedding.embedding_layernorm.bias',
            "shape": None
        },
    }

    hidden_layer_name_split = {
        'self_attn.out_proj.weight': {
            "name": f'{trtllm_attn_layer_name}.dense.weight',
            "shape": (hidden_size, hidden_size // mapping.tp_size),
            "split_dim": -1
        },
        'fc1.weight': {
            "name": 'mlp.fc.weight',
            "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
            "split_dim": 0
        },
        'fc1.bias': {
            "name": 'mlp.fc.bias',
            "shape": (ffn_hidden_size // mapping.tp_size),
            "split_dim": 0
        },
        'fc2.weight': {
            "name": 'mlp.proj.weight',
            "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
            "split_dim": -1
        },
    }

    hidden_layer_name_no_split = {
        'self_attn.out_proj.bias': {
            "name": f'{trtllm_attn_layer_name}.dense.bias',
            "shape": (hidden_size)
        },
        'self_attn_layer_norm.weight': {
            "name": f'{trtllm_attn_layernorm_name}.weight',
            "shape": None
        },
        'self_attn_layer_norm.bias': {
            "name": f'{trtllm_attn_layernorm_name}.bias',
            "shape": None
        },
        'fc2.bias': {
            "name": 'mlp.proj.bias',
            "shape": (hidden_size)
        },
        'final_layer_norm.weight': {
            "name": 'mlp_layernorm.weight',
            "shape": None
        },
        'final_layer_norm.bias': {
            "name": 'mlp_layernorm.bias',
            "shape": None
        },
    }

    if config.model_type == 'mbart':
        hidden_layer_name_split['layer_norm.weight'] = {
            "name": 'final_layernorm.weight',
            "shape": None,
            "split_dim": 0
        }
        hidden_layer_name_no_split['layer_norm.bias'] = {
            "name": 'final_layernorm.bias',
            "shape": None,
            "split_dim": 0
        }

    if component == "decoder":
        hidden_layer_name_split.update({
            'encoder_attn.out_proj.weight': {
                "name": 'cross_attention.dense.weight',
                "shape": (hidden_size, hidden_size // mapping.tp_size),
                "split_dim": -1
            }
        })
        hidden_layer_name_no_split.update({
            'encoder_attn.out_proj.bias': {
                "name": 'cross_attention.dense.bias',
                "shape": (hidden_size)
            },
            'encoder_attn_layer_norm.weight': {
                "name": 'cross_attention_layernorm.weight',
                "shape": None
            },
            'encoder_attn_layer_norm.bias': {
                "name": 'cross_attention_layernorm.bias',
                "shape": None
            },
        })

    def get_attn_module_name(component, layer, attn_type):
        return f'model.{component}.layers.{int(layer)}.{attn_type}'

    for hf_weight_name, weight_info in embedding_layer_names.items():
        if 'position' in hf_weight_name:
            weights[weight_info["name"]] = params[
                f'{hf_param_prefix}.{hf_weight_name}'][2:].clone()
        else:
            weights[weight_info["name"]] = params[
                f'{hf_param_prefix}.{hf_weight_name}'].clone()
        weights[weight_info["name"]] = reshape(weights[weight_info["name"]],
                                               weight_info["shape"])

    num_layers = config.num_hidden_layers

    layers_range = mapping.pp_layers(num_layers)
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'

        for hf_weight_name, weight_info in hidden_layer_name_split.items():
            weights[
                f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
                    split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
                          mapping.tp_size,
                          mapping.tp_rank,
                          dim=weight_info["split_dim"]), weight_info["shape"])

        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
            trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
            hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
            weights[trtllm_layer_fullname] = reshape(
                params[hf_layer_fullname].clone(), shape=weight_info["shape"])

        self_attn_module_name = get_attn_module_name(component, layer_idx,
                                                     'self_attn')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
                mapping.tp_size, mapping.tp_rank, config.model_type,
                (hidden_size * 3 // mapping.tp_size, hidden_size),
                (hidden_size * 3 // mapping.tp_size)))
        if component == 'decoder':
            cross_attn_module_name = get_attn_module_name(
                component, layer_idx, 'encoder_attn')
            weights.update(
                fuse_qkv_one_layer(
                    params, cross_attn_module_name,
                    f'{trtllm_layer_name_prefix}.cross_attention',
                    mapping.tp_size, mapping.tp_rank, config.model_type,
                    (hidden_size * 3 // mapping.tp_size, hidden_size),
                    (hidden_size * 3 // mapping.tp_size)))

    if component == 'decoder':
        weights['lm_head.weight'] = reshape(
            split(params['lm_head.weight'],
                  mapping.tp_size,
                  mapping.tp_rank,
                  dim=0), (config.vocab_size // mapping.tp_size, hidden_size))

    if config.has_model_final_layernorm:
        weights['final_layernorm.weight'] = params[
            f'{hf_param_prefix}.layer_norm.weight'].clone()
        weights['final_layernorm.bias'] = params[
            f'{hf_param_prefix}.layer_norm.bias'].clone()

    return weights


def parse_pix2struct_config(args, hf_model):
    # manually set q_scaling to offset attention scaling's effect.
    # TODO: modify kernels to control whether to disable attention scaling
    config = configparser.ConfigParser()

    def get_offset_q_scaling(config) -> str:
        d_model = config.hidden_size
        num_heads = config.num_heads
        head_size = d_model / num_heads
        scaling = 1 / head_size**.5
        return str(scaling)

    config["decoder"] = {}
    for key, val in hf_model.decoder.config.to_dict().items():
        config["decoder"][key] = f"{val}"

    config["decoder"]["q_scaling"] = get_offset_q_scaling(
        hf_model.decoder.config)

    config["structure"] = dict()
    config["structure"]["pix2struct_with_bias"] = "false"
    config["structure"]["use_gated_activation"] = "false"
    config["structure"]["position_embedding_type"] = "relative"
    config["structure"]["model_type"] = args.model_type

    def parse_pix2struct_config_by_component(config, component, args):
        if component == 'decoder':
            args.n_layer = config.getint(component, 'num_layers')
            args.n_head = config.getint(component, 'num_heads')
            args.head_size = config.getint(component, 'd_kv')
            args.hidden_size = config.getint(component, 'hidden_size')
            args.ffn_hidden_size = config.getint(component, 'd_ff')
            args.vocab_size = config.getint(component, 'vocab_size')
            args.n_positions = config.getint(component,
                                             'n_positions',
                                             fallback=512)
            args.has_position_embedding = config.getboolean(
                component, 'has_position_embedding',
                fallback=False)  # TODO: hardcoded here
            args.has_token_type_embedding = config.getboolean(
                component, 'has_token_type_embedding', fallback=False)
            args.has_embedding_layernorm = config.getboolean(
                component, 'has_embedding_layernorm', fallback=False)
            args.has_embedding_scale = config.getboolean(component,
                                                         'has_embedding_scale',
                                                         fallback=False)
            args.q_scaling = config.getfloat(component,
                                             'q_scaling',
                                             fallback=1.0)
            args.has_attention_qkvo_bias = config.getboolean(
                component, 'has_attention_qkvo_bias', fallback=False)
            args.has_mlp_bias = config.getboolean(component,
                                                  'has_mlp_bias',
                                                  fallback=False)
            args.has_model_final_layernorm = config.getboolean(
                component, 'has_model_final_layernorm', fallback=True)
            args.layernorm_eps = config.getfloat(component,
                                                 'layer_norm_epsilon')
            args.layernorm_position = layernorm_position_map[config.get(
                component, 'layernorm_position',
                fallback='pre_layernorm')]  # TODO: hardcoded here
            args.layernorm_type = layernorm_type_map[config.get(
                component, 'layernorm_type', fallback='RmsNorm')]
            args.hidden_act = config.get(component, 'dense_act_fn')
            args.gated_act = True
            args.mlp_type = mlp_type_map['GatedMLP' if args.
                                         gated_act else 'MLP']
            args.has_lm_head_bias = config.getboolean(
                component,  # TODO: T5 with bias
                'has_lm_head_bias',
                fallback=False)
            args.relative_attention = config.getboolean(component,
                                                        'relative_attention',
                                                        fallback=True)
            args.num_buckets = config.getint(component,
                                             'relative_attention_num_buckets')
            args.max_distance = config.getint(
                component, 'relative_attention_max_distance')
            args.logits_dtype = config.get(component,
                                           'logits_dtype',
                                           fallback='float32')
            args.rescale_before_lm_head = config.getboolean(
                component, 'tie_word_embeddings'
            )  # default is True (for T5), but False for Flan-T5
            args.encoder_hidden_size = config.getint('decoder', 'hidden_size')
            args.encoder_num_heads = config.getint('decoder', 'num_heads')
            args.encoder_head_size = config.getint('decoder', 'd_kv')
            args.position_embedding_type = config.get(
                'structure', 'position_embedding_type')
            args.decoder_start_token_id = config.getint(
                'decoder', 'decoder_start_token_id')

        else:
            assert False, 'Unsupported component!'
        return args

    decoder_args = parse_pix2struct_config_by_component(config, "decoder", args)
    return None, decoder_args


def convert_pix2struct_weights_to_tllm_safetensors(config, component, params):
    weights = {}

    mapping = config.mapping

    convert_weight_to_dtype(params, config.dtype)
    hidden_size = config.hidden_size
    ffn_hidden_size = config.intermediate_size
    num_layers = config.num_hidden_layers
    n_head = config.num_attention_heads
    head_size = config.head_size
    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5

    hf_param_prefix = f'{component}'
    trtllm_layer_name = f'{component}_layers'
    trtllm_attn_layer_name = 'self_attention'
    trtllm_attn_layernorm_name = 'self_attention_layernorm'

    def get_attn_module_name(component, layer, attn_type):
        return f'{component}.layer.{int(layer)}.{attn_type}.attention'

    weights['embedding.vocab_embedding.weight'] = reshape(
        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)

    layers_range = mapping.pp_layers(num_layers)
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
        hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'

        hidden_layer_name_split = {
            f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
                "shape":
                (hidden_size, attention_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
                "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
                "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                "split_dim": 0
            },
        }

        hidden_layer_name_no_split = {
            f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
                "shape": None
            },
            f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
                "shape": None
            },
        }

        if config.gated_act:
            hidden_layer_name_split.update({
                f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {
                    "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
                    "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                    "split_dim": 0
                },
            })

        hidden_layer_name_split.update({
            f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':
            {
                "name":
                f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
                "shape":
                (hidden_size, attention_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
        })
        hidden_layer_name_no_split.update({
            f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':
            {
                "name":
                f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
                "shape": None
            },
        })
        self_attn_module_name = get_attn_module_name(
            component, layer_idx, 'encoder_decoder_attention')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,
                mapping.tp_rank, config.model_type,
                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
                None))

        self_attn_module_name = get_attn_module_name(component, layer_idx,
                                                     'self_attention')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
                mapping.tp_size, mapping.tp_rank, config.model_type,
                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
                None))

        weights[
            f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
                split(
                    params[
                        f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
                    .T, mapping.tp_size, mapping.tp_rank, 0),
                (n_head // mapping.tp_size, config.num_buckets))

        for hf_weight_name, weight_info in hidden_layer_name_split.items():
            if hf_weight_name in params.keys():
                weights[weight_info["name"]] = reshape(
                    split(params[hf_weight_name],
                          mapping.tp_size,
                          mapping.tp_rank,
                          dim=weight_info["split_dim"]), weight_info["shape"])
        for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
            if hf_weight_name in params.keys():
                weights[weight_info["name"]] = reshape(
                    params[hf_weight_name].clone(), shape=weight_info["shape"])

    weights[f'final_layernorm.weight'] = reshape(
        params[f'{component}.final_layer_norm.weight'].clone(), None)

    weights['lm_head.weight'] = reshape(
        split(params[f'{component}.lm_head.weight'],
              mapping.tp_size,
              mapping.tp_rank,
              dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
    if not config.use_implicit_relative_attention:
        weights[f'rel_attn_table'] = reshape(
            split(
                params[
                    f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
                .T, mapping.tp_size, mapping.tp_rank, 0),
            (n_head // mapping.tp_size, config.num_buckets))

    return weights


def parse_StructEqTable_config(args, hf_model):
    # manually set q_scaling to offset attention scaling's effect.
    # TODO: modify kernels to control whether to disable attention scaling
    config = configparser.ConfigParser()

    def get_offset_q_scaling(config) -> str:
        d_model = config.hidden_size
        num_heads = config.num_heads
        head_size = d_model / num_heads
        scaling = 1 / head_size**.5
        return str(scaling)

    config["decoder"] = {}
    for key, val in hf_model.decoder.config.to_dict().items():
        config["decoder"][key] = f"{val}"

    config["decoder"]["q_scaling"] = get_offset_q_scaling(
        hf_model.decoder.config)

    config["structure"] = dict()
    config["structure"]["pix2struct_with_bias"] = "false"
    config["structure"]["use_gated_activation"] = "false"
    config["structure"]["position_embedding_type"] = "relative"
    config["structure"]["model_type"] = args.model_type

    def parse_StructEqTable_config_by_component(config, component, args):
        if component == 'decoder':
            args.n_layer = config.getint(component, 'num_layers')
            args.n_head = config.getint(component, 'num_heads')
            args.head_size = config.getint(component, 'd_kv')
            args.hidden_size = config.getint(component, 'hidden_size')
            args.ffn_hidden_size = config.getint(component, 'd_ff')
            args.vocab_size = config.getint(component, 'vocab_size')
            args.n_positions = config.getint(component,
                                             'n_positions',
                                             fallback=512)
            args.has_position_embedding = config.getboolean(
                component, 'has_position_embedding',
                fallback=False)  # TODO: hardcoded here
            args.has_token_type_embedding = config.getboolean(
                component, 'has_token_type_embedding', fallback=False)
            args.has_embedding_layernorm = config.getboolean(
                component, 'has_embedding_layernorm', fallback=False)
            args.has_embedding_scale = config.getboolean(component,
                                                         'has_embedding_scale',
                                                         fallback=False)
            args.q_scaling = config.getfloat(component,
                                             'q_scaling',
                                             fallback=1.0)
            args.has_attention_qkvo_bias = config.getboolean(
                component, 'has_attention_qkvo_bias', fallback=False)
            args.has_mlp_bias = config.getboolean(component,
                                                  'has_mlp_bias',
                                                  fallback=False)
            args.has_model_final_layernorm = config.getboolean(
                component, 'has_model_final_layernorm', fallback=True)
            args.layernorm_eps = config.getfloat(component,
                                                 'layer_norm_epsilon')
            args.layernorm_position = layernorm_position_map[config.get(
                component, 'layernorm_position',
                fallback='pre_layernorm')]  # TODO: hardcoded here
            args.layernorm_type = layernorm_type_map[config.get(
                component, 'layernorm_type', fallback='RmsNorm')]
            args.hidden_act = config.get(component, 'dense_act_fn')
            args.gated_act = True
            args.mlp_type = mlp_type_map['GatedMLP' if args.
                                         gated_act else 'MLP']
            args.has_lm_head_bias = config.getboolean(
                component,  # TODO: T5 with bias
                'has_lm_head_bias',
                fallback=False)
            args.relative_attention = config.getboolean(component,
                                                        'relative_attention',
                                                        fallback=True)
            args.num_buckets = config.getint(component,
                                             'relative_attention_num_buckets')
            args.max_distance = config.getint(
                component, 'relative_attention_max_distance')
            args.logits_dtype = config.get(component,
                                           'logits_dtype',
                                           fallback='float32')
            args.rescale_before_lm_head = config.getboolean(
                component, 'tie_word_embeddings'
            )  # default is True (for T5), but False for Flan-T5
            args.encoder_hidden_size = config.getint('decoder', 'hidden_size')
            args.encoder_num_heads = config.getint('decoder', 'num_heads')
            args.encoder_head_size = config.getint('decoder', 'd_kv')
            args.position_embedding_type = config.get(
                'structure', 'position_embedding_type')
            args.decoder_start_token_id = config.getint(
                'decoder', 'decoder_start_token_id')

        else:
            assert False, 'Unsupported component!'
        return args

    decoder_args = parse_StructEqTable_config_by_component(config, "decoder", args)
    return None, decoder_args


def convert_StructEqTable_weights_to_tllm_safetensors(config, component, params):
    weights = {}

    mapping = config.mapping

    convert_weight_to_dtype(params, config.dtype)
    hidden_size = config.hidden_size
    ffn_hidden_size = config.intermediate_size
    num_layers = config.num_hidden_layers
    n_head = config.num_attention_heads
    head_size = config.head_size
    attention_hidden_size = n_head * head_size  # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5

    hf_param_prefix = f'{component}'
    trtllm_layer_name = f'{component}_layers'
    trtllm_attn_layer_name = 'self_attention'
    trtllm_attn_layernorm_name = 'self_attention_layernorm'

    def get_attn_module_name(component, layer, attn_type):
        return f'{component}.layer.{int(layer)}.{attn_type}.attention'

    weights['embedding.vocab_embedding.weight'] = reshape(
        params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)

    layers_range = mapping.pp_layers(num_layers)
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
        hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'

        hidden_layer_name_split = {
            f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
                "shape":
                (hidden_size, attention_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
                "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
            f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
                "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                "split_dim": 0
            },
        }

        hidden_layer_name_no_split = {
            f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {
                "name":
                f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
                "shape": None
            },
            f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {
                "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
                "shape": None
            },
        }

        if config.gated_act:
            hidden_layer_name_split.update({
                f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {
                    "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
                    "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
                    "split_dim": 0
                },
            })

        hidden_layer_name_split.update({
            f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':
            {
                "name":
                f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
                "shape":
                (hidden_size, attention_hidden_size // mapping.tp_size),
                "split_dim": -1
            },
        })
        hidden_layer_name_no_split.update({
            f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':
            {
                "name":
                f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
                "shape": None
            },
        })
        self_attn_module_name = get_attn_module_name(
            component, layer_idx, 'encoder_decoder_attention')
        weights.update(
            fuse_qkv_one_layer(
                params, self_attn_module_name,
                f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,
                mapping.tp_rank, config.model_type,
                (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
                None))

        self_attn_module_name = get_attn_module_name(component, lay
Download .txt
gitextract_pzxcbxif/

├── .gitignore
├── LICENSE
├── README.md
├── docs/
│   └── TENSORRT_GETTING_STARTED.md
├── requirements.txt
├── setup.py
├── struct_eqtable/
│   ├── __init__.py
│   ├── internvl/
│   │   ├── __init__.py
│   │   ├── conversation.py
│   │   ├── internvl.py
│   │   └── internvl_lmdeploy.py
│   └── pix2s/
│       ├── __init__.py
│       ├── pix2s.py
│       └── pix2s_trt.py
└── tools/
    ├── demo/
    │   ├── demo.py
    │   └── demo.tex
    ├── scripts/
    │   └── build_tensorrt.sh
    └── tensorrt_utils/
        ├── build_visual_engine.py
        ├── convert_checkpoint.py
        └── helper.py
Download .txt
SYMBOL INDEX (112 symbols across 11 files)

FILE: setup.py
  function write_version_to_file (line 5) | def write_version_to_file(version, target_file):

FILE: struct_eqtable/__init__.py
  function get_model_name (line 15) | def get_model_name(model_path):
  function build_model (line 31) | def build_model(model_ckpt='U4R/StructTable-InternVL2-1B', **kwargs):

FILE: struct_eqtable/internvl/conversation.py
  class SeparatorStyle (line 13) | class SeparatorStyle(IntEnum):
  class Conversation (line 37) | class Conversation:
    method get_prompt (line 61) | def get_prompt(self) -> str:
    method set_system_message (line 251) | def set_system_message(self, system_message: str):
    method append_message (line 255) | def append_message(self, role: str, message: str):
    method update_last_message (line 259) | def update_last_message(self, message: str):
    method to_gradio_chatbot (line 267) | def to_gradio_chatbot(self):
    method to_openai_api_messages (line 277) | def to_openai_api_messages(self):
    method copy (line 289) | def copy(self):
    method dict (line 304) | def dict(self):
  function register_conv_template (line 318) | def register_conv_template(template: Conversation, override: bool = False):
  function get_conv_template (line 328) | def get_conv_template(name: str) -> Conversation:

FILE: struct_eqtable/internvl/internvl.py
  class InternVL (line 8) | class InternVL(nn.Module):
    method __init__ (line 9) | def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_...
    method init_model (line 29) | def init_model(self, model_path):
    method init_image_processor (line 39) | def init_image_processor(self, image_processor_path):
    method init_tokenizer (line 45) | def init_tokenizer(self, tokenizer_path):
    method format_image_tokens (line 58) | def format_image_tokens(self, path_num):
    method forward (line 61) | def forward(self, images, output_format='latex', **kwargs):
    method find_closest_aspect_ratio (line 130) | def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width...
    method dynamic_preprocess (line 145) | def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=...

FILE: struct_eqtable/internvl/internvl_lmdeploy.py
  class InternVL_LMDeploy (line 11) | class InternVL_LMDeploy(nn.Module):
    method __init__ (line 12) | def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_...
    method init_tokenizer (line 30) | def init_tokenizer(self, tokenizer_path):
    method init_model (line 37) | def init_model(self, model_path):
    method forward (line 49) | def forward(self, images, output_format='latex', **kwargs):

FILE: struct_eqtable/pix2s/pix2s.py
  class Pix2Struct (line 7) | class Pix2Struct(nn.Module):
    method __init__ (line 8) | def __init__(self, model_path='U4R/StructTable-base', max_new_tokens=1...
    method postprocess_latex_code (line 21) | def postprocess_latex_code(self, code):
    method init_model (line 26) | def init_model(self, model_path):
    method init_image_processor (line 30) | def init_image_processor(self, image_processor_path):
    method forward (line 33) | def forward(self, image, **kwargs):

FILE: struct_eqtable/pix2s/pix2s_trt.py
  function trt_dtype_to_torch (line 23) | def trt_dtype_to_torch(dtype):
  class Pix2StructTensorRT (line 36) | class Pix2StructTensorRT(nn.Module):
    method __init__ (line 38) | def __init__(self, model_path, tensorrt_path, batch_size=1, max_new_to...
    method postprocess_latex_code (line 78) | def postprocess_latex_code(self, code):
    method init_image_processor (line 83) | def init_image_processor(self):
    method init_tokenizer (line 87) | def init_tokenizer(self):
    method init_image_encoder (line 92) | def init_image_encoder(self):
    method init_llm (line 100) | def init_llm(self):
    method __call__ (line 112) | def __call__(self, image, **kwargs):
    method preprocess (line 135) | def preprocess(self, warmup, pre_prompt, post_prompt, image,
    method generate (line 168) | def generate(self, pre_prompt, post_prompt, image, decoder_input_ids,
    method get_visual_features (line 223) | def get_visual_features(self, image, attention_mask):
    method setup_fake_prompts (line 259) | def setup_fake_prompts(self, visual_features, pre_input_ids, post_inpu...
    method ptuning_setup (line 283) | def ptuning_setup(self, prompt_table, input_ids, input_lengths):
    method setup_inputs (line 312) | def setup_inputs(self, input_text, raw_image):
    method run (line 353) | def run(self, flattened_patches, attention_mask, max_new_tokens):
  function read_config (line 390) | def read_config(config_path):
  class Mapping (line 469) | class Mapping(object):
    method __init__ (line 470) | def __init__(
    method get_node_rank (line 549) | def get_node_rank(self, rank: int):
    method get_local_rank (line 552) | def get_local_rank(self, rank: int):
    method has_tp (line 555) | def has_tp(self):
    method is_last_pp_rank (line 558) | def is_last_pp_rank(self):
    method is_first_pp_rank (line 561) | def is_first_pp_rank(self):
    method has_pp (line 564) | def has_pp(self):
    method prev_pp_rank (line 567) | def prev_pp_rank(self):
    method next_pp_rank (line 573) | def next_pp_rank(self):
    method has_moe_tp (line 579) | def has_moe_tp(self):
    method has_moe_ep (line 582) | def has_moe_ep(self):
    method pp_layers (line 585) | def pp_layers(self, num_layers: int) -> List[int]:
    method ep_experts (line 591) | def ep_experts(self, num_experts: int) -> List[int]:
  function get_engine_name (line 598) | def get_engine_name(rank):
  class TRTLLMEncDecModel (line 601) | class TRTLLMEncDecModel:
    method __init__ (line 603) | def __init__(
    method from_engine (line 718) | def from_engine(cls,
    method process_input (line 734) | def process_input(self,
    method encoder_run (line 767) | def encoder_run(self,
    method generate (line 913) | def generate(self,

FILE: tools/demo/demo.py
  function parse_config (line 9) | def parse_config():
  function main (line 23) | def main():

FILE: tools/tensorrt_utils/build_visual_engine.py
  function parse_arguments (line 30) | def parse_arguments():
  class VisionEngineBuilder (line 64) | class VisionEngineBuilder:
    method __init__ (line 66) | def __init__(self, args):
    method build (line 78) | def build(self):
  function export_visual_wrapper_onnx (line 109) | def export_visual_wrapper_onnx(visual_wrapper,
  function build_trt_engine (line 127) | def build_trt_engine(model_type,
  function build_blip2_engine (line 212) | def build_blip2_engine(args):
  function build_pix2struct_engine (line 252) | def build_pix2struct_engine(args):
  function build_StructEqTable_engine (line 301) | def build_StructEqTable_engine(args):
  function build_llava_engine (line 350) | def build_llava_engine(args):
  function build_vila_engine (line 385) | def build_vila_engine(args):
  function build_nougat_engine (line 429) | def build_nougat_engine(args):
  function build_cogvlm_engine (line 457) | def build_cogvlm_engine(args):
  function build_fuyu_engine (line 493) | def build_fuyu_engine(args):
  function build_neva_engine (line 533) | def build_neva_engine(args):
  function build_video_neva_engine (line 606) | def build_video_neva_engine(args):
  function build_kosmos_engine (line 687) | def build_kosmos_engine(args):
  function build_phi_engine (line 723) | def build_phi_engine(args):

FILE: tools/tensorrt_utils/convert_checkpoint.py
  function copy_args_to_component_config (line 32) | def copy_args_to_component_config(component_config, args):
  function parse_t5_config (line 38) | def parse_t5_config(args, hf_model):
  function convert_t5_weights_to_tllm_safetensors (line 152) | def convert_t5_weights_to_tllm_safetensors(config, component, params):
  function parse_nmt_config (line 320) | def parse_nmt_config(args, model):
  function convert_nmt_weights_to_tllm_safetensors (line 441) | def convert_nmt_weights_to_tllm_safetensors(config, component, params,
  function parse_bart_config (line 598) | def parse_bart_config(args, hf_model):
  function convert_bart_weights_to_tllm_safetensors (line 748) | def convert_bart_weights_to_tllm_safetensors(config, component, params):
  function parse_pix2struct_config (line 938) | def parse_pix2struct_config(args, hf_model):
  function convert_pix2struct_weights_to_tllm_safetensors (line 1038) | def convert_pix2struct_weights_to_tllm_safetensors(config, component, pa...
  function parse_StructEqTable_config (line 1186) | def parse_StructEqTable_config(args, hf_model):
  function convert_StructEqTable_weights_to_tllm_safetensors (line 1286) | def convert_StructEqTable_weights_to_tllm_safetensors(config, component,...
  function get_model (line 1434) | def get_model(args):
  function convert_checkpoint (line 1458) | def convert_checkpoint(args):
  function convert (line 1635) | def convert(worker_rank, world_size, args, model_config, convert_args,

FILE: tools/tensorrt_utils/helper.py
  function split (line 10) | def split(v: Union[np.ndarray, torch.Tensor],
  function reshape (line 30) | def reshape(v: torch.Tensor, shape=None):
  function fuse_qkv_one_layer (line 37) | def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_s...
  function get_qkv_module_name (line 71) | def get_qkv_module_name(model_type):
  function convert_weight_to_dtype (line 91) | def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (226K chars).
[
  {
    "path": ".gitignore",
    "chars": 71,
    "preview": "dist/\nbuild/\n**.egg-info/\n**__pycache__/\n**.cache\nckpts/\n**version.py\n\n"
  },
  {
    "path": "LICENSE",
    "chars": 11356,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 7481,
    "preview": "<div align=\"center\">\n<h1>StructEqTable-Deploy: A High-efficiency Open-source Toolkit for Table-to-Latex Transformation</"
  },
  {
    "path": "docs/TENSORRT_GETTING_STARTED.md",
    "chars": 5471,
    "preview": "# Getting Started\n[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is used for model inference speeding up.  \n\nAll"
  },
  {
    "path": "requirements.txt",
    "chars": 25,
    "preview": "torch\ntransformers<=4.47\n"
  },
  {
    "path": "setup.py",
    "chars": 1116,
    "preview": "from pathlib import Path\nfrom setuptools import find_packages, setup\n\n\ndef write_version_to_file(version, target_file):\n"
  },
  {
    "path": "struct_eqtable/__init__.py",
    "chars": 1181,
    "preview": "from .pix2s import Pix2Struct, Pix2StructTensorRT\nfrom .internvl import InternVL, InternVL_LMDeploy\n\nfrom transformers i"
  },
  {
    "path": "struct_eqtable/internvl/__init__.py",
    "chars": 79,
    "preview": "from .internvl import InternVL\nfrom .internvl_lmdeploy import InternVL_LMDeploy"
  },
  {
    "path": "struct_eqtable/internvl/conversation.py",
    "chars": 14534,
    "preview": "\"\"\"\nConversation prompt templates.\n\nWe kindly request that you import fastchat instead of copying this file if you wish "
  },
  {
    "path": "struct_eqtable/internvl/internvl.py",
    "chars": 7129,
    "preview": "import torch\n\nfrom torch import nn\nfrom transformers import AutoModel, AutoTokenizer, AutoImageProcessor, GenerationConf"
  },
  {
    "path": "struct_eqtable/internvl/internvl_lmdeploy.py",
    "chars": 2295,
    "preview": "import torch\nfrom torch import nn\n\nfrom transformers import AutoTokenizer\ntry:\n    from lmdeploy import pipeline, Genera"
  },
  {
    "path": "struct_eqtable/pix2s/__init__.py",
    "chars": 76,
    "preview": "from .pix2s import Pix2Struct\nfrom .pix2s_trt import Pix2StructTensorRT\n    "
  },
  {
    "path": "struct_eqtable/pix2s/pix2s.py",
    "chars": 2032,
    "preview": "import torch\n\nfrom torch import nn\nfrom transformers import AutoModelForVision2Seq, AutoProcessor\n\n\nclass Pix2Struct(nn."
  },
  {
    "path": "struct_eqtable/pix2s/pix2s_trt.py",
    "chars": 42994,
    "preview": "import os\nimport time\nimport json\n\nimport torch\nimport torch.nn as nn\n\ntry:\n    import tensorrt_llm\n    import tensorrt "
  },
  {
    "path": "tools/demo/demo.py",
    "chars": 2995,
    "preview": "import time\nimport torch\nimport argparse\n\nfrom PIL import Image\nfrom struct_eqtable import build_model\n\n\ndef parse_confi"
  },
  {
    "path": "tools/demo/demo.tex",
    "chars": 1828,
    "preview": "\n\\documentclass[border=20pt]{standalone}\n\\usepackage{blindtext}%\n\\usepackage{subcaption}\n\\usepackage{url}\n\\usepackage{gr"
  },
  {
    "path": "tools/scripts/build_tensorrt.sh",
    "chars": 1626,
    "preview": "set -x \n\nHF_CKPT_PATH=${1:-\"../ckpts/StructTable-base\"}\nMODEL_OUTPUT=${2:-\"../ckpts/StructTable-base-TensorRT\"}\nMAX_IMAG"
  },
  {
    "path": "tools/tensorrt_utils/build_visual_engine.py",
    "chars": 31869,
    "preview": "import argparse\nimport os\nimport shutil\nimport sys\nimport tarfile\nfrom time import time\n\nimport yaml\n\n# isort: off\nimpor"
  },
  {
    "path": "tools/tensorrt_utils/convert_checkpoint.py",
    "chars": 80096,
    "preview": "import argparse\nimport configparser\nimport copy\nimport json\nimport logging\nimport os\nimport types\nfrom ast import litera"
  },
  {
    "path": "tools/tensorrt_utils/helper.py",
    "chars": 3419,
    "preview": "import typing\nfrom typing import Union\n\nimport numpy as np\nimport torch  # pytype: disable=import-error\n\nfrom tensorrt_l"
  }
]

About this extraction

This page contains the full source code of the UniModal4Reasoning/StructEqTable-Deploy GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (212.6 KB), approximately 47.1k tokens, and a symbol index with 112 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!