[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2024 OpenBMB\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# Megrez-3B-Omni: The First Open-Source End-Side Full Modality Understanding Model\n\n<p align=\"center\">\n    <img src=\"assets/megrez_logo.png\" width=\"400\"/>\n<p>\n<p align=\"center\">\n    📄 <a href=\"assets/Megrez_Omni_Technical_Report.pdf\">Paper</a>\n    🤗 <a href=\"https://huggingface.co/Infinigence/Megrez-3B-Omni\">Huggingface</a>&nbsp&nbsp | &nbsp&nbsp🤖<a href=\"https://www.modelscope.cn/models/InfiniAI/Megrez-3B-Omni\">Modelscope</a>&nbsp&nbsp | &nbsp&nbsp🖥️ <a href=\"https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni\">Demo</a>&nbsp&nbsp | &nbsp&nbsp📖 <a href=\"assets/wechat-official.jpg\">WeChat Official</a>&nbsp&nbsp | &nbsp&nbsp💬 <a href=\"assets/wechat-group.jpg\">WeChat Groups</a>&nbsp&nbsp\n</p>\n\n<strong>[中文](./README_zh.md) | English</strong>\n\n</div>\n\n## Introduction\n**Megrez-3B-Omni** is an on-device multimodal understanding LLM model developed by **Infinigence AI** ([Infinigence AI](https://cloud.infini-ai.com/platform/ai)). It is an extension of the Megrez-3B-Instruct model and supports analysis of image, text, and audio modalities. The model achieves state-of-the-art accuracy in all three domains:\n- Image Understanding: By utilizing SigLip-400M for constructing image tokens, Megrez-3B-Omni outperforms models with more parameters such as LLaVA-NeXT-Yi-34B. It is one of the best image understanding models among multiple mainstream benchmarks, including MME, MMMU, and OCRBench. It demonstrates excellent performance in tasks such as scene understanding and OCR.\n- Language Understanding: Megrez-3B-Omni retains text understanding capabilities without significant trade-offs. Compared to its single-modal counterpart (Megrez-3B-Instruct), the accuracy variation is less than 2%, maintaining state-of-the-art performance on benchmarks like C-EVAL, MMLU/MMLU Pro, and AlignBench. It also outperforms previous-generation models with 14B parameters.\n- Speech Understanding: Equipped with the encoder head of Qwen2-Audio/whisper-large-v3, the model supports both Chinese and English speech input, multi-turn conversations, and voice-based questions about input images. It can directly respond to voice commands with text and achieved leading results across multiple benchmarks.\n\n## Evaluation\n\n- The left image compares the performance of Megrez-3B-Omni with other open-source models on mainstream image multimodal tasks.\n- The right image shows the performance of Megrez-3B-Omni on the OpenCompass test set. Image reference: [InternVL 2.5 Blog Post](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/).  \n\nYou can find detailed accuracy metrics on the [Megrez-3B-Omni-HF](https://huggingface.co/Infinigence/Megrez-3B-Omni) page.  \n\n<div style=\"display: flex; justify-content: space-between;\">\n  <img src=\"assets/multitask.jpg\" alt=\"Comparison of Image Understanding Capabilities\" style=\"width: 45%;\">\n  <img src=\"assets/opencompass.jpg\" alt=\"OpenCompass Benchmark Performance\" style=\"width: 45%;\">\n</div>\n\n### Inference Speed\n\n|                | image_tokens | prefill (tokens/s) | decode (tokens/s) |\n|----------------|:------------:|:------------------:|:-----------------:|\n| Megrez-3B-Omni |      448     |       6312.66      |       1294.9      |\n| Qwen2-VL-2B    |     1378     |       7349.39      |       685.66      |\n| MiniCPM-V-2_6  |      448     |       2167.09      |       452.51      |\n\nSetup:  \n- The testing environment utilizes an NVIDIA H100 GPU with vLLM. Each test includes 128 text tokens and a 720×1480 image as input, producing 128 output tokens, with `num_seqs` fixed at 8.  \n- Under this setup, the decoding speed of Qwen2-VL-2B is slower than Megrez-3B-Omni, despite having a smaller base LLM. This is due to the larger number of image tokens generated when encoding images of the specified size, which impacts actual inference speed.  \n\n## Model Demo\n\n【GIF】\n\n## Install\n\nInstall runtime dependencies with the following command:\n\n```shell\npip install -r requirements.txt\n```\n\nThe audio-related functionality relies on **FFmpeg** for audio processing. If you are using a Debian or Debian-based system, you can install FFmpeg with the following command:\n\n```bash\nsudo apt-get install ffmpeg\n```\n\nFor other operating systems, please refer to the [official FFmpeg documentation](https://ffmpeg.org/download.html) for installation instructions.\n\n## Inference\n\n### Conversation with Multimodal Data\n\nYou can use the following script to chat with our model. Note that you should replace `PATH_TO_PRETRAINED_MODEL` with the path to the downloaded model checkpoint.\n\n```python\nimport torch\nfrom transformers import AutoModelForCausalLM\n\npath = \"{{PATH_TO_PRETRAINED_MODEL}}\"  # Change this to the path of the model.\n\nmodel = (\n    AutoModelForCausalLM.from_pretrained(\n        path,\n        trust_remote_code=True,\n        torch_dtype=torch.bfloat16,\n    )\n    .eval()\n    .cuda()\n)\n\nmessages = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": \"Please describe the content of the image.\",\n            \"image\": \"./data/sample_image.jpg\",\n        },\n    },\n]\n\nMAX_NEW_TOKENS = 100\nresponse = model.chat(\n    messages,\n    sampling=False,\n    max_new_tokens=MAX_NEW_TOKENS,\n)\nprint(response)\n```\n\nYou can also find a complete script in [example_chat_hf.py](example_chat_hf.py).\n\n### Inference with vLLM\n\nWe provide a reference implementation of inference with vLLM framework. You can find the model definition in [vllm_demo/megrezo.py](vllm_demo/megrezo.py).\n\n1. Install vLLM\n\n```shell\npip install vllm==0.6.3.post1 flash_attn==2.5.8 xformers==0.0.27.post2\n```\n\n**Note**: To use vLLM for inference, it is essential to install specific versions of the dependencies. Other versions may lead to interface incompatibility risks. If you encounter any issues, feel free to [open an issue](https://github.com/infinigence/Infini-Megrez-Omni/issues/new).\n\n2. Run the inference script\n\nSince vLLM does not officially support MegrezO yet, you need to import the module first:\n\n```python\nfrom vllm import ModelRegistry\nfrom megrezo import MegrezOModel\n\nModelRegistry.register_model(\"MegrezO\", MegrezOModel)\n```\n\nThen, you can run inference with the following code:\n\n```python\nfrom PIL import Image\nfrom vllm import LLM\nfrom vllm import SamplingParams\n\n\n# Load the model.\nmodel_path = \"{{PATH_TO_HF_PRETRAINED_MODEL}}\"  # Change this to the path of the model.\nllm = LLM(\n    model_path,\n    trust_remote_code=True,\n    gpu_memory_utilization=0.5,\n)\n\nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=1000,\n    repetition_penalty=1.2,\n    stop=[\"<|turn_end|>\", \"<|eos|>\"],\n)\n\nimg = Image.open(\"../data/sample_image.jpg\")\n\nconversation = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": \"图片的内容是什么？\",\n            \"image\": img,\n        },\n    },\n]\n\n# Convert the conversation to vLLM acceptable format.\nprompt = llm.get_tokenizer().apply_chat_template(\n    conversation,\n    tokenize=False,\n    add_generation_prompt=True,\n)\nvllm_inputs = [\n    {\n        \"prompt\": prompt,\n        \"multi_modal_data\": {\n            \"image\": img,\n        },\n    }\n]\n\n# Generate the outputs.\noutputs = llm.generate(\n    vllm_inputs,\n    sampling_params,\n)\n\n# Print the outputs.\nfor output in outputs:\n    print(output.outputs[0].text)\n```\n\nYou can find a complete script in [vllm_demo/example_infer_vllm.py](vllm_demo/example_infer_vllm.py).\n\n## Chat with MegrezO using Gradio\n\nWe provide online and local demos powered by Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>.\n\n### WebUI Demonstration\n\n<div align=\"center\" style=\"display: flex; justify-content: space-between;\">\n  <img src=\"assets/gradio_demo.jpg\" style=\"width: 80%;\">\n</div>\n\n### Online Demo\n\nPlease try out our online Demo here: [🤗Megrez-3B-Omni](https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni)\n\n### Local WebUI Demo\n  \nYou can easily deploy your own local WebUI to chat with MegrezO using Gradio.\n\n1. Install dependencies:\n\n```shell\npip install -r requirements.txt\n```\n\n2. Launch the Gradio app.\n\nYou need to specify the `model_path` and `port` in the command line. The `model_path` is the path to the model checkpoint, and the `port` is the port number for the local server. By default, the `port` is `7860`.\n\n```shell\npython gradio_app.py --model_path {model_path} --port {port}\n```\n\nThen, you can visit `http://localhost:7860` in your browser to interact with the model.\n\nFeel free to modify the `gradio_app.py` to customize the input and output interfaces. For more information, please refer to the [Gradio documentation](https://gradio.app/docs).\n\n## Fine-Tuning the Model\n\nWe provide a [fine-tuning example](./finetune/) based on [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [accelerate](https://github.com/huggingface/accelerate).\n\n### Data Preparation\n\nWe have constructed a sample dataset based on [ALLaVA-4V/allava_laion](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/tree/main/allava_laion) dataset:  \n\n- **Dialogue**: [data/train/records.jsonl](./data/train/records.jsonl)  \n- **Images**: [data/train/images](./data/train/images)  \n- **Audio**: [data/train/audio](./data/train/audio), created by converting dialogue text into speech using TTS.  \n\nYou can also prepare your own dataset following the same format.\n\n### Dependencies Installation\n\nInstall the required dependencies with the following command:  \n\n```bash\npip install deepspeed accelerate\n```\n\n### Full-Parameter Fine-Tuning\n\nTo run the fine-tuning example, execute the following commands. Be sure to replace the model path in the script with the path to your downloaded model.  \n\n```bash\ncd finetune\n\nsh finetune.sh\n```\n\nYou can customize the modules to fine-tune by setting the parameters:  \n`tune_vision_encoder`, `tune_vision_proj`, `tune_llm`, `tune_audio_encoder`, and `tune_audio_proj`.\n\n### Notes\n\n1. **Recommended Hardware**: Please use at least two GPUs with 80GB memory for fine-tuning.  \n2. **If GPU memory is insufficient**:  \n   - Adjust the `model_max_length` and `per_device_train_batch_size` parameters.  \n   - Disable specific modules for fine-tuning to reduce memory usage.  \n   - Optimize memory consumption by configuring the `zero_optimization` parameters in DeepSpeed.\n3. **For better inference results**:\n   - We recommend to put the images in the first round of chat for better inference results. There are no such restrictions for audio and text, which can be switched freely.\n   - In the Automatic Speech Recognition (ASR) scenario, simply change content['text'] to \"Convert speech to text.\"\n   - In the OCR scenario, enabling sampling may introduce language model hallucinations which cause text changes. Users may consider disabling sampling in inference (sampling=False). However, disabling sampling may introduce model repetition.\n \n\n## Open Source License and Usage Statement\n\n- **License**: The code in this repository is open-sourced under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) license.  \n- **Hallucination**: Large models inherently have hallucination issues. Users should not completely trust the content generated by the model. \n- **Values and Safety**: While we have made every effort to ensure compliance of the data used during training, the large volume and complexity of the data may still lead to unforeseen issues. We disclaim any liability for problems arising from the use of this open-source model, including but not limited to data security issues, public opinion risks, or risks and problems caused by misleading, misuse, propagation, or improper utilization of the model.  \n\n"
  },
  {
    "path": "README_zh.md",
    "content": "<div align=\"center\">\n\n# Megrez-3B-Omni: 首个端侧全模态理解开源模型\n\n<p align=\"center\">\n    <img src=\"assets/megrez_logo.png\" width=\"400\"/>\n<p>\n<p align=\"center\">\n    📄 <a href=\"assets/Megrez_Omni_Technical_Report.pdf\">Paper</a>\n    🤗 <a href=\"https://huggingface.co/Infinigence/Megrez-3B-Omni\">Huggingface</a>&nbsp&nbsp | &nbsp&nbsp🤖<a href=\"https://www.modelscope.cn/models/InfiniAI/Megrez-3B-Omni\">Modelscope</a>&nbsp&nbsp | &nbsp&nbsp🖥️ <a href=\"https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni\">Demo</a>&nbsp&nbsp | &nbsp&nbsp📖 <a href=\"assets/wechat-official.jpg\">WeChat Official</a>&nbsp&nbsp | &nbsp&nbsp💬 <a href=\"assets/wechat-group.jpg\">WeChat Groups</a>&nbsp&nbsp\n</p>\n\n<strong>中文 | [English](./README.md)</strong>\n\n</div>\n\n## 模型简介\nMegrez-3B-Omni是由无问芯穹（[Infinigence AI](https://cloud.infini-ai.com/platform/ai)）研发的**端侧全模态**理解模型，基于无问大语言模型Megrez-3B-Instruct扩展，同时具备图片、文本、音频三种模态数据的理解分析能力，在三个方面均取得最优精度\n- 在图像理解方面，基于SigLip-400M构建图像Token，在OpenCompass榜单上（综合8个主流多模态评测基准）平均得分66.2，超越LLaVA-NeXT-Yi-34B等更大参数规模的模型。Megrez-3B-Omni也是在MME、MMMU、OCRBench等测试集上目前精度最高的图像理解模型之一，在场景理解、OCR等方面具有良好表现。\n- 在语言理解方面，Megrez-3B-Omni并未牺牲模型的文本处理能力，综合能力较单模态版本（Megrez-3B-Instruct）精度变化小于2%，保持在C-EVAL、MMLU/MMLU Pro、AlignBench等多个测试集上的最优精度优势，依然取得超越上一代14B模型的能力表现\n- 在语音理解方面，采用Qwen2-Audio/whisper-large-v3的Encoder作为语音输入，支持中英文语音输入及多轮对话，支持对输入图片的语音提问，根据语音指令直接响应文本，在多项基准任务上取得了领先的结果\n\n## 评测结果\n- 左图为Megrez-3B-Omni与其他开源模型在主流图片多模态任务上的性能比较\n- 右图为Megrez-3B-Omni在OpenCompass测试集上表现，图片引用自： [InternVL 2.5 Blog Post](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/)*\n<div style=\"display: flex; justify-content: space-between;\">\n  <img src=\"assets/multitask.jpg\" alt=\"Image 1\" style=\"width: 45%;\">\n  <img src=\"assets/opencompass.jpg\" alt=\"Image 2\" style=\"width: 45%;\">\n</div>\n\n详细精度见 [Megrez-3B-Omni-HF](https://huggingface.co/Infinigence/Megrez-3B-Omni)\n\n### 推理速度\n|                | image_tokens | prefill (tokens/s) | decode (tokens/s) |\n|----------------|:------------:|:------------------:|:-----------------:|\n| Megrez-3B-Omni |      448     |       6312.66      |       1294.9      |\n| Qwen2-VL-2B    |     1378     |       7349.39      |       685.66      |\n| MiniCPM-V-2_6  |      448     |       2167.09      |       452.51      |\n\n实验设置：\n- 测试环境为NVIDIA H100下VLLM下输入128个Text token和一张 720*1480的图片，输出128个token，num_seqs固定为8。\n- Qwen2-VL-2B的在此实验下的decode速度小于Megrez-3B-Omni，虽然其具备更小的基座LLM，但是编码上述大小图片后的image_token相较Megrez-3B-Omni较多，影响实际推理速度。\n\n## 模型演示\n【GIF】\n\n## 安装\n使用如下命令安装依赖：\n\n```shell\npip install -r requirements.txt\n```\n\n音频功能依赖ffmpeg进行音频处理，如果您使用 Debian 相关的系统，可以通过以下命令安装：\n\n```shell\nsudo apt-get install ffmpeg\n```\n\n对于其他的操作系统，请参考 [ffmpeg 官方文档](https://ffmpeg.org/download.html) 进行安装。\n\n\n## 模型推理\n\n### 使用多模态数据进行多轮对话\n\n请使用如下脚本进行推理。请将 `PATH_TO_PRETRAINED_MODEL` 替换为下载的模型权重的路径。\n```python\nimport torch\nfrom transformers import AutoModelForCausalLM\n\npath = \"{{PATH_TO_PRETRAINED_MODEL}}\"  # 更改为模型的路径\n\nmodel = (\n    AutoModelForCausalLM.from_pretrained(\n        path,\n        trust_remote_code=True,\n        torch_dtype=torch.bfloat16,\n    )\n    .eval()\n    .cuda()\n)\n\nmessages = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": \"Please describe the content of the image.\",\n            \"image\": \"./data/sample_image.jpg\",\n        },\n    },\n]\n\nMAX_NEW_TOKENS = 100\nresponse = model.chat(\n    messages,\n    sampling=False,\n    max_new_tokens=MAX_NEW_TOKENS,\n)\nprint(response)\n```\n\n完整的示例见：[example_chat_hf.py](example_chat_hf.py).\n\n### 使用 vLLM 进行推理\n我们提供了一个基于 vLLM 框架的推理参考实现。您可以在 [vllm_demo/megrezo.py](vllm_demo/megrezo.py) 中找到模型定义。\n\n推理步骤如下：\n\n1. 安装 vLLM\n\n```shell\npip install vllm==0.6.3.post1 flash_attn==2.5.8 xformers==0.0.27.post2\n```\n\n**注意**：使用 vLLM 推理需要安装特定版本的依赖，其他版本可能存在接口不一致的风险。有任何问题欢迎[提issue](https://github.com/infinigence/Infini-Megrez-Omni/issues/new)。\n\n2. 运行推理脚本\n\nvLLM 尚未正式支持 MegrezO，因此您需要先导入我们定义的模块：\n\n```python\nfrom vllm import ModelRegistry\nfrom megrezo import MegrezOModel\n\nModelRegistry.register_model(\"MegrezO\", MegrezOModel)\n```\n\n然后，您可以使用以下代码运行推理：\n\n```python\nfrom PIL import Image\nfrom vllm import LLM\nfrom vllm import SamplingParams\n\n\nmodel_path = \"{{PATH_TO_HF_PRETRAINED_MODEL}}\"  # 更改为模型的路径\nllm = LLM(\n    model_path,\n    trust_remote_code=True,\n    gpu_memory_utilization=0.5,\n)\n\nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=1000,\n    repetition_penalty=1.2,\n    stop=[\"<|turn_end|>\", \"<|eos|>\"],\n)\n\nimg = Image.open(\"../data/sample_image.jpg\")\n\nconversation = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": \"图片的内容是什么？\",\n            \"image\": img,\n        },\n    },\n]\n\n# 将对话转换为 vLLM 可接受的格式。\nprompt = llm.get_tokenizer().apply_chat_template(\n    conversation,\n    tokenize=False,\n    add_generation_prompt=True,\n)\nvllm_inputs = [\n    {\n        \"prompt\": prompt,\n        \"multi_modal_data\": {\n            \"image\": img,\n        },\n    }\n]\n\n# 生成输出\noutputs = llm.generate(\n    vllm_inputs,\n    sampling_params,\n)\n\n# 打印输出\nfor output in outputs:\n    print(output.outputs[0].text)\n```\n\n完整的示例见：[vllm_demo/example_infer_vllm.py](vllm_demo/example_infer_vllm.py).\n\n## 使用 Gradio 与 MegrezO 对话\n\n我们提供基于 Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> 实现的在线和本地 Demo。\n\n### WeiUI 演示\n\n<div align=\"center\" style=\"display: flex; justify-content: space-between;\">\n  <img src=\"assets/gradio_demo.jpg\" style=\"width: 80%;\">\n</div>\n\n### 在线 Demo\n\n欢迎试用在线 Demo: [🤗Megrez-3B-Omni](https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni)。\n\n### 本地 Demo\n  \n使用如下命令部署本地 Gradio 应用：\n\n1. 安装依赖:\n\n```shell\npip install -r requirements.txt\n```\n\n2. 启动 Gradio 应用\n\n您需要在命令行中指定 `model_path` 和 `port`。`model_path` 是模型的路径，`port` 是本地服务器的端口号。默认情况下，`port` 是 `7860`。\n\n```shell\npython gradio_app.py --model_path {model_path} --port {port}\n```\n\n然后，您可以在浏览器中访问 `http://localhost:7860` 与模型对话。\n\n如需自定义输入和输出接口，请修改 `gradio_app.py`。更多信息请参考 [Gradio 文档](https://gradio.app/docs)。\n\n## 微调模型\n\n我们提供了一个基于 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 和 [accelerate](https://github.com/huggingface/accelerate) 的[微调示例](./finetune/)。\n\n### 数据准备\n\n我们基于[ALLaVA-4V/allava_laion](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/tree/main/allava_laion)构造了一个示例数据集：\n\n- **对话**：[data/train/records.jsonl](./data/train/records.jsonl)\n- **图片**：[data/train/images](./data/train/images)\n- **音频**：[data/train/audio](./data/train/audio)，是通过将对话中的文本使用TTS转换为语音得到的。\n\n您也可以按照上述格式准备自己的数据集。\n\n### 依赖安装\n\n```shell\npip install deepspeed accelerate\n```\n\n### 全参微调\n\n使用如下命令运行我们的微调示例，请注意将脚本中的模型路径替换成您下载的模型路径。\n\n```shell\ncd finetune\n\nsh finetune.sh\n```\n\n您可以通过设置`tune_vision_encoder`、`tune_vision_proj`、`tune_llm`、`tune_audio_encoder`、`tune_audio_proj`来选择需要微调的模块。\n\n### 注意事项\n\n- 推荐使用至少2张拥有80G显存的GPU进行微调。\n- 在显存不足的情况下：\n  - 请尝试调整`model_max_length`和`per_device_train_batch_size`。\n  - 请尝试关闭需要微调的模块以便减少显存占用。\n  - 请尝试调整deepspeed的`zero_optimization`参数来优化显存使用。\n- 使用时\n  - 请将图片尽量在首轮输入以保证推理效果，语音和文本无此限制，可以自由切换\n  - 语音识别（ASR）场景下，只需要将content['text']修改为“将语音转化为文字。”\n  - OCR场景下开启采样可能会引入语言模型幻觉导致的文字变化，可考虑关闭采样进行推理（sampling=False），但关闭采样可能引入模型复读\n\n## 开源协议及使用声明\n\n- **协议**：本仓库中代码依照 [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 协议开源。\n- **幻觉**：大模型天然存在幻觉问题，用户使用过程中请勿完全相信模型生成的内容。\n- **价值观及安全性**：本模型已尽全力确保训练过程中使用的数据的合规性，但由于数据的大体量及复杂性，仍有可能存在一些无法预见的问题。如果出现使用本开源模型而导致的任何问题，包括但不限于数据安全问题、公共舆论风险，或模型被误导、滥用、传播或不当利用所带来的任何风险和问题，我们将不承担任何责任。\n"
  },
  {
    "path": "data/train/records.jsonl",
    "content": "{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"Considering the items the figurine is holding, what scene or event from its source material might it be representing?\",\"image\":\"images/100230144.jpeg\",\"audio\":\"audios/0_0_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"Based on the items the figurine is holding, which are a rifle and an iron, it likely represents a scene from the source material where the character is setting up traps to protect his home. These items are indicative of the resourceful and inventive methods the character uses in the movie \\\"Home Alone\\\" to create humorous but effective traps for intruders.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"How does the inclusion of gemstones affect the overall aesthetic and possible value of the ring?\",\"image\":\"images/100277305.jpeg\",\"audio\":\"audios/0_1_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The inclusion of gemstones along the edges of the leaf design enhances the ring's aesthetic by adding a touch of brilliance and luxury. The small, clear gemstones catch the light, creating a sparkling effect that contrasts with the silver's polished surface and the leaf's textured details. This sparkling accent likely increases the ring's visual appeal, making it stand out more as an accessory. In terms of value, the presence of gemstones, which could be diamonds, would likely raise the ring's monetary worth. Gemstones, especially diamonds, are prized for their beauty and durability, and their addition to jewelry pieces often signifies a higher quality and class of the item, thereby increasing its market value.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What might the significance of the figure holding the 'O' in relation to the setting sun symbolize in the context of this album cover?\",\"image\":\"images/100234216.jpeg\",\"audio\":\"audios/0_2_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The silhouette of the figure holding the 'O' of the \\\"WOW\\\" so that it appears to be part of the setting sun likely symbolizes an element of enlightenment or revelation, which is often a theme in Christian music. It could represent the divine or an epiphany, echoing the 'wow' factor that the album title suggests. This imagery might be intended to convey a sense of awe or an extraordinary experience, possibly reflecting the emotional or spiritual impact that the songs aim to impart on the listener.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What might be the significance of the character's clothing and instrument choice in relation to the type of event being promoted?\",\"image\":\"images/100223054.jpeg\",\"audio\":\"audios/0_3_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The character's clothing and instrument choice are indicative of a relaxed and informal atmosphere, which aligns with the musical genres mentioned on the poster\\u2014rock and hard rock. The gray cap and black shirt could be associated with a laid-back, perhaps slightly rebellious aesthetic that is often linked to rock music cultures. The red bass guitar signifies that music is central to the event and hints that the character may represent a member of the band \\\"PuBFlieS,\\\" suggesting they play rock or hard rock music, as bass guitars are fundamental in these genres. The cheerful demeanor of the character along with the casual attire suggests that the event aims to be welcoming and enjoyable, catering to fans of these music genres.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What material is the water bottle likely made from?\",\"image\":\"images/100280844.jpeg\",\"audio\":\"audios/0_4_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"Based on the image alone, the water bottle is likely made from metal, inferred from the reflective and smooth surface, which is characteristic of metal water bottles. The metallic sheen and lack of any seams or typical plastic texturing support this assumption. Additionally, the way the light reflects off the surface suggests a material that is more reflective than plastic or glass, which is often the case with stainless steel or aluminum bottles.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What historical period might the attire of the person in the painting suggest?\",\"image\":\"images/100284327.jpeg\",\"audio\":\"audios/0_5_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The attire of the person in the painting suggests a historical period that could be the late Middle Ages or the Renaissance, specifically within Europe. This is indicated by the style of the hat, which resembles a wide-brimmed hat known as a Gainsborough hat, popular in different forms from the 17th century onwards, and the monastic-like robes, which were common in religious orders during the medieval period. However, the exact historical period cannot be definitively determined from the attire alone due to the painting's anachronistic nature.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"Could the absence of any ports or connectors on the visible side suggest that this is a wireless device?\",\"image\":\"images/100245026.jpeg\",\"audio\":\"audios/0_6_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"Based on the visible features in the image, the absence of any ports, buttons, or connectors on the surface of the device does suggest that it could be a wireless device. This characteristic is common in modern wireless technology, where devices are often designed to be sleek and unobtrusive, relying on wireless connectivity for their functionality. The simplicity and cleanliness of the design further imply that it could be a wireless charging pad, which typically requires no physical connections to the devices it charges.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"Considering the individual's attire and the setting, what could be the possible theme or context of this photo shoot?\",\"image\":\"images/10688.jpeg\",\"audio\":\"audios/3_0_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The theme or context of the photo shoot could be a conceptual representation of rebellion or resistance. The military-style jacket with patches, the helmet, and the setting that resembles a prison all suggest a narrative of combatting authority or standing against confinement. The use of fashion to portray this narrative indicates a possible commentary on individualism and defiance.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What could be the possible association between the two logos presented in the image, and how might they relate to the content listed in the slide?\",\"image\":\"images/104042.jpeg\",\"audio\":\"audios/2_0_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The possible association between the two logos and the content of the slide suggests a partnership or a collaborative project focused on recycling and waste electrical and electronic equipment (WEEE). The \\\"LIFE +\\\" logo is associated with an EU environmental initiative, and \\\"RECYCLING SYMPRAXIS\\\" suggests a practice or a consortium working towards recycling. The date and word \\\"PHILOXENIA\\\" hint at an event, possibly a conference or seminar that took place in 2010. The second logo, which is less identifiable, likely represents the organization responsible for the content of the presentation, in this case, \\\"Q-PLAN Northern Greece\\\", which seems to be the coordinator or the main body overseeing the implementation of the state-of-the-art technologies and applications in WEEE recycling. The contents listed in the slide would be topics discussed in relation to these technologies and their applications.\"}}]}\n{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"What might the three stars above the team crest signify in the context of soccer achievements?\",\"image\":\"images/100271334.jpeg\",\"audio\":\"audios/0_7_0_chattts.wav\"}},{\"role\":\"assistant\",\"content\":{\"text\":\"The three stars above the team crest traditionally represent major honors or championships won by the team. In many soccer leagues, a star is added to the team's crest for a set number of league or major tournament victories. For instance, a club might add a star for every ten league titles they win. Therefore, these stars are likely indicative of the team's historical success, possibly in their domestic league or international competitions.\"}}]}\n"
  },
  {
    "path": "example_chat_hf.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: example_chat_hf.py\n# Description: None\n\nimport torch\nfrom transformers import AutoModelForCausalLM\n\npath = \"/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o\"  # Change this to the path of the model.\n\nmodel = (\n    AutoModelForCausalLM.from_pretrained(\n        path,\n        trust_remote_code=True,\n        torch_dtype=torch.bfloat16,\n        attn_implementation=\"flash_attention_2\",\n    )\n    .eval()\n    .cuda()\n)\nprompt = \"hi\" * (128 - 1) \n# Chat with text and image\nmessages = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": prompt,\n            \"image\": \"./data/sample_image.jpg\",\n        },\n    },\n]\n\n# Chat with audio and image\n# messages = [\n#     {\n#         \"role\": \"user\",\n#         \"content\": {\n#             \"image\": \"./data/sample_image.jpg\",\n#             \"audio\": \"./data/sample_audio.m4a\",\n#         },\n#     },\n# ]\n\nMAX_NEW_TOKENS = 100\nresponse = model.chat(\n    messages,\n    sampling=False,\n    max_new_tokens=MAX_NEW_TOKENS,\n)\nprint(response)\n"
  },
  {
    "path": "finetune/dataset.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: dataset.py\n# Description: None\n\nimport os\n\nimport numpy as np\nfrom regex import F\nimport torch\nfrom torch.utils.data import Dataset\n\n\nclass SupervisedDataset(Dataset):\n    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n\n    def __init__(\n        self,\n        raw_data_list,\n        processor,\n        process_func,\n        dataset_prefix=\"\",\n    ):\n        super(SupervisedDataset, self).__init__()\n        self.raw_data_list = raw_data_list\n        self.processor = processor\n        self.process_func = process_func\n        self.dataset_prefix = dataset_prefix\n\n    def __len__(self):\n        return len(self.raw_data_list)\n\n    def check_ret(self, ret):\n        flag = True\n        for key in ret.keys():\n            value_list = ret[key]\n            if not isinstance(value_list, list):\n                value_list = [value_list]\n            for value in value_list:\n                if isinstance(value, torch.Tensor):\n                    if torch.isnan(value).any():\n                        flag = False\n                    if torch.isinf(value).any():\n                        flag = False\n        return flag\n\n    def check_audio(self, ret):\n        flag = True\n        for audio in ret[\"msgs_audio\"]:\n            if (audio[\"input_audio_lengths\"][:, 1] == 0).any():\n                flag = False\n        return flag\n\n    def prepare_labels(self, data):\n\n        def prepare_labels(tokenizer, input_ids, padding_value=-100):\n            # <|role_start|>assistant<|role_end|> 后面的内容才是需要算loss的部分\n            def find_start_header_idxs():\n                start_header_tokens = tokenizer.encode(\"<|role_start|>assistant<|role_end|>\", add_special_tokens=False)\n                start_header_idxs = np.where(input_ids == start_header_tokens[-1])[0]\n\n                kept_start_header_idxs = []\n                for start_header_idx in start_header_idxs:\n                    keep = True\n                    for i in range(1, len(start_header_tokens)):\n                        if start_header_tokens[-(i + 1)] != input_ids[start_header_idx - i]:\n                            keep = False\n                            break\n                    if keep:\n                        kept_start_header_idxs.append(start_header_idx)\n                return kept_start_header_idxs\n\n            turn_end_token_id = tokenizer.encode(\"<|turn_end|>\")[0]\n            start_header_idxs = find_start_header_idxs()\n            end_header_idxs = np.where(input_ids == turn_end_token_id)[0]\n            label_mask = np.zeros_like(input_ids, dtype=np.bool_)\n\n            def find_next_greater_number(lst, num):\n                next_greater = None\n                for n in lst:\n                    if n > num:\n                        if next_greater is None or n < next_greater:\n                            next_greater = n\n                return next_greater\n\n            nr_tokens = len(input_ids)\n            for start_head_idx in start_header_idxs:\n                start_idx = start_head_idx + 1\n                end_idx = find_next_greater_number(end_header_idxs, start_head_idx)\n                end_idx = min(end_idx + 1, nr_tokens)\n                label_mask[start_idx:end_idx] = True\n\n            labels = torch.ones(input_ids.shape[0] + 1) * padding_value\n            labels[: input_ids.shape[0]] = input_ids\n            labels[: input_ids.shape[0]][~label_mask] = padding_value\n            labels = labels[1:]\n            return labels.long()\n\n        return prepare_labels(self.processor.tokenizer, data[\"input_ids\"])\n\n    def add_dataset_prefix(self, item):\n        conv = item[\"conversations\"]\n        for i in range(len(conv)):\n            content = conv[i][\"content\"]\n            if \"image\" in content:\n                content[\"image\"] = os.path.join(self.dataset_prefix, content[\"image\"])\n            if \"audio\" in content:\n                content[\"audio\"] = os.path.join(self.dataset_prefix, content[\"audio\"])\n\n        return conv\n\n    def __getitem__(self, i):\n        raw_data_item = self.raw_data_list[i]\n        item = self.add_dataset_prefix(raw_data_item)\n        processed_data = self.processor(\n            item,\n            add_generation_prompt=False,\n            apply_data_collator=False,\n        )\n        if \"labels\" not in processed_data:\n            processed_data[\"labels\"] = self.prepare_labels(processed_data)\n\n        return processed_data\n"
  },
  {
    "path": "finetune/ds_config_zero2.json",
    "content": "{\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"betas\": \"auto\",\n            \"eps\": \"auto\",\n            \"weight_decay\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"offload_optimizer\": {\n            \"device\": \"none\",\n            \"pin_memory\": true\n        },\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": 2e8,\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 100,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "finetune/finetune.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: finetune.py\n# Description: None\n\n\nimport glob\nimport json\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom dataclasses import field\nfrom functools import partial\nfrom glob import glob\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport torch\nimport transformers\nfrom accelerate.utils import DistributedType\nfrom dataset import SupervisedDataset\nfrom deepspeed import zero\nfrom deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\nfrom trainer import MegrezOTrainer\nfrom transformers import AutoModelForCausalLM\nfrom transformers import AutoProcessor\nfrom transformers import AutoTokenizer\nfrom transformers.integrations import deepspeed\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"openbmb/MiniCPM-V-2\")\n\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None, metadata={\"help\": \"Path to the training data.\"})\n    eval_data_path: str = field(default=None, metadata={\"help\": \"Path to the evaluation data.\"})\n    dataset_prefix: str = field(default=\"data\", metadata={\"help\": \"Prefix for the multimodal data.\"})\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    model_max_length: int = field(\n        default=2048,\n        metadata={\"help\": \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"},\n    )\n    tune_vision_encoder: Optional[bool] = field(default=True)\n    tune_vision_proj: Optional[bool] = field(default=True)\n    tune_llm: Optional[bool] = field(default=True)\n    tune_audio_encoder: Optional[bool] = field(default=True)\n    tune_audio_proj: Optional[bool] = field(default=True)\n    use_lora: Optional[bool] = field(default=False)\n    max_slice_nums: Optional[int] = field(default=9)\n    scale_resolution: Optional[int] = field(default=448)\n    remove_unused_columns: Optional[bool] = field(default=False)\n\n\n@dataclass\nclass LoraArguments:\n    lora_r: int = 64\n    lora_alpha: int = 64\n    lora_dropout: float = 0.05\n    lora_target_modules: str = r\"llm\\..*layers\\.\\d+\\.self_attn\\.(q_proj|k_proj|v_proj)\"\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\"\n    q_lora: bool = False\n    lora_modules_to_save: str = \"\"\n    lora_layer_replication: Optional[List[Tuple[int, int]]] = None\n    lora_layers_to_transform: Optional[List[int]] = None\n    lora_layers_pattern: Optional[str] = None\n\n\ndef maybe_zero_3(param):\n    if hasattr(param, \"ds_id\"):\n        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}\n    return to_return\n\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\n\ndef safe_save_model_for_hf_trainer(trainer, output_dir: str, bias=\"none\"):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n    # check if zero3 mode enabled\n    if deepspeed.is_deepspeed_zero3_enabled():\n        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()\n    else:\n        if trainer.args.use_lora:\n            state_dict = get_peft_state_maybe_zero_3(trainer.model.named_parameters(), bias)\n        else:\n            state_dict = trainer.model.state_dict()\n    if trainer.args.should_save and trainer.args.local_rank == 0:\n        trainer._save(output_dir, state_dict=state_dict)\n\n\ndef make_supervised_data_module(\n    data_args,\n    processor,\n    process_func,\n    data_collator=None,\n    max_length=2048,\n) -> Dict:\n    \"\"\"Make dataset and collator for supervised fine-tuning.\"\"\"\n    rank0_print(\"Loading data...\")\n\n    with open(data_args.data_path, \"r\") as f:\n        raw_data_list = [json.loads(line) for line in f]\n        train_dataset = SupervisedDataset(\n            raw_data_list,\n            processor,\n            process_func,\n            data_args.dataset_prefix,\n        )\n\n    eval_dataset = None\n    return dict(\n        train_dataset=train_dataset,\n        eval_dataset=eval_dataset,\n        data_collator=partial(data_collator, max_length=max_length, collate_labels=True),\n    )\n\n\ndef get_parameter_number(model):\n    trainable_params, all_param = 0, 0\n    for param in model.parameters():\n        num_params = param.numel()\n        # if using DS Zero 3 and the weights are initialized empty\n        if num_params == 0 and hasattr(param, \"ds_numel\"):\n            num_params = param.ds_numel\n\n        all_param += num_params\n        if param.requires_grad:\n            trainable_params += num_params\n\n    return {\"Total\": all_param, \"Trainable\": trainable_params}\n\n\nlocal_rank = 0\n\n\ndef load_model_from_pretrained(model_path, dtype=torch.bfloat16):\n    model = AutoModelForCausalLM.from_pretrained(\n        model_path, _attn_implementation=\"flash_attention_2\", trust_remote_code=True, torch_dtype=dtype\n    )\n    return model\n\n\ndef load_tokenizer_from_pretrained(model_path):\n    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)\n    return tokenizer\n\n\ndef train():\n    global local_rank\n    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))\n\n    (\n        model_args,\n        data_args,\n        training_args,\n        lora_args,\n    ) = parser.parse_args_into_dataclasses()\n\n    if getattr(training_args, \"deepspeed\", None):\n        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED\n\n    compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)\n\n    local_rank = training_args.local_rank\n    world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n    ddp = world_size != 1\n    device_map = None\n    if lora_args.q_lora:\n        device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\") or 0)} if ddp else None\n        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():\n            logging.warning(\"FSDP or ZeRO3 are not incompatible with QLoRA.\")\n\n    model = load_model_from_pretrained(model_args.model_name_or_path, dtype=compute_dtype)\n    tokenizer = load_tokenizer_from_pretrained(model_args.model_name_or_path)\n    processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)\n\n    model.tune_llm = training_args.tune_llm\n    model.tune_vision = training_args.tune_vision_encoder or training_args.tune_vision_proj\n    model.tune_audio = training_args.tune_audio_encoder or training_args.tune_audio_proj\n\n    if not training_args.tune_vision_encoder:\n        model.vision.vpm.requires_grad_(False)\n    if not training_args.tune_vision_proj:\n        model.vision.resampler.requires_grad_(False)\n    if not training_args.tune_llm:\n        model.llm.requires_grad_(False)\n    if not training_args.tune_audio_encoder:\n        model.audio.requires_grad_(False)\n        model.audio.audio.proj.requires_grad_(True)\n        if model.audio.audio.audio_bos_eos_token is not None:\n            model.audio.audio.audio_bos_eos_token.requires_grad_(True)\n    if not training_args.tune_audio_proj:\n        model.audio.audio.proj.requires_grad_(False)\n        if model.audio.audio.audio_bos_eos_token is not None:\n            model.audio.audio.audio_bos_eos_token.requires_grad_(False)\n\n    rank0_print(get_parameter_number(model))\n    data_module = make_supervised_data_module(\n        data_args=data_args,\n        processor=processor,\n        process_func=None,\n        data_collator=processor.data_collator,\n        max_length=training_args.model_max_length,\n    )\n    if training_args.lr_scheduler_type == \"cosine_with_min_lr\":\n        training_args.lr_scheduler_kwargs = {\"min_lr_rate\": 0.1}\n    trainer = MegrezOTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        args=training_args,\n        **data_module,\n    )\n\n    train_dataset = trainer.train_dataset\n    nr_data = len(train_dataset)\n    rank0_print(\"nr dataset: {}\".format(nr_data))\n\n    checkpoint_path = os.path.join(training_args.output_dir, \"checkpoint*\")\n    checkpoint_paths = sorted(list(glob(checkpoint_path)))\n\n    valid_checkpoint_paths = []\n    for checkpoint_path in checkpoint_paths:\n        checkpoint_num = checkpoint_path.split(\"-\")[-1]\n        if checkpoint_num.isdigit():\n            valid_checkpoint_paths.append(checkpoint_path)\n    checkpoint_paths = sorted(list(valid_checkpoint_paths))\n    checkpoint_paths = sorted(checkpoint_paths, key=lambda x: int(x.split(\"-\")[-1]))\n    checkpoint_paths = list(checkpoint_paths)\n    load_checkpoint = True\n\n    if load_checkpoint and checkpoint_paths:\n        checkpoint_path = checkpoint_paths[-1]\n        rank0_print(\"Continue Checkpoint Training: {}\".format(checkpoint_path))\n        trainer.train(checkpoint_path)\n    else:\n        trainer.train()\n\n    trainer.save_state()\n    final_path = os.path.join(training_args.output_dir, \"final\")\n    os.makedirs(final_path, exist_ok=True)\n    rank0_print(\"save final path to {}\".format(final_path))\n    safe_save_model_for_hf_trainer(trainer, final_path)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "finetune/finetune.sh",
    "content": "DATA_PATH=$(pwd)/../data/train/records.jsonl\nDATASET_PREFIX=$(pwd)/../data/train/\nCURRENT_TIME=$(date +%Y%m%d_%H%M%S)\nOUTPUT_DIR=$(pwd)/test_finetune/$CURRENT_TIME\nLOGGING_DIR=$(pwd)/test_finetune_log\nMODEL_PATH=\"\"\n\ntorchrun --nproc_per_node=2 finetune.py \\\n    --data_path $DATA_PATH \\\n    --dataset_prefix $DATASET_PREFIX \\\n    --output_dir $OUTPUT_DIR \\\n    --logging_dir $LOGGING_DIR \\\n    --model_name_or_path $MODEL_PATH \\\n    --learning_rate 1e-5 \\\n    --num_train_epochs 10 \\\n    --deepspeed ds_config_zero2.json \\\n    --prediction_loss_only false \\\n    --bf16 true \\\n    --fp16 false \\\n    --do_train \\\n    --tune_vision_encoder true \\\n    --tune_vision_proj true \\\n    --tune_llm true \\\n    --tune_audio_encoder false \\\n    --tune_audio_proj true \\\n    --model_max_length 2048 \\\n    --max_slice_nums 9 \\\n    --scale_resolution 448 \\\n    --logging_strategy \"steps\" \\\n    --per_device_train_batch_size 1 \\\n    --per_device_eval_batch_size 1 \\\n    --gradient_accumulation_steps 1 \\\n    --save_steps 1000 \\\n    --save_total_limit 100 \\\n    --learning_rate 1e-6 \\\n    --weight_decay 0.1 \\\n    --adam_beta2 0.98 \\\n    --warmup_ratio 0.01 \\\n    --lr_scheduler_type \"cosine\" \\\n    --logging_steps 1\n"
  },
  {
    "path": "finetune/requirements.txt",
    "content": "deepspeed\naccelerate"
  },
  {
    "path": "finetune/trainer.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: trainer.py\n# Description: None\n\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport deepspeed\nimport torch\nimport torch.nn as nn\nfrom transformers import Trainer\nfrom transformers.integrations import is_deepspeed_zero3_enabled\nfrom transformers.trainer_pt_utils import nested_detach\n\n\nclass MegrezOTrainer(Trainer):\n    def compute_loss(self, model, inputs, return_outputs=False):\n        if \"labels\" in inputs:\n            labels = inputs.pop(\"labels\")\n        else:\n            labels = None\n\n        self.model.vision.resampler.pos_embed = self.model.vision.resampler.pos_embed.to(self.model.device)\n        if is_deepspeed_zero3_enabled():\n            with deepspeed.zero.GatheredParameters(self.model.vision.resampler.attn.parameters(), modifier_rank=0):\n                if not self.args.use_lora:\n                    outputs = self.model(data=inputs, use_cache=False)\n                else:\n                    outputs = self.model.base_model(data=inputs, use_cache=False)\n        else:\n            if not self.args.use_lora:\n                outputs = self.model(data=inputs, use_cache=False)\n            else:\n                outputs = self.model.base_model(data=inputs, use_cache=False)\n\n        if labels is not None:\n            # Flatten the tokens\n            loss_fct = nn.CrossEntropyLoss()\n            logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()\n            labels = labels.view(-1).long().contiguous()\n            # Enable model parallelism\n            labels = labels.to(logits.device)\n            loss = loss_fct(logits, labels)\n        else:\n            if isinstance(outputs, dict) and \"loss\" not in outputs:\n                raise ValueError(\n                    \"The model did not return a loss from the inputs, only the following keys: \"\n                    f\"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.\"\n                )\n            # We don't use .loss here since the model may return tuples instead of ModelOutput.\n            loss = outputs[\"loss\"] if isinstance(outputs, dict) else outputs[0]\n\n        return (loss, outputs) if return_outputs else loss\n\n    def prediction_step(\n        self,\n        model: nn.Module,\n        inputs: Dict[str, Union[torch.Tensor, Any]],\n        prediction_loss_only: bool,\n        ignore_keys: Optional[List[str]] = None,\n    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:\n        \"\"\"\n        Perform an evaluation step on `model` using `inputs`.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to evaluate.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n            prediction_loss_only (`bool`):\n                Whether or not to return the loss only.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n\n        Return:\n            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,\n            logits and labels (each being optional).\n        \"\"\"\n        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)\n        # For CLIP-like models capable of returning loss values.\n        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`\n        # is `True` in `model.forward`.\n        return_loss = inputs.get(\"return_loss\", None)\n        if return_loss is None:\n            return_loss = self.can_return_loss\n        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False\n\n        inputs = self._prepare_inputs(inputs)\n        if ignore_keys is None:\n            if hasattr(self.model, \"config\"):\n                ignore_keys = getattr(self.model.config, \"keys_to_ignore_at_inference\", [])\n            else:\n                ignore_keys = []\n\n        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.\n        if has_labels or loss_without_labels:\n            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))\n            if len(labels) == 1:\n                labels = labels[0]\n        else:\n            labels = None\n\n        with torch.no_grad():\n            if has_labels or loss_without_labels:\n                with self.compute_loss_context_manager():\n                    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)\n                loss = loss.mean().detach()\n\n                if isinstance(outputs, dict):\n                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + [\"loss\"])\n                else:\n                    logits = outputs[1:]\n            else:\n                loss = None\n                with self.compute_loss_context_manager():\n                    outputs = model(**inputs)\n                if isinstance(outputs, dict):\n                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)\n                else:\n                    logits = outputs\n                # TODO: this needs to be fixed and made cleaner later.\n                if self.args.past_index >= 0:\n                    self._past = outputs[self.args.past_index - 1]\n\n        if prediction_loss_only:\n            return (loss, None, None)\n\n        logits = nested_detach(logits)\n        if len(logits) == 1:\n            logits = logits[0]\n\n        return (loss, logits, labels)\n\n    def training_step(\n        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int\n    ) -> torch.Tensor:\n        \"\"\"\n        Perform a training step on a batch of inputs.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to train.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n\n        Return:\n            `torch.Tensor`: The tensor with training loss on this batch.\n        \"\"\"\n        model.train()\n        inputs = self._prepare_inputs(inputs)\n\n        with self.compute_loss_context_manager():\n            loss = self.compute_loss(model, inputs)\n\n        del inputs\n        torch.cuda.empty_cache()\n\n        if self.args.n_gpu > 1:\n            loss = loss.mean()  # mean() to average on multi-gpu parallel training\n\n        if self.use_apex:\n            from transformers.trainer import amp\n\n            with amp.scale_loss(loss, self.optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            if is_deepspeed_zero3_enabled():\n                with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):\n                    self.accelerator.backward(loss)\n            else:\n                self.accelerator.backward(loss)\n\n        return loss.detach() / self.args.gradient_accumulation_steps\n"
  },
  {
    "path": "gradio_app.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: app.py\n# Description: None\n\n\nimport threading\nfrom copy import deepcopy\nfrom typing import Dict, List\n\nimport gradio as gr\nimport torch\nfrom transformers import AutoModelForCausalLM\nfrom transformers import AutoTokenizer\nfrom transformers import TextIteratorStreamer\n\nIMAGE_EXTENSIONS = (\".jpg\", \".jpeg\", \".png\", \".bmp\", \".tiff\", \".webp\")\nVIDEO_EXTENSIONS = (\".mp4\", \".mkv\", \".mov\", \".avi\", \".flv\", \".wmv\", \".webm\", \".m4v\")\nAUDIO_EXTENSIONS = (\".mp3\", \".wav\")\n\nDEFAULT_SAMPLING_PARAMS = {\n    \"top_p\": 0.8,\n    \"top_k\": 100,\n    \"temperature\": 0.7,\n    \"do_sample\": True,\n    \"num_beams\": 1,\n    \"repetition_penalty\": 1.2,\n}\nMAX_NEW_TOKENS = 1024\n\n\ndef main(model_path: str, port: int):\n\n    if gr.NO_RELOAD:\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        model = (\n            AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)\n            .eval()\n            .cuda()\n        )\n        iterable_streamer = TextIteratorStreamer(\n            tokenizer,\n            skip_prompt=True,\n            skip_special_tokens=True,\n            timeout=30,\n        )\n\n    def history2messages(history: List[Dict]) -> List[Dict]:\n        \"\"\"\n        Transform gradio history to chat messages.\n        \"\"\"\n        messages = []\n        cur_message = dict()\n        for item in history:\n            if item[\"role\"] == \"assistant\":\n                if len(cur_message) > 0:\n                    messages.append(deepcopy(cur_message))\n                    cur_message = dict()\n                messages.append(deepcopy(item))\n                continue\n\n            if \"role\" not in cur_message:\n                cur_message[\"role\"] = \"user\"\n            if \"content\" not in cur_message:\n                cur_message[\"content\"] = dict()\n\n            if \"metadata\" not in item:\n                item[\"metadata\"] = {\"title\": None}\n            if item[\"metadata\"][\"title\"] is None:\n                cur_message[\"content\"][\"text\"] = item[\"content\"]\n            elif item[\"metadata\"][\"title\"] == \"image\":\n                cur_message[\"content\"][\"image\"] = item[\"content\"][0]\n            elif item[\"metadata\"][\"title\"] == \"audio\":\n                cur_message[\"content\"][\"audio\"] = item[\"content\"][0]\n        if len(cur_message) > 0:\n            messages.append(cur_message)\n        return messages\n\n    def check_messages(history, message, audio):\n        audios = []\n        images = []\n\n        for file_msg in message[\"files\"]:\n            if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS):\n                audios.append(file_msg)\n            elif file_msg.endswith(IMAGE_EXTENSIONS):\n                images.append(file_msg)\n            else:\n                filename = file_msg.split(\"/\")[-1]\n                raise gr.Error(f\"Unsupported file type: {filename}. It should be an image or audio file.\")\n\n        if len(audios) > 1:\n            raise gr.Error(\"Please upload only one audio file.\")\n\n        if len(images) > 1:\n            raise gr.Error(\"Please upload only one image file.\")\n\n        if audio is not None:\n            if len(audios) > 0:\n                raise gr.Error(\"Please upload only one audio file or record audio.\")\n            audios.append(audio)\n\n        # Append the message to the history\n        for image in images:\n            history.append({\"role\": \"user\", \"content\": (image,), \"metadata\": {\"title\": \"image\"}})\n\n        for audio in audios:\n            history.append({\"role\": \"user\", \"content\": (audio,), \"metadata\": {\"title\": \"audio\"}})\n\n        if message[\"text\"] is not None:\n            history.append({\"role\": \"user\", \"content\": message[\"text\"]})\n\n        return history, gr.MultimodalTextbox(value=None, interactive=False)\n\n    def bot(\n        history: list,\n        top_p: float,\n        top_k: int,\n        temperature: float,\n        repetition_penalty: float,\n        max_new_tokens: int = MAX_NEW_TOKENS,\n        regenerate: bool = False,\n    ):\n        sampling_params = {\n            \"top_p\": top_p,\n            \"top_k\": top_k,\n            \"temperature\": temperature,\n            \"repetition_penalty\": repetition_penalty,\n        }\n\n        if regenerate:\n            history = history[:-1]\n\n        msgs = history2messages(history)\n        th = threading.Thread(\n            target=model.chat,\n            kwargs=dict(\n                input_msgs=msgs,\n                sampling=True,\n                streamer=iterable_streamer,\n                max_new_tokens=max_new_tokens,\n                **sampling_params,\n            ),\n        )\n        th.start()\n\n        response = \"\"\n        for subtext in iterable_streamer:\n            response += subtext\n            yield history + [{\"role\": \"assistant\", \"content\": response}]\n\n        th.join()\n        return response\n\n    def change_state(state):\n        return gr.update(visible=not state), not state\n\n    with gr.Blocks() as demo:\n        chatbot = gr.Chatbot(elem_id=\"chatbot\", bubble_full_width=False, type=\"messages\", height=800)\n\n        sampling_params_group_hidden_state = gr.State(False)\n\n        with gr.Row(equal_height=True):\n            audio_input = gr.Audio(\n                sources=[\"microphone\", \"upload\"],\n                type=\"filepath\",\n                scale=4,\n            )\n            chat_input = gr.MultimodalTextbox(\n                file_count=\"multiple\",\n                show_label=False,\n                scale=10,\n                file_types=[\"image\", \"audio\"],\n                # stop_btn=True,\n            )\n            with gr.Column(scale=1, min_width=150):\n                with gr.Row(equal_height=True):\n                    regenerate_btn = gr.Button(\"Regenerate\", variant=\"primary\")\n                    clear_btn = gr.ClearButton(\n                        [chat_input, audio_input, chatbot],\n                    )\n\n        with gr.Row():\n            sampling_params_toggle_btn = gr.Button(\"Sampling Parameters\")\n\n        with gr.Group(visible=False) as sampling_params_group:\n            with gr.Row():\n                temperature = gr.Slider(\n                    minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS[\"temperature\"], label=\"Temperature\"\n                )\n                repetition_penalty = gr.Slider(\n                    minimum=0,\n                    maximum=2,\n                    value=DEFAULT_SAMPLING_PARAMS[\"repetition_penalty\"],\n                    label=\"Repetition Penalty\",\n                )\n\n            with gr.Row():\n                top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS[\"top_p\"], label=\"Top-p\")\n                top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS[\"top_k\"], label=\"Top-k\")\n\n            with gr.Row():\n                max_new_tokens = gr.Slider(\n                    minimum=1,\n                    maximum=MAX_NEW_TOKENS,\n                    value=MAX_NEW_TOKENS,\n                    label=\"Max New Tokens\",\n                    interactive=True,\n                )\n\n        sampling_params_toggle_btn.click(\n            change_state,\n            sampling_params_group_hidden_state,\n            [sampling_params_group, sampling_params_group_hidden_state],\n        )\n\n        chat_msg = chat_input.submit(\n            check_messages,\n            [chatbot, chat_input, audio_input],\n            [chatbot, chat_input],\n        )\n        bot_msg = chat_msg.then(\n            bot,\n            inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens],\n            outputs=chatbot,\n            api_name=\"bot_response\",\n        )\n        bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n\n        regenerate_btn.click(\n            bot,\n            inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)],\n            outputs=chatbot,\n        )\n\n    demo.launch(server_port=port)\n\n\nif __name__ == \"__main__\":\n\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model_path\", type=str, required=True)\n    parser.add_argument(\"--port\", type=int, default=7680)\n    args = parser.parse_args()\n\n    main(args.model_path, args.port)\n"
  },
  {
    "path": "requirements.txt",
    "content": "transformers>=4.44.0\ntokenizers>=0.20.3\naccelerate\ndatasets\ngradio\n"
  },
  {
    "path": "vllm_demo/example_infer_vllm.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: example_infer_vllm.py\n# Description: None\n\nfrom PIL import Image\nfrom vllm import LLM\nfrom vllm import ModelRegistry\nfrom vllm import SamplingParams\n\nfrom megrezo import MegrezOModel\n\nModelRegistry.register_model(\"MegrezO\", MegrezOModel)\n\n# Load the model.\n# model_path = \"{{PATH_TO_HF_PRETRAINED_MODEL}}\"  # Change this to the path of the model.\nmodel_path = \"/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o\"  # Change this to the path of the model.\nllm = LLM(\n    model_path,\n    trust_remote_code=True,\n    gpu_memory_utilization=0.5,\n)\n\nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=1000,\n    repetition_penalty=1.2,\n    stop=[\"<|turn_end|>\", \"<|eos|>\"],\n)\n\nimg = Image.open(\"../data/sample_image.jpg\")\n\nconversation = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": \"图片的内容是什么？\",\n            \"image\": img,\n        },\n    },\n]\n\n# Convert the conversation to vLLM acceptable format.\nprompt = llm.get_tokenizer().apply_chat_template(\n    conversation,\n    tokenize=False,\n    add_generation_prompt=True,\n)\nvllm_inputs = [\n    {\n        \"prompt\": prompt,\n        \"multi_modal_data\": {\n            \"image\": img,\n        },\n    }\n]\n\n# Generate the outputs.\noutputs = llm.generate(\n    vllm_inputs,\n    sampling_params,\n)\n\n# Print the outputs.\nfor output in outputs:\n    print(output.outputs[0].text)\n"
  },
  {
    "path": "vllm_demo/megrezo.py",
    "content": "# coding=utf-8\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py\n# Copyright 2023 The vLLM team.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Inference-only MegrezO model compatible with HuggingFace weights.\"\"\"\n\nfrom functools import lru_cache\nfrom functools import partial\nfrom typing import Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.types\nfrom PIL import Image\nfrom torch import Tensor\nfrom torch import nn\nfrom torch.nn.init import trunc_normal_\nfrom transformers import PretrainedConfig\nfrom vllm.attention import AttentionMetadata\nfrom vllm.config import CacheConfig\nfrom vllm.config import MultiModalConfig\nfrom vllm.inputs import INPUT_REGISTRY\nfrom vllm.inputs import DecoderOnlyInputs\nfrom vllm.inputs import InputContext\nfrom vllm.inputs import token_inputs\nfrom vllm.model_executor.layers.linear import ReplicatedLinear\nfrom vllm.model_executor.layers.logits_processor import LogitsProcessor\nfrom vllm.model_executor.layers.quantization import QuantizationConfig\nfrom vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed\nfrom vllm.model_executor.layers.sampler import Sampler\nfrom vllm.model_executor.layers.sampler import SamplerOutput\nfrom vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead\nfrom vllm.model_executor.model_loader.weight_utils import default_weight_loader\nfrom vllm.model_executor.models import VllmModelForTextGeneration\nfrom vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer\nfrom vllm.model_executor.models.interfaces import SupportsMultiModal\nfrom vllm.model_executor.models.interfaces import SupportsPP\nfrom vllm.model_executor.models.llama import LlamaModel\nfrom vllm.model_executor.models.module_mapping import MultiModelKeys\nfrom vllm.model_executor.models.utils import LLMWrapper\nfrom vllm.model_executor.models.utils import is_pp_missing_parameter\nfrom vllm.model_executor.sampling_metadata import SamplingMetadata\nfrom vllm.multimodal import MULTIMODAL_REGISTRY\nfrom vllm.multimodal.base import MultiModalInputs\nfrom vllm.multimodal.utils import cached_get_tokenizer\nfrom vllm.sequence import IntermediateTensors\nfrom vllm.sequence import SequenceData\nfrom vllm.transformers_utils.processor import get_processor\n\nRawImageType = Union[Image.Image, torch.Tensor]\nRawAudioType = Union[bytes, torch.Tensor]\n\ncached_get_processor = lru_cache(get_processor)\n\n\nclass MegrezORawImageInput(TypedDict):\n    \"\"\"Input mapper input with auxiliary data for computing image bounds.\"\"\"\n\n    image: RawImageType\n\n\nclass MegrezOAudioInput(TypedDict):\n    type: Literal[\"audio\"]\n\n    data: RawAudioType\n\n\nclass MegrezOAudioTensorInput(TypedDict):\n    type: Literal[\"audio_tensor\"]\n\n    input_audios: torch.Tensor\n    input_audio_lengths: torch.Tensor\n    audio_span_tokens: torch.Tensor\n\n\nclass MegrezOImagePixelInputs(TypedDict):\n    type: Literal[\"pixel_values\"]\n    pixel_values: torch.Tensor\n    \"\"\"\n    Shape: `(batch_size * num_images, num_channels, height, width)`\n\n    Note that the image size may vary, so we pass it as a list\n    instead of a batched tensor.\n    \"\"\"\n\n    tgt_sizes: torch.Tensor\n    \"\"\"\n    Shape: `(batch_size * num_images, 2)`\n\n    This should be in `(height, width)` format.\n    \"\"\"\n\n    patch_attention_mask: torch.Tensor\n    \"\"\"\n    Shape: `(batch_size * num_images, num_patches, num_patches)`\n    \"\"\"\n\n\nclass MegrezOImageEmbeddingInputs(TypedDict):\n    type: Literal[\"image_embeds\"]\n    data: torch.Tensor\n    \"\"\"\n    Shape: `(batch_size * num_images, image_feature_size, hidden_size)`\n\n    `hidden_size` must match the hidden size of language model backbone.\n    instead of a batched tensor.\n    \"\"\"\n\n    image_bounds: torch.Tensor\n    \"\"\"\n    Shape: `(batch_size * num_images, 2)`\n\n    This should be in `(start, stop)` format.\n    \"\"\"\n\n\ndef insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):\n\n    inserted_bounds = inserted_bounds.long()\n\n    for idx in range(len(inserted_embeddings)):\n        bid = inserted_bounds[idx][0]\n        start_id = inserted_bounds[idx][1]\n        end_id = inserted_bounds[idx][2]\n        embedding = inserted_embeddings[idx]\n        text_embeddings[start_id + 1 : end_id] = embedding\n    return text_embeddings\n\n\ndef insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):\n\n    inserted_bounds = inserted_bounds.long()\n    for idx in range(len(inserted_embeddings)):\n        bid = inserted_bounds[idx][0]\n        start_id = inserted_bounds[idx][1]\n        end_id = inserted_bounds[idx][2]\n        embedding = inserted_embeddings[idx]\n        text_embeddings[start_id:end_id] = embedding\n\n    return text_embeddings\n\n\nMegrezOImageInputs = Union[MegrezOImagePixelInputs]\nMegrezOAudioInputs = Union[MegrezOAudioTensorInput]\n\n# region: Resampler\nDEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)\n\n\nclass Resampler(nn.Module):\n\n    def __init__(\n        self,\n        num_queries: int,\n        embed_dim: int,\n        num_heads: int,\n        kv_dim: Optional[int] = None,\n        norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,\n        max_size: Tuple[int, int] = (70, 70),\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n\n        self.num_queries = num_queries\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n\n        self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))\n        trunc_normal_(self.query, std=0.02)\n        if kv_dim is not None and kv_dim != embed_dim:\n            self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False, quant_config=quant_config, prefix=prefix)\n        else:\n            # Maintain the same return value with ReplicatedLinear.forward\n            self.kv_proj = lambda *args, **kwargs: (  # type: ignore # noqa\n                nn.Identity()(*args, **kwargs),\n                None,\n            )\n\n        self.attn = nn.MultiheadAttention(embed_dim, num_heads)\n        self.ln_q = norm_layer(embed_dim)\n        self.ln_kv = norm_layer(embed_dim)\n        self.do_post_projection = True\n        self.ln_post = norm_layer(embed_dim)\n        self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))\n\n        self.max_size = max_size\n        self._set_2d_pos_cache(self.max_size)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m: nn.Module) -> None:\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def _repeat(self, query, N: int):\n        return query.unsqueeze(1).repeat(1, N, 1)\n\n    def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.types.Device = \"cpu\") -> None:\n        pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, max_size, version=(2, 5))\n        pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)\n        self.register_buffer(\"pos_embed\", pos_embed, persistent=False)\n\n    def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device) -> None:\n        max_h = tgt_sizes[:, 0].max().item()\n        max_w = tgt_sizes[:, 1].max().item()\n        assert isinstance(max_h, int) and isinstance(max_w, int)\n\n        if max_h > self.max_size[0] or max_w > self.max_size[1]:\n            self.max_size = (\n                max(max_h, self.max_size[0]),\n                max(max_w, self.max_size[1]),\n            )\n            self._set_2d_pos_cache(self.max_size, device)\n\n    def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:\n        assert x.shape[0] == tgt_sizes.shape[0]\n        bs = x.shape[0]\n\n        device = x.device\n        dtype = x.dtype\n\n        patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]\n\n        self._adjust_pos_cache(tgt_sizes, device=device)\n\n        max_patch_len = patch_len.max().item()\n        assert isinstance(max_patch_len, int)\n\n        key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)\n\n        pos_embed = []\n        for i in range(bs):\n            tgt_h, tgt_w = tgt_sizes[i].tolist()\n            pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype))  # patches * D\n            key_padding_mask[i, patch_len[i] :] = True\n        pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(\n            1, 0, 2\n        )  # BLD => L * B * D\n        x, _ = self.kv_proj(x)  # B * L * D\n        x = self.ln_kv(x).permute(1, 0, 2)  # L * B * D\n\n        q = self.ln_q(self.query)  # Q * D\n\n        out = self.attn(\n            self._repeat(q, bs),  # Q * B * D\n            x + pos_embed,  # L * B * D +  L * B * D\n            x,\n            key_padding_mask=key_padding_mask,\n        )[0]\n        #  out: Q * B * D\n        x = out.permute(1, 0, 2)  # B * Q * D\n\n        x = self.ln_post(x)\n        x = x @ self.proj\n        return x\n\n\n# endregion\n\n# region: AudioEncoder\n\n\nclass LayerNorm(nn.LayerNorm):\n    def forward(self, x: Tensor) -> Tensor:\n        # return super().forward(x.float()).type(x.dtype)\n        return super().forward(x).type(x.dtype)\n\n\nclass Linear(nn.Linear):\n    def forward(self, x: Tensor) -> Tensor:\n        return F.linear(\n            x,\n            self.weight.to(x.dtype),\n            None if self.bias is None else self.bias.to(x.dtype),\n        )\n\n\nclass Conv1d(nn.Conv1d):\n    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:\n        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))\n\n\ndef sinusoids(length, channels, max_timescale=10000):\n    \"\"\"Returns sinusoids for positional embedding\"\"\"\n    assert channels % 2 == 0\n    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)\n    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))\n    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]\n    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)\n\n\nclass MultiHeadAttention(nn.Module):\n    def __init__(self, n_state: int, n_head: int):\n        super().__init__()\n        self.n_head = n_head\n        self.query = Linear(n_state, n_state)\n        self.key = Linear(n_state, n_state, bias=False)\n        self.value = Linear(n_state, n_state)\n        self.out = Linear(n_state, n_state)\n\n    def forward(\n        self,\n        x: Tensor,\n        xa: Optional[Tensor] = None,\n        mask: Optional[Tensor] = None,\n        kv_cache: Optional[dict] = None,\n    ):\n        q = self.query(x)\n\n        if kv_cache is None or xa is None or self.key not in kv_cache:\n            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;\n            # otherwise, perform key/value projections for self- or cross-attention as usual.\n            k = self.key(x if xa is None else xa)\n            v = self.value(x if xa is None else xa)\n        else:\n            # for cross-attention, calculate keys and values once and reuse in subsequent calls.\n            k = kv_cache[self.key]\n            v = kv_cache[self.value]\n\n        wv, qk = self.qkv_attention(q, k, v, mask)\n        return self.out(wv), qk\n\n    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):\n        n_batch, n_ctx, n_state = q.shape\n        scale = (n_state // self.n_head) ** -0.25\n        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale\n        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale\n        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n\n        qk = q @ k\n        if mask is not None:\n            qk += mask\n\n        w = F.softmax(qk, dim=-1).to(q.dtype)\n        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):\n        super().__init__()\n\n        self.attn = MultiHeadAttention(n_state, n_head)\n        self.attn_ln = LayerNorm(n_state)\n\n        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None\n        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None\n\n        n_mlp = n_state * 4\n        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))\n        self.mlp_ln = LayerNorm(n_state)\n\n    def forward(\n        self,\n        x: Tensor,\n        xa: Optional[Tensor] = None,\n        mask: Optional[Tensor] = None,\n        kv_cache: Optional[dict] = None,\n    ):\n        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]\n        if self.cross_attn:\n            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]\n        x = x + self.mlp(self.mlp_ln(x))\n        return x\n\n\nclass AudioEncoder(nn.Module):\n    def __init__(\n        self,\n        n_mels: int,\n        n_ctx: int,\n        n_state: int,\n        n_head: int,\n        n_layer: int,\n        output_dim: int = 512,\n        avg_pool: bool = True,\n        add_audio_bos_eos_token: bool = True,\n        **kwargs,\n    ):\n        super().__init__()\n        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)\n        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)\n        # self.register_buffer(\"positional_embedding\", sinusoids(n_ctx, n_state))\n        self.positional_embedding = nn.Parameter(sinusoids(n_ctx, n_state), requires_grad=False)\n\n        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(\n            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]\n        )\n        self.ln_post = LayerNorm(n_state)\n\n        if avg_pool:\n            self.avg_pooler = nn.AvgPool1d(2, stride=2)\n        else:\n            self.avg_pooler = None\n        self.proj = nn.Linear(n_state, output_dim)\n        if add_audio_bos_eos_token:\n            self.audio_bos_eos_token = nn.Embedding(2, output_dim)\n        else:\n            self.audio_bos_eos_token = None\n        self.output_dim = output_dim\n        self.n_head = n_head\n\n    def forward(self, x: Tensor, padding_mask: Tensor = None, audio_lengths: Tensor = None):\n        \"\"\"\n        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)\n            the mel spectrogram of the audio\n        \"\"\"\n        x = x.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)\n        if audio_lengths is not None:\n            input_mel_len = audio_lengths[:, 0] * 2\n            max_mel_len_in_batch = input_mel_len.max()\n            x = x[:, :, :max_mel_len_in_batch]\n        x = F.gelu(self.conv1(x))\n        x = F.gelu(self.conv2(x))\n        x = x.permute(0, 2, 1)  # B, L, D\n        bsz = x.size(0)\n        src_len = x.size(1)\n\n        self.input_positional_embedding = self.positional_embedding[:src_len]\n        assert (\n            x.shape[1:] == self.input_positional_embedding.shape\n        ), f\"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}\"\n        x = (x + self.input_positional_embedding).to(x.dtype)\n        if padding_mask is not None:\n            padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)\n            batch_src_len = padding_mask.size(1)\n            x = x[:, :batch_src_len, :]\n            padding_mask = padding_mask.view(bsz, -1, batch_src_len)\n            padding_mask_ = padding_mask.all(1)\n            x[padding_mask_] = 0\n            key_padding_mask = (\n                padding_mask_.view(bsz, 1, 1, batch_src_len)\n                .expand(-1, self.n_head, -1, -1)\n                .reshape(bsz, self.n_head, 1, batch_src_len)\n            )\n            new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)\n            padding_mask = new_padding_mask.masked_fill(key_padding_mask, float(\"-inf\"))\n\n        for block in self.blocks:\n            x = block(x, mask=padding_mask)\n\n        if self.avg_pooler:\n            x = x.permute(0, 2, 1)\n            x = self.avg_pooler(x)\n            x = x.permute(0, 2, 1)\n\n        x = self.ln_post(x)\n        x = self.proj(x)\n\n        if self.audio_bos_eos_token is not None:\n            bos = self.audio_bos_eos_token.weight[0][None, :]\n            eos = self.audio_bos_eos_token.weight[1][None, :]\n        else:\n            bos, eos = None, None\n        return x, bos, eos\n\n    def encode(\n        self,\n        input_audios: Tensor,\n        input_audio_lengths: Tensor,\n        audio_span_tokens: List,\n    ):\n        real_input_audio_lens = input_audio_lengths[:, 0].tolist()\n        max_len_in_batch = max(real_input_audio_lens)\n        padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(\n            dtype=self.conv1.weight.dtype, device=self.conv1.weight.device\n        )\n        for index in range(len(input_audios)):\n            padding_mask[index, : input_audio_lengths[index][0].item()] = 0\n        x, bos, eos = self(input_audios, padding_mask, input_audio_lengths)\n        output_audios = []\n        for i in range(len(audio_span_tokens)):\n            audio_span = audio_span_tokens[i]\n            audio = x[i][: audio_span - 2]\n            if bos is not None:\n                audio = torch.concat([bos, audio, eos])\n            assert len(audio) == audio_span\n            output_audios.append(audio)\n        return output_audios\n\n\nclass AudioModel(torch.nn.Module):\n\n    def __init__(self, config):\n        super(AudioModel, self).__init__()\n        self.config = config\n        self.audio = AudioEncoder(**config.audio_config.to_dict())\n\n    def forward(self, audio_info):\n        audios = audio_info[\"input_audios\"][0]\n        input_audio_lengths = audio_info[\"input_audio_lengths\"][0]\n        audio_span_tokens = audio_info[\"audio_span_tokens\"][0]\n        audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens)\n        return audios_features\n\n\n# endregion\n\n\ndef get_max_megrezo_image_tokens(ctx: InputContext):\n    hf_config = ctx.get_hf_config()\n    return getattr(hf_config, \"query_num\", 64) * 10\n\n\ndef dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):\n    return SequenceData.from_prompt_token_counts((0, seq_len))\n\n\ndef dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, num_images: int):\n    width = height = hf_config.vision_config.image_size\n    imgs = [MegrezORawImageInput(image=Image.new(\"RGB\", (width, height), color=0)) for _ in range(num_images)]\n    return {\"image\": imgs}\n\n\ndef dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):\n    hf_config = ctx.get_hf_config()\n    num_images = mm_counts[\"image\"]\n\n    seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)\n    mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)  # skip audio for now\n    return (seq_data, mm_data)\n\n\ndef input_processor_for_megrezo(ctx: InputContext, inputs: DecoderOnlyInputs):\n    multi_modal_data = inputs.get(\"multi_modal_data\")\n    if multi_modal_data is None or (\"image\" not in multi_modal_data and \"audio\" not in multi_modal_data):\n        return inputs\n\n    model_config = ctx.model_config\n    tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=model_config.trust_remote_code)\n    processor = cached_get_processor(model_config.model, trust_remote_code=model_config.trust_remote_code)\n\n    prompt = inputs.get(\"prompt\")\n    token_ids = inputs.get(\"prompt_token_ids\")\n    if prompt is None:\n        prompt = tokenizer.decode(token_ids)\n\n    images = multi_modal_data.get(\"image\")\n    audios = multi_modal_data.get(\"audio\")\n    prompt, multimodal_inputs = processor.process_multimodal_inputs(\n        prompt,\n        images=images,\n        audios=audios,\n        return_tensors=\"pt\",\n    )\n    text_encodings = tokenizer(\n        prompt,\n        return_tensors=\"pt\",\n        padding=True,\n        padding_side=\"left\",\n    )\n    encodings = processor.merge_encodings(text_encodings, multimodal_inputs)\n    data = processor.data_collator([encodings])\n\n    new_prompt = tokenizer.decode(data[\"input_ids\"][0])\n    new_multi_modal_data = {\n        \"image\": data[\"image_encoding\"],\n        \"audio\": data[\"audio_encoding\"],\n    }\n\n    return token_inputs(\n        prompt_token_ids=data[\"input_ids\"][0],\n        prompt=new_prompt,\n        multi_modal_data=new_multi_modal_data,\n    )\n\n\ndef input_mapper_for_megrezo(ctx: InputContext, data: object):\n    return MultiModalInputs(data)\n\n\n@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_megrezo)\n@MULTIMODAL_REGISTRY.register_input_mapper(\"audio\", input_mapper_for_megrezo)\n@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(\"audio\", 3000)\n@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_megrezo_image_tokens)\n@INPUT_REGISTRY.register_input_processor(input_processor_for_megrezo)\nclass MegrezOModel(nn.Module, VllmModelForTextGeneration, SupportsMultiModal, SupportsPP):\n\n    packed_modules_mapping = {\n        \"qkv_proj\": [\"q_proj\", \"k_proj\", \"v_proj\"],\n        \"gate_up_proj\": [\"gate_proj\", \"up_proj\"],\n    }\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        multimodal_config: MultiModalConfig,\n        cache_config: Optional[CacheConfig] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n        # All MiniCPM-V models disable `tie_word_embeddings` but\n        # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot\n        # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model\n        # and config class\n        self.config = config\n        self.multimodal_config = multimodal_config\n\n        self.llm = self.init_llm(config, cache_config, quant_config, prefix=\"model\")\n        self.vision = self.init_vision_module(config, quant_config, prefix=\"vpm\")\n        param_dtype = torch.get_default_dtype()\n        self.vision.to(dtype=param_dtype)\n\n        self.audio = self.init_audio_module(config, quant_config)\n        self.audio.to(dtype=param_dtype)\n\n        self.vision_dim = self.vision.embeddings.embed_dim\n        self.embed_dim = self.config.hidden_size\n        self.resampler = self.init_resampler(\n            self.embed_dim, self.vision_dim, quant_config=quant_config, prefix=\"vision.resampler\"\n        )\n        self.resampler.to(device=\"cuda\", dtype=param_dtype)\n        self.lm_head = ParallelLMHead(\n            config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=\"llm.lm_head\"\n        )\n        self.logits_processor = LogitsProcessor(config.vocab_size)\n        self.sampler = Sampler()\n\n        self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors\n\n        self._called_cnt = 0\n\n    def get_vision_hidden_states(\n        self,\n        pixel_values,\n        tgt_sizes,\n        patch_attn_mask,\n    ) -> torch.Tensor:\n\n        device = self.vision.embeddings.position_embedding.weight.device\n        dtype = self.vision.embeddings.position_embedding.weight.dtype\n        pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in pixel_values]).type(dtype)\n        vision_embedding = self.vision(\n            pixel_values.type(dtype),\n            patch_attention_mask=patch_attn_mask,\n            tgt_sizes=tgt_sizes,\n        )\n\n        return self.resampler(vision_embedding, tgt_sizes)\n\n    def compose_embeddings(self, mini_batch):\n        input_ids = mini_batch[\"input_ids\"]\n        image_encoding = mini_batch.get(\"image_encoding\")\n        audio_encoding = mini_batch.get(\"audio_encoding\")\n\n        embeddings_text = self.llm.model.embed_tokens(input_ids)\n        input_embeds = embeddings_text\n        if image_encoding:\n            pixel_values = image_encoding[\"pixel_values\"][0]\n            tgt_sizes = image_encoding[\"tgt_sizes\"][0]\n            patch_attention_mask = image_encoding[\"patch_attention_mask\"][0]\n            bounds_image = image_encoding[\"image_bounds\"][0]\n            device = self.vision.embeddings.position_embedding.weight.device\n            dtype = self.vision.embeddings.position_embedding.weight.dtype\n\n            embeddings_image = self.get_vision_hidden_states(\n                pixel_values.to(device, dtype),\n                tgt_sizes,\n                patch_attention_mask.to(device),\n            )\n            input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image)\n\n        if audio_encoding:\n            embeddings_audio = self.audio(audio_encoding)\n            bounds_audio = audio_encoding[\"audio_bounds\"][0]\n            input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio)\n\n        return input_embeds\n\n    def _parse_inputs(self, input_ids: torch.Tensor, **kwargs):\n        if kwargs.get(\"pixel_values\") is not None:\n            image_encoding = {\n                \"pixel_values\": kwargs.get(\"pixel_values\"),\n                \"tgt_sizes\": kwargs.get(\"tgt_sizes\"),\n                \"patch_attention_mask\": kwargs.get(\"patch_attention_mask\"),\n                \"image_bounds\": kwargs.get(\"image_bounds\"),\n            }\n        else:\n            image_encoding = None\n\n        if kwargs.get(\"input_audios\") is not None:\n            audio_encoding = {\n                \"input_audios\": kwargs.get(\"input_audios\"),\n                \"input_audio_lengths\": kwargs.get(\"input_audio_lengths\"),\n                \"audio_span_tokens\": kwargs.get(\"audio_span_tokens\"),\n                \"audio_bounds\": kwargs.get(\"audio_bounds\"),\n            }\n        else:\n            audio_encoding = None\n\n        return {\n            \"input_ids\": input_ids,\n            \"image_encoding\": image_encoding,\n            \"audio_encoding\": audio_encoding,\n        }\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        positions: torch.Tensor,\n        kv_caches: List[torch.Tensor],\n        attn_metadata: AttentionMetadata,\n        intermediate_tensors: Optional[IntermediateTensors] = None,\n        **kwargs: Any,\n    ) -> torch.Tensor:\n        if intermediate_tensors is not None:\n            embeddings = None\n        else:\n            mini_batch = self._parse_inputs(input_ids, **kwargs)\n            embeddings = self.compose_embeddings(mini_batch)\n\n        # always pass the input via `inputs_embeds`\n        # to make sure the computation graph is consistent\n        # for `torch.compile` integration\n        input_ids = None\n\n        output = self.llm(\n            input_ids=input_ids,\n            positions=positions,\n            kv_caches=kv_caches,\n            attn_metadata=attn_metadata,\n            intermediate_tensors=intermediate_tensors,\n            inputs_embeds=embeddings,\n        )\n\n        self._called_cnt += 1\n        return output\n\n    def compute_logits(\n        self,\n        hidden_states: torch.Tensor,\n        sampling_metadata: SamplingMetadata,\n    ) -> Optional[torch.Tensor]:\n        logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)\n        return logits\n\n    def sample(\n        self,\n        logits: torch.Tensor,\n        sampling_metadata: SamplingMetadata,\n    ) -> Optional[SamplerOutput]:\n        next_tokens = self.sampler(logits, sampling_metadata)\n        return next_tokens\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", 0),\n            (\".gate_up_proj\", \".up_proj\", 1),\n        ]\n\n        keys_to_modify_mapping = {\n            \"llm.lm_head\": \"lm_head\",\n            \"vision.resampler\": \"resampler\",\n        }\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            for key_to_modify, new_key in keys_to_modify_mapping.items():\n                if key_to_modify in name:\n                    name = name.replace(key_to_modify, new_key)\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if \"rotary_emb.cos_cached\" in name or \"rotary_emb.sin_cached\" in name:\n                # Models trained using ColossalAI may include these tensors in\n                # the checkpoint. Skip them.\n                continue\n            # if \"audio.positional_embedding\" in name:\n            #     continue\n\n            for param_name, weight_name, shard_id in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if is_pp_missing_parameter(name, self):\n                    continue\n\n                if name in params_dict:\n                    param = params_dict[name]\n                    weight_loader = param.weight_loader\n                    weight_loader(param, loaded_weight, shard_id)\n                else:\n                    print(f\"Skipping loading of {name}\")\n\n                break\n            else:\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name is None:\n                    continue\n\n                if is_pp_missing_parameter(name, self):\n                    continue\n\n                if name in params_dict:\n                    param = params_dict[name]\n                    weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                    weight_loader(param, loaded_weight)\n                else:\n                    print(f\"Skipping loading of {name}\")\n\n    def get_mm_mapping(self) -> MultiModelKeys:\n        \"\"\"\n        Get the module prefix in multimodal models\n        \"\"\"\n        return MultiModelKeys.from_string_field(language_model=\"llm\", connector=\"resampler\", tower_model=\"vpm\")\n\n    def init_llm(\n        self,\n        config: PretrainedConfig,\n        cache_config: Optional[CacheConfig] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> nn.Module:\n\n        return LLMWrapper(\n            LlamaModel(\n                config,\n                cache_config=cache_config,\n                quant_config=quant_config,\n                prefix=prefix,\n            ),\n            name=prefix,\n        )\n\n    def init_audio_module(\n        self,\n        config: PretrainedConfig,\n        quant_config: Optional[QuantizationConfig],\n        prefix: str = \"\",\n    ) -> nn.Module:\n        return AudioModel(config)\n\n    def init_vision_module(\n        self,\n        config: PretrainedConfig,\n        quant_config: Optional[QuantizationConfig],\n        prefix: str = \"\",\n    ) -> nn.Module:\n        model = LLMWrapper(\n            Idefics2VisionTransformer(config.vision_config),\n            name=prefix,\n        )\n        if self.config.drop_vision_last_layer:\n            model.encoder.layers = model.encoder.layers[:-1]\n        return model\n\n    def init_resampler(\n        self,\n        embed_dim: int,\n        vision_dim: int,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> nn.Module:\n        resampler = Resampler(\n            num_queries=self.config.query_num,\n            embed_dim=embed_dim,\n            num_heads=embed_dim // 128,\n            kv_dim=vision_dim,\n            quant_config=quant_config,\n            prefix=prefix,\n        )\n        return resampler\n"
  },
  {
    "path": "vllm_demo/requirements.txt",
    "content": "vllm==0.6.3.post1\nflash_attn==2.5.8\nxformers==0.0.27.post2"
  },
  {
    "path": "vllm_demo/try_minicpm_v.py",
    "content": "from transformers import AutoTokenizer\nfrom PIL import Image\nfrom vllm import LLM, SamplingParams\n\nMODEL_NAME = \"/mnt/public/algm/models/MiniCPM-V-2_6/\"\n\n\nimage = Image.open(\"../data/sample_image.jpg\").convert(\"RGB\")\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\nllm = LLM(\n    model=MODEL_NAME,\n    trust_remote_code=True,\n    gpu_memory_utilization=1,\n    max_model_len=2048\n)\n\nmessages = [{\n    \"role\":\n    \"user\",\n    \"content\":\n    # Number of images\n    \"(<image>./</image>)\" + \\\n    \"\\nWhat is the content of this image?\" \n}]\nprompt = tokenizer.apply_chat_template(\n    messages,\n    tokenize=False,\n    add_generation_prompt=True\n)\n\n# Single Inference\ninputs = {\n    \"prompt\": prompt,\n    \"multi_modal_data\": {\n        \"image\": image\n        # Multi images, the number of images should be equal to that of `(<image>./</image>)`\n        # \"image\": [image, image] \n    },\n}\n# Batch Inference\n# inputs = [{\n#     \"prompt\": prompt,\n#     \"multi_modal_data\": {\n#         \"image\": image\n#     },\n# } for _ in 2]\n\n\n# 2.6\nstop_tokens = ['<|im_end|>', '<|endoftext|>']\nstop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]\n\nsampling_params = SamplingParams(\n    stop_token_ids=stop_token_ids, \n    use_beam_search=True,\n    temperature=0, \n    best_of=3,\n    max_tokens=1024\n)\n\noutputs = llm.generate(inputs, sampling_params=sampling_params)\n\nprint(outputs[0].outputs[0].text)"
  },
  {
    "path": "vllm_demo/try_qwen_vl.py",
    "content": "from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor\nfrom qwen_vl_utils import process_vision_info\n\n# default: Load the model on the available device(s)\nmodel = Qwen2VLForConditionalGeneration.from_pretrained(\n    \"/mnt/public/algm/models/Qwen2-VL-2B-Instruct\", torch_dtype=\"auto\", device_map=\"auto\"\n)\n\n\n# default processer\nprocessor = AutoProcessor.from_pretrained(\"/mnt/public/algm/models/Qwen2-VL-2B-Instruct\")\n\n\nprompt = \"hi\" * (128 - 1) \nmessages = [\n    {\n        \"role\": \"user\",\n        \"content\": [\n            {\n                \"type\": \"image\",\n                \"image\": \"../data/sample_image.jpg\",\n            },\n            {\"type\": \"text\", \"text\": prompt},\n        ],\n    }\n]\n\n# Preparation for inference\ntext = processor.apply_chat_template(\n    messages, tokenize=False, add_generation_prompt=True\n)\nimage_inputs, video_inputs = process_vision_info(messages)\nimport pdb;pdb.set_trace()\ninputs = processor(\n    text=[text],\n    images=image_inputs,\n    videos=video_inputs,\n    padding=True,\n    return_tensors=\"pt\",\n)\ninputs = inputs.to(\"cuda\")\nimport pdb;pdb.set_trace()\n# Inference: Generation of the output\ngenerated_ids = model.generate(**inputs, max_new_tokens=128)\ngenerated_ids_trimmed = [\n    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n]\noutput_text = processor.batch_decode(\n    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n)\nprint(output_text)"
  },
  {
    "path": "vllm_demo/vllm_profling.py",
    "content": "# -*- encoding: utf-8 -*-\n# File: example_infer_vllm.py\n# Description: None\n\nfrom PIL import Image\nfrom vllm import LLM\nfrom vllm import ModelRegistry\nfrom vllm import SamplingParams\n\nfrom megrezo import MegrezOModel\n\nModelRegistry.register_model(\"MegrezO\", MegrezOModel)\n\n# Load the model.\nmodel_path = \"/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o\"  # Change this to the path of the model.\nllm = LLM(\n    model_path,\n    trust_remote_code=True,\n    gpu_memory_utilization=0.9,\n    max_num_seqs=8,\n)\n\nnum_requests = 100\ninput_len = 128\noutput_length = 128\n# prepare data \nprompt = \"hi\" * (input_len - 1) \nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=output_length,\n    repetition_penalty=1.2,\n    stop=[\"<|turn_end|>\", \"<|eos|>\"],\n    ignore_eos=True,\n)\n\nimg = Image.open(\"../data/sample_image.jpg\")\n\nconversation = [\n    {\n        \"role\": \"user\",\n        \"content\": {\n            \"text\": prompt,\n            \"image\": img,\n        },\n    },\n]\n\n# Convert the conversation to vLLM acceptable format.\nprompt = llm.get_tokenizer().apply_chat_template(\n    conversation,\n    tokenize=False,\n    add_generation_prompt=True,\n)\nvllm_inputs = [\n    {\n        \"prompt\": prompt,\n        \"multi_modal_data\": {\n            \"image\": img,\n        },\n    }\n    for _ in range(num_requests)\n]\n\n# Generate the outputs.\noutputs = llm.generate(\n    vllm_inputs,\n    sampling_params,\n)\n\n# Print the outputs.\n# for output in outputs:\n#     print(output.outputs[0].text)\n"
  },
  {
    "path": "vllm_demo/vllm_profling_minicpm.py",
    "content": "from transformers import AutoTokenizer\nfrom PIL import Image\nfrom vllm import LLM, SamplingParams\n\n\nmodel_path = \"/mnt/public/algm/models/MiniCPM-V-2_6/\"\nimage = Image.open(\"../data/sample_image.jpg\").convert(\"RGB\")\ntokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\nllm = LLM(\n    model=model_path,\n    gpu_memory_utilization=0.9,\n    max_num_seqs=8,\n    trust_remote_code=True,\n    max_model_len=4096\n)\n\nnum_requests = 100\ninput_len = 128\noutput_length = 128\n# prepare data \nprompt = \"hi\" * (input_len - 1) \nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=output_length,\n    repetition_penalty=1.2,\n    ignore_eos=True,\n)\n\n\nmessages = [{\n    \"role\":\n    \"user\",\n    \"content\":\n    # Number of images\n    \"(<image>./</image>)\" + \\\n    prompt\n}]\nprompt = tokenizer.apply_chat_template(\n    messages,\n    tokenize=False,\n    add_generation_prompt=True\n)\n\n# Single Inference\nllm_inputs = [{\n    \"prompt\": prompt,\n    \"multi_modal_data\": {\n        \"image\": image\n    },\n} for _ in range(num_requests)]\n\n\n\n\n\noutputs = llm.generate(llm_inputs, sampling_params=sampling_params)\n"
  },
  {
    "path": "vllm_demo/vllm_profling_qwen.py",
    "content": "from transformers import AutoProcessor\nfrom vllm import LLM, SamplingParams\nfrom qwen_vl_utils import process_vision_info\n\n\n# Load the model.\nmodel_path = \"/mnt/public/algm/models/Qwen2-VL-2B-Instruct\"  # Change this to the path of the model.\n\nllm = LLM(\n    model=model_path,\n    limit_mm_per_prompt={\"image\": 10, \"video\": 10},\n    gpu_memory_utilization=0.9,\n    max_num_seqs=8,\n)\n\nnum_requests = 100\ninput_len = 128\noutput_length = 128\n# prepare data \nprompt = \"hi\" * (input_len - 1) \nsampling_params = SamplingParams(\n    temperature=0,\n    max_tokens=output_length,\n    repetition_penalty=1.2,\n    ignore_eos=True,\n)\n\nmessages = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\n        \"role\": \"user\",\n        \"content\": [\n            {\n                \"type\": \"image\",\n                \"image\": \"../data/sample_image.jpg\",\n                \"min_pixels\": 224 * 224,\n                \"max_pixels\": 1024 * 1024,\n            },\n            {\"type\": \"text\", \"text\": prompt},\n        ],\n    },\n]\nprocessor = AutoProcessor.from_pretrained(model_path)\nprompt = processor.apply_chat_template(\n    messages,\n    tokenize=False,\n    add_generation_prompt=True,\n)\nimage_inputs, video_inputs = process_vision_info(messages)\n\nmm_data = {}\nif image_inputs is not None:\n    mm_data[\"image\"] = image_inputs\nif video_inputs is not None:\n    mm_data[\"video\"] = video_inputs\n\nllm_inputs = [\n        {\n        \"prompt\": prompt,\n        \"multi_modal_data\": mm_data,\n    }\n    for _ in range(num_requests)\n]\n\n\noutputs = llm.generate(llm_inputs, sampling_params=sampling_params)\n"
  }
]