Repository: infinigence/Infini-Megrez-Omni
Branch: main
Commit: 031e8160e384
Files: 21
Total size: 110.6 KB
Directory structure:
gitextract_nm6z7f7i/
├── LICENSE
├── README.md
├── README_zh.md
├── data/
│ └── train/
│ └── records.jsonl
├── example_chat_hf.py
├── finetune/
│ ├── dataset.py
│ ├── ds_config_zero2.json
│ ├── finetune.py
│ ├── finetune.sh
│ ├── requirements.txt
│ └── trainer.py
├── gradio_app.py
├── requirements.txt
└── vllm_demo/
├── example_infer_vllm.py
├── megrezo.py
├── requirements.txt
├── try_minicpm_v.py
├── try_qwen_vl.py
├── vllm_profling.py
├── vllm_profling_minicpm.py
└── vllm_profling_qwen.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2024 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
# Megrez-3B-Omni: The First Open-Source End-Side Full Modality Understanding Model
<p align="center">
<img src="assets/megrez_logo.png" width="400"/>
<p>
<p align="center">
📄 <a href="assets/Megrez_Omni_Technical_Report.pdf">Paper</a>
🤗 <a href="https://huggingface.co/Infinigence/Megrez-3B-Omni">Huggingface</a>   |   🤖<a href="https://www.modelscope.cn/models/InfiniAI/Megrez-3B-Omni">Modelscope</a>   |   🖥️ <a href="https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni">Demo</a>   |   📖 <a href="assets/wechat-official.jpg">WeChat Official</a>   |   💬 <a href="assets/wechat-group.jpg">WeChat Groups</a>  
</p>
<strong>[中文](./README_zh.md) | English</strong>
</div>
## Introduction
**Megrez-3B-Omni** is an on-device multimodal understanding LLM model developed by **Infinigence AI** ([Infinigence AI](https://cloud.infini-ai.com/platform/ai)). It is an extension of the Megrez-3B-Instruct model and supports analysis of image, text, and audio modalities. The model achieves state-of-the-art accuracy in all three domains:
- Image Understanding: By utilizing SigLip-400M for constructing image tokens, Megrez-3B-Omni outperforms models with more parameters such as LLaVA-NeXT-Yi-34B. It is one of the best image understanding models among multiple mainstream benchmarks, including MME, MMMU, and OCRBench. It demonstrates excellent performance in tasks such as scene understanding and OCR.
- Language Understanding: Megrez-3B-Omni retains text understanding capabilities without significant trade-offs. Compared to its single-modal counterpart (Megrez-3B-Instruct), the accuracy variation is less than 2%, maintaining state-of-the-art performance on benchmarks like C-EVAL, MMLU/MMLU Pro, and AlignBench. It also outperforms previous-generation models with 14B parameters.
- Speech Understanding: Equipped with the encoder head of Qwen2-Audio/whisper-large-v3, the model supports both Chinese and English speech input, multi-turn conversations, and voice-based questions about input images. It can directly respond to voice commands with text and achieved leading results across multiple benchmarks.
## Evaluation
- The left image compares the performance of Megrez-3B-Omni with other open-source models on mainstream image multimodal tasks.
- The right image shows the performance of Megrez-3B-Omni on the OpenCompass test set. Image reference: [InternVL 2.5 Blog Post](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/).
You can find detailed accuracy metrics on the [Megrez-3B-Omni-HF](https://huggingface.co/Infinigence/Megrez-3B-Omni) page.
<div style="display: flex; justify-content: space-between;">
<img src="assets/multitask.jpg" alt="Comparison of Image Understanding Capabilities" style="width: 45%;">
<img src="assets/opencompass.jpg" alt="OpenCompass Benchmark Performance" style="width: 45%;">
</div>
### Inference Speed
| | image_tokens | prefill (tokens/s) | decode (tokens/s) |
|----------------|:------------:|:------------------:|:-----------------:|
| Megrez-3B-Omni | 448 | 6312.66 | 1294.9 |
| Qwen2-VL-2B | 1378 | 7349.39 | 685.66 |
| MiniCPM-V-2_6 | 448 | 2167.09 | 452.51 |
Setup:
- The testing environment utilizes an NVIDIA H100 GPU with vLLM. Each test includes 128 text tokens and a 720×1480 image as input, producing 128 output tokens, with `num_seqs` fixed at 8.
- Under this setup, the decoding speed of Qwen2-VL-2B is slower than Megrez-3B-Omni, despite having a smaller base LLM. This is due to the larger number of image tokens generated when encoding images of the specified size, which impacts actual inference speed.
## Model Demo
【GIF】
## Install
Install runtime dependencies with the following command:
```shell
pip install -r requirements.txt
```
The audio-related functionality relies on **FFmpeg** for audio processing. If you are using a Debian or Debian-based system, you can install FFmpeg with the following command:
```bash
sudo apt-get install ffmpeg
```
For other operating systems, please refer to the [official FFmpeg documentation](https://ffmpeg.org/download.html) for installation instructions.
## Inference
### Conversation with Multimodal Data
You can use the following script to chat with our model. Note that you should replace `PATH_TO_PRETRAINED_MODEL` with the path to the downloaded model checkpoint.
```python
import torch
from transformers import AutoModelForCausalLM
path = "{{PATH_TO_PRETRAINED_MODEL}}" # Change this to the path of the model.
model = (
AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
messages = [
{
"role": "user",
"content": {
"text": "Please describe the content of the image.",
"image": "./data/sample_image.jpg",
},
},
]
MAX_NEW_TOKENS = 100
response = model.chat(
messages,
sampling=False,
max_new_tokens=MAX_NEW_TOKENS,
)
print(response)
```
You can also find a complete script in [example_chat_hf.py](example_chat_hf.py).
### Inference with vLLM
We provide a reference implementation of inference with vLLM framework. You can find the model definition in [vllm_demo/megrezo.py](vllm_demo/megrezo.py).
1. Install vLLM
```shell
pip install vllm==0.6.3.post1 flash_attn==2.5.8 xformers==0.0.27.post2
```
**Note**: To use vLLM for inference, it is essential to install specific versions of the dependencies. Other versions may lead to interface incompatibility risks. If you encounter any issues, feel free to [open an issue](https://github.com/infinigence/Infini-Megrez-Omni/issues/new).
2. Run the inference script
Since vLLM does not officially support MegrezO yet, you need to import the module first:
```python
from vllm import ModelRegistry
from megrezo import MegrezOModel
ModelRegistry.register_model("MegrezO", MegrezOModel)
```
Then, you can run inference with the following code:
```python
from PIL import Image
from vllm import LLM
from vllm import SamplingParams
# Load the model.
model_path = "{{PATH_TO_HF_PRETRAINED_MODEL}}" # Change this to the path of the model.
llm = LLM(
model_path,
trust_remote_code=True,
gpu_memory_utilization=0.5,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=1000,
repetition_penalty=1.2,
stop=["<|turn_end|>", "<|eos|>"],
)
img = Image.open("../data/sample_image.jpg")
conversation = [
{
"role": "user",
"content": {
"text": "图片的内容是什么?",
"image": img,
},
},
]
# Convert the conversation to vLLM acceptable format.
prompt = llm.get_tokenizer().apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
vllm_inputs = [
{
"prompt": prompt,
"multi_modal_data": {
"image": img,
},
}
]
# Generate the outputs.
outputs = llm.generate(
vllm_inputs,
sampling_params,
)
# Print the outputs.
for output in outputs:
print(output.outputs[0].text)
```
You can find a complete script in [vllm_demo/example_infer_vllm.py](vllm_demo/example_infer_vllm.py).
## Chat with MegrezO using Gradio
We provide online and local demos powered by Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>.
### WebUI Demonstration
<div align="center" style="display: flex; justify-content: space-between;">
<img src="assets/gradio_demo.jpg" style="width: 80%;">
</div>
### Online Demo
Please try out our online Demo here: [🤗Megrez-3B-Omni](https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni)
### Local WebUI Demo
You can easily deploy your own local WebUI to chat with MegrezO using Gradio.
1. Install dependencies:
```shell
pip install -r requirements.txt
```
2. Launch the Gradio app.
You need to specify the `model_path` and `port` in the command line. The `model_path` is the path to the model checkpoint, and the `port` is the port number for the local server. By default, the `port` is `7860`.
```shell
python gradio_app.py --model_path {model_path} --port {port}
```
Then, you can visit `http://localhost:7860` in your browser to interact with the model.
Feel free to modify the `gradio_app.py` to customize the input and output interfaces. For more information, please refer to the [Gradio documentation](https://gradio.app/docs).
## Fine-Tuning the Model
We provide a [fine-tuning example](./finetune/) based on [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [accelerate](https://github.com/huggingface/accelerate).
### Data Preparation
We have constructed a sample dataset based on [ALLaVA-4V/allava_laion](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/tree/main/allava_laion) dataset:
- **Dialogue**: [data/train/records.jsonl](./data/train/records.jsonl)
- **Images**: [data/train/images](./data/train/images)
- **Audio**: [data/train/audio](./data/train/audio), created by converting dialogue text into speech using TTS.
You can also prepare your own dataset following the same format.
### Dependencies Installation
Install the required dependencies with the following command:
```bash
pip install deepspeed accelerate
```
### Full-Parameter Fine-Tuning
To run the fine-tuning example, execute the following commands. Be sure to replace the model path in the script with the path to your downloaded model.
```bash
cd finetune
sh finetune.sh
```
You can customize the modules to fine-tune by setting the parameters:
`tune_vision_encoder`, `tune_vision_proj`, `tune_llm`, `tune_audio_encoder`, and `tune_audio_proj`.
### Notes
1. **Recommended Hardware**: Please use at least two GPUs with 80GB memory for fine-tuning.
2. **If GPU memory is insufficient**:
- Adjust the `model_max_length` and `per_device_train_batch_size` parameters.
- Disable specific modules for fine-tuning to reduce memory usage.
- Optimize memory consumption by configuring the `zero_optimization` parameters in DeepSpeed.
3. **For better inference results**:
- We recommend to put the images in the first round of chat for better inference results. There are no such restrictions for audio and text, which can be switched freely.
- In the Automatic Speech Recognition (ASR) scenario, simply change content['text'] to "Convert speech to text."
- In the OCR scenario, enabling sampling may introduce language model hallucinations which cause text changes. Users may consider disabling sampling in inference (sampling=False). However, disabling sampling may introduce model repetition.
## Open Source License and Usage Statement
- **License**: The code in this repository is open-sourced under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
- **Hallucination**: Large models inherently have hallucination issues. Users should not completely trust the content generated by the model.
- **Values and Safety**: While we have made every effort to ensure compliance of the data used during training, the large volume and complexity of the data may still lead to unforeseen issues. We disclaim any liability for problems arising from the use of this open-source model, including but not limited to data security issues, public opinion risks, or risks and problems caused by misleading, misuse, propagation, or improper utilization of the model.
================================================
FILE: README_zh.md
================================================
<div align="center">
# Megrez-3B-Omni: 首个端侧全模态理解开源模型
<p align="center">
<img src="assets/megrez_logo.png" width="400"/>
<p>
<p align="center">
📄 <a href="assets/Megrez_Omni_Technical_Report.pdf">Paper</a>
🤗 <a href="https://huggingface.co/Infinigence/Megrez-3B-Omni">Huggingface</a>   |   🤖<a href="https://www.modelscope.cn/models/InfiniAI/Megrez-3B-Omni">Modelscope</a>   |   🖥️ <a href="https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni">Demo</a>   |   📖 <a href="assets/wechat-official.jpg">WeChat Official</a>   |   💬 <a href="assets/wechat-group.jpg">WeChat Groups</a>  
</p>
<strong>中文 | [English](./README.md)</strong>
</div>
## 模型简介
Megrez-3B-Omni是由无问芯穹([Infinigence AI](https://cloud.infini-ai.com/platform/ai))研发的**端侧全模态**理解模型,基于无问大语言模型Megrez-3B-Instruct扩展,同时具备图片、文本、音频三种模态数据的理解分析能力,在三个方面均取得最优精度
- 在图像理解方面,基于SigLip-400M构建图像Token,在OpenCompass榜单上(综合8个主流多模态评测基准)平均得分66.2,超越LLaVA-NeXT-Yi-34B等更大参数规模的模型。Megrez-3B-Omni也是在MME、MMMU、OCRBench等测试集上目前精度最高的图像理解模型之一,在场景理解、OCR等方面具有良好表现。
- 在语言理解方面,Megrez-3B-Omni并未牺牲模型的文本处理能力,综合能力较单模态版本(Megrez-3B-Instruct)精度变化小于2%,保持在C-EVAL、MMLU/MMLU Pro、AlignBench等多个测试集上的最优精度优势,依然取得超越上一代14B模型的能力表现
- 在语音理解方面,采用Qwen2-Audio/whisper-large-v3的Encoder作为语音输入,支持中英文语音输入及多轮对话,支持对输入图片的语音提问,根据语音指令直接响应文本,在多项基准任务上取得了领先的结果
## 评测结果
- 左图为Megrez-3B-Omni与其他开源模型在主流图片多模态任务上的性能比较
- 右图为Megrez-3B-Omni在OpenCompass测试集上表现,图片引用自: [InternVL 2.5 Blog Post](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/)*
<div style="display: flex; justify-content: space-between;">
<img src="assets/multitask.jpg" alt="Image 1" style="width: 45%;">
<img src="assets/opencompass.jpg" alt="Image 2" style="width: 45%;">
</div>
详细精度见 [Megrez-3B-Omni-HF](https://huggingface.co/Infinigence/Megrez-3B-Omni)
### 推理速度
| | image_tokens | prefill (tokens/s) | decode (tokens/s) |
|----------------|:------------:|:------------------:|:-----------------:|
| Megrez-3B-Omni | 448 | 6312.66 | 1294.9 |
| Qwen2-VL-2B | 1378 | 7349.39 | 685.66 |
| MiniCPM-V-2_6 | 448 | 2167.09 | 452.51 |
实验设置:
- 测试环境为NVIDIA H100下VLLM下输入128个Text token和一张 720*1480的图片,输出128个token,num_seqs固定为8。
- Qwen2-VL-2B的在此实验下的decode速度小于Megrez-3B-Omni,虽然其具备更小的基座LLM,但是编码上述大小图片后的image_token相较Megrez-3B-Omni较多,影响实际推理速度。
## 模型演示
【GIF】
## 安装
使用如下命令安装依赖:
```shell
pip install -r requirements.txt
```
音频功能依赖ffmpeg进行音频处理,如果您使用 Debian 相关的系统,可以通过以下命令安装:
```shell
sudo apt-get install ffmpeg
```
对于其他的操作系统,请参考 [ffmpeg 官方文档](https://ffmpeg.org/download.html) 进行安装。
## 模型推理
### 使用多模态数据进行多轮对话
请使用如下脚本进行推理。请将 `PATH_TO_PRETRAINED_MODEL` 替换为下载的模型权重的路径。
```python
import torch
from transformers import AutoModelForCausalLM
path = "{{PATH_TO_PRETRAINED_MODEL}}" # 更改为模型的路径
model = (
AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
messages = [
{
"role": "user",
"content": {
"text": "Please describe the content of the image.",
"image": "./data/sample_image.jpg",
},
},
]
MAX_NEW_TOKENS = 100
response = model.chat(
messages,
sampling=False,
max_new_tokens=MAX_NEW_TOKENS,
)
print(response)
```
完整的示例见:[example_chat_hf.py](example_chat_hf.py).
### 使用 vLLM 进行推理
我们提供了一个基于 vLLM 框架的推理参考实现。您可以在 [vllm_demo/megrezo.py](vllm_demo/megrezo.py) 中找到模型定义。
推理步骤如下:
1. 安装 vLLM
```shell
pip install vllm==0.6.3.post1 flash_attn==2.5.8 xformers==0.0.27.post2
```
**注意**:使用 vLLM 推理需要安装特定版本的依赖,其他版本可能存在接口不一致的风险。有任何问题欢迎[提issue](https://github.com/infinigence/Infini-Megrez-Omni/issues/new)。
2. 运行推理脚本
vLLM 尚未正式支持 MegrezO,因此您需要先导入我们定义的模块:
```python
from vllm import ModelRegistry
from megrezo import MegrezOModel
ModelRegistry.register_model("MegrezO", MegrezOModel)
```
然后,您可以使用以下代码运行推理:
```python
from PIL import Image
from vllm import LLM
from vllm import SamplingParams
model_path = "{{PATH_TO_HF_PRETRAINED_MODEL}}" # 更改为模型的路径
llm = LLM(
model_path,
trust_remote_code=True,
gpu_memory_utilization=0.5,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=1000,
repetition_penalty=1.2,
stop=["<|turn_end|>", "<|eos|>"],
)
img = Image.open("../data/sample_image.jpg")
conversation = [
{
"role": "user",
"content": {
"text": "图片的内容是什么?",
"image": img,
},
},
]
# 将对话转换为 vLLM 可接受的格式。
prompt = llm.get_tokenizer().apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
vllm_inputs = [
{
"prompt": prompt,
"multi_modal_data": {
"image": img,
},
}
]
# 生成输出
outputs = llm.generate(
vllm_inputs,
sampling_params,
)
# 打印输出
for output in outputs:
print(output.outputs[0].text)
```
完整的示例见:[vllm_demo/example_infer_vllm.py](vllm_demo/example_infer_vllm.py).
## 使用 Gradio 与 MegrezO 对话
我们提供基于 Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> 实现的在线和本地 Demo。
### WeiUI 演示
<div align="center" style="display: flex; justify-content: space-between;">
<img src="assets/gradio_demo.jpg" style="width: 80%;">
</div>
### 在线 Demo
欢迎试用在线 Demo: [🤗Megrez-3B-Omni](https://huggingface.co/spaces/Infinigence/Megrez-3B-Omni)。
### 本地 Demo
使用如下命令部署本地 Gradio 应用:
1. 安装依赖:
```shell
pip install -r requirements.txt
```
2. 启动 Gradio 应用
您需要在命令行中指定 `model_path` 和 `port`。`model_path` 是模型的路径,`port` 是本地服务器的端口号。默认情况下,`port` 是 `7860`。
```shell
python gradio_app.py --model_path {model_path} --port {port}
```
然后,您可以在浏览器中访问 `http://localhost:7860` 与模型对话。
如需自定义输入和输出接口,请修改 `gradio_app.py`。更多信息请参考 [Gradio 文档](https://gradio.app/docs)。
## 微调模型
我们提供了一个基于 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 和 [accelerate](https://github.com/huggingface/accelerate) 的[微调示例](./finetune/)。
### 数据准备
我们基于[ALLaVA-4V/allava_laion](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/tree/main/allava_laion)构造了一个示例数据集:
- **对话**:[data/train/records.jsonl](./data/train/records.jsonl)
- **图片**:[data/train/images](./data/train/images)
- **音频**:[data/train/audio](./data/train/audio),是通过将对话中的文本使用TTS转换为语音得到的。
您也可以按照上述格式准备自己的数据集。
### 依赖安装
```shell
pip install deepspeed accelerate
```
### 全参微调
使用如下命令运行我们的微调示例,请注意将脚本中的模型路径替换成您下载的模型路径。
```shell
cd finetune
sh finetune.sh
```
您可以通过设置`tune_vision_encoder`、`tune_vision_proj`、`tune_llm`、`tune_audio_encoder`、`tune_audio_proj`来选择需要微调的模块。
### 注意事项
- 推荐使用至少2张拥有80G显存的GPU进行微调。
- 在显存不足的情况下:
- 请尝试调整`model_max_length`和`per_device_train_batch_size`。
- 请尝试关闭需要微调的模块以便减少显存占用。
- 请尝试调整deepspeed的`zero_optimization`参数来优化显存使用。
- 使用时
- 请将图片尽量在首轮输入以保证推理效果,语音和文本无此限制,可以自由切换
- 语音识别(ASR)场景下,只需要将content['text']修改为“将语音转化为文字。”
- OCR场景下开启采样可能会引入语言模型幻觉导致的文字变化,可考虑关闭采样进行推理(sampling=False),但关闭采样可能引入模型复读
## 开源协议及使用声明
- **协议**:本仓库中代码依照 [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 协议开源。
- **幻觉**:大模型天然存在幻觉问题,用户使用过程中请勿完全相信模型生成的内容。
- **价值观及安全性**:本模型已尽全力确保训练过程中使用的数据的合规性,但由于数据的大体量及复杂性,仍有可能存在一些无法预见的问题。如果出现使用本开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
================================================
FILE: data/train/records.jsonl
================================================
{"conversations":[{"role":"user","content":{"text":"Considering the items the figurine is holding, what scene or event from its source material might it be representing?","image":"images/100230144.jpeg","audio":"audios/0_0_0_chattts.wav"}},{"role":"assistant","content":{"text":"Based on the items the figurine is holding, which are a rifle and an iron, it likely represents a scene from the source material where the character is setting up traps to protect his home. These items are indicative of the resourceful and inventive methods the character uses in the movie \"Home Alone\" to create humorous but effective traps for intruders."}}]}
{"conversations":[{"role":"user","content":{"text":"How does the inclusion of gemstones affect the overall aesthetic and possible value of the ring?","image":"images/100277305.jpeg","audio":"audios/0_1_0_chattts.wav"}},{"role":"assistant","content":{"text":"The inclusion of gemstones along the edges of the leaf design enhances the ring's aesthetic by adding a touch of brilliance and luxury. The small, clear gemstones catch the light, creating a sparkling effect that contrasts with the silver's polished surface and the leaf's textured details. This sparkling accent likely increases the ring's visual appeal, making it stand out more as an accessory. In terms of value, the presence of gemstones, which could be diamonds, would likely raise the ring's monetary worth. Gemstones, especially diamonds, are prized for their beauty and durability, and their addition to jewelry pieces often signifies a higher quality and class of the item, thereby increasing its market value."}}]}
{"conversations":[{"role":"user","content":{"text":"What might the significance of the figure holding the 'O' in relation to the setting sun symbolize in the context of this album cover?","image":"images/100234216.jpeg","audio":"audios/0_2_0_chattts.wav"}},{"role":"assistant","content":{"text":"The silhouette of the figure holding the 'O' of the \"WOW\" so that it appears to be part of the setting sun likely symbolizes an element of enlightenment or revelation, which is often a theme in Christian music. It could represent the divine or an epiphany, echoing the 'wow' factor that the album title suggests. This imagery might be intended to convey a sense of awe or an extraordinary experience, possibly reflecting the emotional or spiritual impact that the songs aim to impart on the listener."}}]}
{"conversations":[{"role":"user","content":{"text":"What might be the significance of the character's clothing and instrument choice in relation to the type of event being promoted?","image":"images/100223054.jpeg","audio":"audios/0_3_0_chattts.wav"}},{"role":"assistant","content":{"text":"The character's clothing and instrument choice are indicative of a relaxed and informal atmosphere, which aligns with the musical genres mentioned on the poster\u2014rock and hard rock. The gray cap and black shirt could be associated with a laid-back, perhaps slightly rebellious aesthetic that is often linked to rock music cultures. The red bass guitar signifies that music is central to the event and hints that the character may represent a member of the band \"PuBFlieS,\" suggesting they play rock or hard rock music, as bass guitars are fundamental in these genres. The cheerful demeanor of the character along with the casual attire suggests that the event aims to be welcoming and enjoyable, catering to fans of these music genres."}}]}
{"conversations":[{"role":"user","content":{"text":"What material is the water bottle likely made from?","image":"images/100280844.jpeg","audio":"audios/0_4_0_chattts.wav"}},{"role":"assistant","content":{"text":"Based on the image alone, the water bottle is likely made from metal, inferred from the reflective and smooth surface, which is characteristic of metal water bottles. The metallic sheen and lack of any seams or typical plastic texturing support this assumption. Additionally, the way the light reflects off the surface suggests a material that is more reflective than plastic or glass, which is often the case with stainless steel or aluminum bottles."}}]}
{"conversations":[{"role":"user","content":{"text":"What historical period might the attire of the person in the painting suggest?","image":"images/100284327.jpeg","audio":"audios/0_5_0_chattts.wav"}},{"role":"assistant","content":{"text":"The attire of the person in the painting suggests a historical period that could be the late Middle Ages or the Renaissance, specifically within Europe. This is indicated by the style of the hat, which resembles a wide-brimmed hat known as a Gainsborough hat, popular in different forms from the 17th century onwards, and the monastic-like robes, which were common in religious orders during the medieval period. However, the exact historical period cannot be definitively determined from the attire alone due to the painting's anachronistic nature."}}]}
{"conversations":[{"role":"user","content":{"text":"Could the absence of any ports or connectors on the visible side suggest that this is a wireless device?","image":"images/100245026.jpeg","audio":"audios/0_6_0_chattts.wav"}},{"role":"assistant","content":{"text":"Based on the visible features in the image, the absence of any ports, buttons, or connectors on the surface of the device does suggest that it could be a wireless device. This characteristic is common in modern wireless technology, where devices are often designed to be sleek and unobtrusive, relying on wireless connectivity for their functionality. The simplicity and cleanliness of the design further imply that it could be a wireless charging pad, which typically requires no physical connections to the devices it charges."}}]}
{"conversations":[{"role":"user","content":{"text":"Considering the individual's attire and the setting, what could be the possible theme or context of this photo shoot?","image":"images/10688.jpeg","audio":"audios/3_0_0_chattts.wav"}},{"role":"assistant","content":{"text":"The theme or context of the photo shoot could be a conceptual representation of rebellion or resistance. The military-style jacket with patches, the helmet, and the setting that resembles a prison all suggest a narrative of combatting authority or standing against confinement. The use of fashion to portray this narrative indicates a possible commentary on individualism and defiance."}}]}
{"conversations":[{"role":"user","content":{"text":"What could be the possible association between the two logos presented in the image, and how might they relate to the content listed in the slide?","image":"images/104042.jpeg","audio":"audios/2_0_0_chattts.wav"}},{"role":"assistant","content":{"text":"The possible association between the two logos and the content of the slide suggests a partnership or a collaborative project focused on recycling and waste electrical and electronic equipment (WEEE). The \"LIFE +\" logo is associated with an EU environmental initiative, and \"RECYCLING SYMPRAXIS\" suggests a practice or a consortium working towards recycling. The date and word \"PHILOXENIA\" hint at an event, possibly a conference or seminar that took place in 2010. The second logo, which is less identifiable, likely represents the organization responsible for the content of the presentation, in this case, \"Q-PLAN Northern Greece\", which seems to be the coordinator or the main body overseeing the implementation of the state-of-the-art technologies and applications in WEEE recycling. The contents listed in the slide would be topics discussed in relation to these technologies and their applications."}}]}
{"conversations":[{"role":"user","content":{"text":"What might the three stars above the team crest signify in the context of soccer achievements?","image":"images/100271334.jpeg","audio":"audios/0_7_0_chattts.wav"}},{"role":"assistant","content":{"text":"The three stars above the team crest traditionally represent major honors or championships won by the team. In many soccer leagues, a star is added to the team's crest for a set number of league or major tournament victories. For instance, a club might add a star for every ten league titles they win. Therefore, these stars are likely indicative of the team's historical success, possibly in their domestic league or international competitions."}}]}
================================================
FILE: example_chat_hf.py
================================================
# -*- encoding: utf-8 -*-
# File: example_chat_hf.py
# Description: None
import torch
from transformers import AutoModelForCausalLM
path = "/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o" # Change this to the path of the model.
model = (
AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
.eval()
.cuda()
)
prompt = "hi" * (128 - 1)
# Chat with text and image
messages = [
{
"role": "user",
"content": {
"text": prompt,
"image": "./data/sample_image.jpg",
},
},
]
# Chat with audio and image
# messages = [
# {
# "role": "user",
# "content": {
# "image": "./data/sample_image.jpg",
# "audio": "./data/sample_audio.m4a",
# },
# },
# ]
MAX_NEW_TOKENS = 100
response = model.chat(
messages,
sampling=False,
max_new_tokens=MAX_NEW_TOKENS,
)
print(response)
================================================
FILE: finetune/dataset.py
================================================
# -*- encoding: utf-8 -*-
# File: dataset.py
# Description: None
import os
import numpy as np
from regex import F
import torch
from torch.utils.data import Dataset
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
raw_data_list,
processor,
process_func,
dataset_prefix="",
):
super(SupervisedDataset, self).__init__()
self.raw_data_list = raw_data_list
self.processor = processor
self.process_func = process_func
self.dataset_prefix = dataset_prefix
def __len__(self):
return len(self.raw_data_list)
def check_ret(self, ret):
flag = True
for key in ret.keys():
value_list = ret[key]
if not isinstance(value_list, list):
value_list = [value_list]
for value in value_list:
if isinstance(value, torch.Tensor):
if torch.isnan(value).any():
flag = False
if torch.isinf(value).any():
flag = False
return flag
def check_audio(self, ret):
flag = True
for audio in ret["msgs_audio"]:
if (audio["input_audio_lengths"][:, 1] == 0).any():
flag = False
return flag
def prepare_labels(self, data):
def prepare_labels(tokenizer, input_ids, padding_value=-100):
# <|role_start|>assistant<|role_end|> 后面的内容才是需要算loss的部分
def find_start_header_idxs():
start_header_tokens = tokenizer.encode("<|role_start|>assistant<|role_end|>", add_special_tokens=False)
start_header_idxs = np.where(input_ids == start_header_tokens[-1])[0]
kept_start_header_idxs = []
for start_header_idx in start_header_idxs:
keep = True
for i in range(1, len(start_header_tokens)):
if start_header_tokens[-(i + 1)] != input_ids[start_header_idx - i]:
keep = False
break
if keep:
kept_start_header_idxs.append(start_header_idx)
return kept_start_header_idxs
turn_end_token_id = tokenizer.encode("<|turn_end|>")[0]
start_header_idxs = find_start_header_idxs()
end_header_idxs = np.where(input_ids == turn_end_token_id)[0]
label_mask = np.zeros_like(input_ids, dtype=np.bool_)
def find_next_greater_number(lst, num):
next_greater = None
for n in lst:
if n > num:
if next_greater is None or n < next_greater:
next_greater = n
return next_greater
nr_tokens = len(input_ids)
for start_head_idx in start_header_idxs:
start_idx = start_head_idx + 1
end_idx = find_next_greater_number(end_header_idxs, start_head_idx)
end_idx = min(end_idx + 1, nr_tokens)
label_mask[start_idx:end_idx] = True
labels = torch.ones(input_ids.shape[0] + 1) * padding_value
labels[: input_ids.shape[0]] = input_ids
labels[: input_ids.shape[0]][~label_mask] = padding_value
labels = labels[1:]
return labels.long()
return prepare_labels(self.processor.tokenizer, data["input_ids"])
def add_dataset_prefix(self, item):
conv = item["conversations"]
for i in range(len(conv)):
content = conv[i]["content"]
if "image" in content:
content["image"] = os.path.join(self.dataset_prefix, content["image"])
if "audio" in content:
content["audio"] = os.path.join(self.dataset_prefix, content["audio"])
return conv
def __getitem__(self, i):
raw_data_item = self.raw_data_list[i]
item = self.add_dataset_prefix(raw_data_item)
processed_data = self.processor(
item,
add_generation_prompt=False,
apply_data_collator=False,
)
if "labels" not in processed_data:
processed_data["labels"] = self.prepare_labels(processed_data)
return processed_data
================================================
FILE: finetune/ds_config_zero2.json
================================================
{
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
================================================
FILE: finetune/finetune.py
================================================
# -*- encoding: utf-8 -*-
# File: finetune.py
# Description: None
import glob
import json
import logging
import os
from dataclasses import dataclass
from dataclasses import field
from functools import partial
from glob import glob
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
import transformers
from accelerate.utils import DistributedType
from dataset import SupervisedDataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from trainer import MegrezOTrainer
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor
from transformers import AutoTokenizer
from transformers.integrations import deepspeed
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})
dataset_prefix: str = field(default="data", metadata={"help": "Prefix for the multimodal data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
tune_vision_encoder: Optional[bool] = field(default=True)
tune_vision_proj: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
tune_audio_encoder: Optional[bool] = field(default=True)
tune_audio_proj: Optional[bool] = field(default=True)
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
scale_resolution: Optional[int] = field(default=448)
remove_unused_columns: Optional[bool] = field(default=False)
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.05
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
lora_modules_to_save: str = ""
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
lora_layers_to_transform: Optional[List[int]] = None
lora_layers_pattern: Optional[str] = None
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
if trainer.args.use_lora:
state_dict = get_peft_state_maybe_zero_3(trainer.model.named_parameters(), bias)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
def make_supervised_data_module(
data_args,
processor,
process_func,
data_collator=None,
max_length=2048,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
rank0_print("Loading data...")
with open(data_args.data_path, "r") as f:
raw_data_list = [json.loads(line) for line in f]
train_dataset = SupervisedDataset(
raw_data_list,
processor,
process_func,
data_args.dataset_prefix,
)
eval_dataset = None
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=partial(data_collator, max_length=max_length, collate_labels=True),
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return {"Total": all_param, "Trainable": trainable_params}
local_rank = 0
def load_model_from_pretrained(model_path, dtype=torch.bfloat16):
model = AutoModelForCausalLM.from_pretrained(
model_path, _attn_implementation="flash_attention_2", trust_remote_code=True, torch_dtype=dtype
)
return model
def load_tokenizer_from_pretrained(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
return tokenizer
def train():
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, "deepspeed", None):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
local_rank = training_args.local_rank
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = None
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")
model = load_model_from_pretrained(model_args.model_name_or_path, dtype=compute_dtype)
tokenizer = load_tokenizer_from_pretrained(model_args.model_name_or_path)
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model.tune_llm = training_args.tune_llm
model.tune_vision = training_args.tune_vision_encoder or training_args.tune_vision_proj
model.tune_audio = training_args.tune_audio_encoder or training_args.tune_audio_proj
if not training_args.tune_vision_encoder:
model.vision.vpm.requires_grad_(False)
if not training_args.tune_vision_proj:
model.vision.resampler.requires_grad_(False)
if not training_args.tune_llm:
model.llm.requires_grad_(False)
if not training_args.tune_audio_encoder:
model.audio.requires_grad_(False)
model.audio.audio.proj.requires_grad_(True)
if model.audio.audio.audio_bos_eos_token is not None:
model.audio.audio.audio_bos_eos_token.requires_grad_(True)
if not training_args.tune_audio_proj:
model.audio.audio.proj.requires_grad_(False)
if model.audio.audio.audio_bos_eos_token is not None:
model.audio.audio.audio_bos_eos_token.requires_grad_(False)
rank0_print(get_parameter_number(model))
data_module = make_supervised_data_module(
data_args=data_args,
processor=processor,
process_func=None,
data_collator=processor.data_collator,
max_length=training_args.model_max_length,
)
if training_args.lr_scheduler_type == "cosine_with_min_lr":
training_args.lr_scheduler_kwargs = {"min_lr_rate": 0.1}
trainer = MegrezOTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
train_dataset = trainer.train_dataset
nr_data = len(train_dataset)
rank0_print("nr dataset: {}".format(nr_data))
checkpoint_path = os.path.join(training_args.output_dir, "checkpoint*")
checkpoint_paths = sorted(list(glob(checkpoint_path)))
valid_checkpoint_paths = []
for checkpoint_path in checkpoint_paths:
checkpoint_num = checkpoint_path.split("-")[-1]
if checkpoint_num.isdigit():
valid_checkpoint_paths.append(checkpoint_path)
checkpoint_paths = sorted(list(valid_checkpoint_paths))
checkpoint_paths = sorted(checkpoint_paths, key=lambda x: int(x.split("-")[-1]))
checkpoint_paths = list(checkpoint_paths)
load_checkpoint = True
if load_checkpoint and checkpoint_paths:
checkpoint_path = checkpoint_paths[-1]
rank0_print("Continue Checkpoint Training: {}".format(checkpoint_path))
trainer.train(checkpoint_path)
else:
trainer.train()
trainer.save_state()
final_path = os.path.join(training_args.output_dir, "final")
os.makedirs(final_path, exist_ok=True)
rank0_print("save final path to {}".format(final_path))
safe_save_model_for_hf_trainer(trainer, final_path)
if __name__ == "__main__":
train()
================================================
FILE: finetune/finetune.sh
================================================
DATA_PATH=$(pwd)/../data/train/records.jsonl
DATASET_PREFIX=$(pwd)/../data/train/
CURRENT_TIME=$(date +%Y%m%d_%H%M%S)
OUTPUT_DIR=$(pwd)/test_finetune/$CURRENT_TIME
LOGGING_DIR=$(pwd)/test_finetune_log
MODEL_PATH=""
torchrun --nproc_per_node=2 finetune.py \
--data_path $DATA_PATH \
--dataset_prefix $DATASET_PREFIX \
--output_dir $OUTPUT_DIR \
--logging_dir $LOGGING_DIR \
--model_name_or_path $MODEL_PATH \
--learning_rate 1e-5 \
--num_train_epochs 10 \
--deepspeed ds_config_zero2.json \
--prediction_loss_only false \
--bf16 true \
--fp16 false \
--do_train \
--tune_vision_encoder true \
--tune_vision_proj true \
--tune_llm true \
--tune_audio_encoder false \
--tune_audio_proj true \
--model_max_length 2048 \
--max_slice_nums 9 \
--scale_resolution 448 \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--save_total_limit 100 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.98 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1
================================================
FILE: finetune/requirements.txt
================================================
deepspeed
accelerate
================================================
FILE: finetune/trainer.py
================================================
# -*- encoding: utf-8 -*-
# File: trainer.py
# Description: None
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_pt_utils import nested_detach
class MegrezOTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
self.model.vision.resampler.pos_embed = self.model.vision.resampler.pos_embed.to(self.model.device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.vision.resampler.attn.parameters(), modifier_rank=0):
if not self.args.use_lora:
outputs = self.model(data=inputs, use_cache=False)
else:
outputs = self.model.base_model(data=inputs, use_cache=False)
else:
if not self.args.use_lora:
outputs = self.model(data=inputs, use_cache=False)
else:
outputs = self.model.base_model(data=inputs, use_cache=False)
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
del inputs
torch.cuda.empty_cache()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
from transformers.trainer import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
self.accelerator.backward(loss)
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
================================================
FILE: gradio_app.py
================================================
# -*- encoding: utf-8 -*-
# File: app.py
# Description: None
import threading
from copy import deepcopy
from typing import Dict, List
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextIteratorStreamer
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")
VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v")
AUDIO_EXTENSIONS = (".mp3", ".wav")
DEFAULT_SAMPLING_PARAMS = {
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"num_beams": 1,
"repetition_penalty": 1.2,
}
MAX_NEW_TOKENS = 1024
def main(model_path: str, port: int):
if gr.NO_RELOAD:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = (
AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
.eval()
.cuda()
)
iterable_streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=30,
)
def history2messages(history: List[Dict]) -> List[Dict]:
"""
Transform gradio history to chat messages.
"""
messages = []
cur_message = dict()
for item in history:
if item["role"] == "assistant":
if len(cur_message) > 0:
messages.append(deepcopy(cur_message))
cur_message = dict()
messages.append(deepcopy(item))
continue
if "role" not in cur_message:
cur_message["role"] = "user"
if "content" not in cur_message:
cur_message["content"] = dict()
if "metadata" not in item:
item["metadata"] = {"title": None}
if item["metadata"]["title"] is None:
cur_message["content"]["text"] = item["content"]
elif item["metadata"]["title"] == "image":
cur_message["content"]["image"] = item["content"][0]
elif item["metadata"]["title"] == "audio":
cur_message["content"]["audio"] = item["content"][0]
if len(cur_message) > 0:
messages.append(cur_message)
return messages
def check_messages(history, message, audio):
audios = []
images = []
for file_msg in message["files"]:
if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS):
audios.append(file_msg)
elif file_msg.endswith(IMAGE_EXTENSIONS):
images.append(file_msg)
else:
filename = file_msg.split("/")[-1]
raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.")
if len(audios) > 1:
raise gr.Error("Please upload only one audio file.")
if len(images) > 1:
raise gr.Error("Please upload only one image file.")
if audio is not None:
if len(audios) > 0:
raise gr.Error("Please upload only one audio file or record audio.")
audios.append(audio)
# Append the message to the history
for image in images:
history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}})
for audio in audios:
history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}})
if message["text"] is not None:
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(
history: list,
top_p: float,
top_k: int,
temperature: float,
repetition_penalty: float,
max_new_tokens: int = MAX_NEW_TOKENS,
regenerate: bool = False,
):
sampling_params = {
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
if regenerate:
history = history[:-1]
msgs = history2messages(history)
th = threading.Thread(
target=model.chat,
kwargs=dict(
input_msgs=msgs,
sampling=True,
streamer=iterable_streamer,
max_new_tokens=max_new_tokens,
**sampling_params,
),
)
th.start()
response = ""
for subtext in iterable_streamer:
response += subtext
yield history + [{"role": "assistant", "content": response}]
th.join()
return response
def change_state(state):
return gr.update(visible=not state), not state
with gr.Blocks() as demo:
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height=800)
sampling_params_group_hidden_state = gr.State(False)
with gr.Row(equal_height=True):
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
scale=4,
)
chat_input = gr.MultimodalTextbox(
file_count="multiple",
show_label=False,
scale=10,
file_types=["image", "audio"],
# stop_btn=True,
)
with gr.Column(scale=1, min_width=150):
with gr.Row(equal_height=True):
regenerate_btn = gr.Button("Regenerate", variant="primary")
clear_btn = gr.ClearButton(
[chat_input, audio_input, chatbot],
)
with gr.Row():
sampling_params_toggle_btn = gr.Button("Sampling Parameters")
with gr.Group(visible=False) as sampling_params_group:
with gr.Row():
temperature = gr.Slider(
minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature"
)
repetition_penalty = gr.Slider(
minimum=0,
maximum=2,
value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
label="Repetition Penalty",
)
with gr.Row():
top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p")
top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k")
with gr.Row():
max_new_tokens = gr.Slider(
minimum=1,
maximum=MAX_NEW_TOKENS,
value=MAX_NEW_TOKENS,
label="Max New Tokens",
interactive=True,
)
sampling_params_toggle_btn.click(
change_state,
sampling_params_group_hidden_state,
[sampling_params_group, sampling_params_group_hidden_state],
)
chat_msg = chat_input.submit(
check_messages,
[chatbot, chat_input, audio_input],
[chatbot, chat_input],
)
bot_msg = chat_msg.then(
bot,
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens],
outputs=chatbot,
api_name="bot_response",
)
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
regenerate_btn.click(
bot,
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)],
outputs=chatbot,
)
demo.launch(server_port=port)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--port", type=int, default=7680)
args = parser.parse_args()
main(args.model_path, args.port)
================================================
FILE: requirements.txt
================================================
transformers>=4.44.0
tokenizers>=0.20.3
accelerate
datasets
gradio
================================================
FILE: vllm_demo/example_infer_vllm.py
================================================
# -*- encoding: utf-8 -*-
# File: example_infer_vllm.py
# Description: None
from PIL import Image
from vllm import LLM
from vllm import ModelRegistry
from vllm import SamplingParams
from megrezo import MegrezOModel
ModelRegistry.register_model("MegrezO", MegrezOModel)
# Load the model.
# model_path = "{{PATH_TO_HF_PRETRAINED_MODEL}}" # Change this to the path of the model.
model_path = "/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o" # Change this to the path of the model.
llm = LLM(
model_path,
trust_remote_code=True,
gpu_memory_utilization=0.5,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=1000,
repetition_penalty=1.2,
stop=["<|turn_end|>", "<|eos|>"],
)
img = Image.open("../data/sample_image.jpg")
conversation = [
{
"role": "user",
"content": {
"text": "图片的内容是什么?",
"image": img,
},
},
]
# Convert the conversation to vLLM acceptable format.
prompt = llm.get_tokenizer().apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
vllm_inputs = [
{
"prompt": prompt,
"multi_modal_data": {
"image": img,
},
}
]
# Generate the outputs.
outputs = llm.generate(
vllm_inputs,
sampling_params,
)
# Print the outputs.
for output in outputs:
print(output.outputs[0].text)
================================================
FILE: vllm_demo/megrezo.py
================================================
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MegrezO model compatible with HuggingFace weights."""
from functools import lru_cache
from functools import partial
from typing import Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.types
from PIL import Image
from torch import Tensor
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs import DecoderOnlyInputs
from vllm.inputs import InputContext
from vllm.inputs import token_inputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import VllmModelForTextGeneration
from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import get_processor
RawImageType = Union[Image.Image, torch.Tensor]
RawAudioType = Union[bytes, torch.Tensor]
cached_get_processor = lru_cache(get_processor)
class MegrezORawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: RawImageType
class MegrezOAudioInput(TypedDict):
type: Literal["audio"]
data: RawAudioType
class MegrezOAudioTensorInput(TypedDict):
type: Literal["audio_tensor"]
input_audios: torch.Tensor
input_audio_lengths: torch.Tensor
audio_span_tokens: torch.Tensor
class MegrezOImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that the image size may vary, so we pass it as a list
instead of a batched tensor.
"""
tgt_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
patch_attention_mask: torch.Tensor
"""
Shape: `(batch_size * num_images, num_patches, num_patches)`
"""
class MegrezOImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
inserted_bounds = inserted_bounds.long()
for idx in range(len(inserted_embeddings)):
bid = inserted_bounds[idx][0]
start_id = inserted_bounds[idx][1]
end_id = inserted_bounds[idx][2]
embedding = inserted_embeddings[idx]
text_embeddings[start_id + 1 : end_id] = embedding
return text_embeddings
def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
inserted_bounds = inserted_bounds.long()
for idx in range(len(inserted_embeddings)):
bid = inserted_bounds[idx][0]
start_id = inserted_bounds[idx][1]
end_id = inserted_bounds[idx][2]
embedding = inserted_embeddings[idx]
text_embeddings[start_id:end_id] = embedding
return text_embeddings
MegrezOImageInputs = Union[MegrezOImagePixelInputs]
MegrezOAudioInputs = Union[MegrezOAudioTensorInput]
# region: Resampler
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class Resampler(nn.Module):
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False, quant_config=quant_config, prefix=prefix)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.do_post_projection = True
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
self.max_size = max_size
self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.types.Device = "cpu") -> None:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, max_size, version=(2, 5))
pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device) -> None:
max_h = tgt_sizes[:, 0].max().item()
max_w = tgt_sizes[:, 1].max().item()
assert isinstance(max_h, int) and isinstance(max_w, int)
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = (
max(max_h, self.max_size[0]),
max(max_w, self.max_size[1]),
)
self._set_2d_pos_cache(self.max_size, device)
def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = patch_len.max().item()
assert isinstance(max_patch_len, int)
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i].tolist()
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i] :] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
1, 0, 2
) # BLD => L * B * D
x, _ = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask,
)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
# endregion
# region: AudioEncoder
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype)
return super().forward(x).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk += mask
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
output_dim: int = 512,
avg_pool: bool = True,
add_audio_bos_eos_token: bool = True,
**kwargs,
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
# self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.positional_embedding = nn.Parameter(sinusoids(n_ctx, n_state), requires_grad=False)
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
if avg_pool:
self.avg_pooler = nn.AvgPool1d(2, stride=2)
else:
self.avg_pooler = None
self.proj = nn.Linear(n_state, output_dim)
if add_audio_bos_eos_token:
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
else:
self.audio_bos_eos_token = None
self.output_dim = output_dim
self.n_head = n_head
def forward(self, x: Tensor, padding_mask: Tensor = None, audio_lengths: Tensor = None):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = x.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
if audio_lengths is not None:
input_mel_len = audio_lengths[:, 0] * 2
max_mel_len_in_batch = input_mel_len.max()
x = x[:, :, :max_mel_len_in_batch]
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) # B, L, D
bsz = x.size(0)
src_len = x.size(1)
self.input_positional_embedding = self.positional_embedding[:src_len]
assert (
x.shape[1:] == self.input_positional_embedding.shape
), f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
x = (x + self.input_positional_embedding).to(x.dtype)
if padding_mask is not None:
padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
batch_src_len = padding_mask.size(1)
x = x[:, :batch_src_len, :]
padding_mask = padding_mask.view(bsz, -1, batch_src_len)
padding_mask_ = padding_mask.all(1)
x[padding_mask_] = 0
key_padding_mask = (
padding_mask_.view(bsz, 1, 1, batch_src_len)
.expand(-1, self.n_head, -1, -1)
.reshape(bsz, self.n_head, 1, batch_src_len)
)
new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
for block in self.blocks:
x = block(x, mask=padding_mask)
if self.avg_pooler:
x = x.permute(0, 2, 1)
x = self.avg_pooler(x)
x = x.permute(0, 2, 1)
x = self.ln_post(x)
x = self.proj(x)
if self.audio_bos_eos_token is not None:
bos = self.audio_bos_eos_token.weight[0][None, :]
eos = self.audio_bos_eos_token.weight[1][None, :]
else:
bos, eos = None, None
return x, bos, eos
def encode(
self,
input_audios: Tensor,
input_audio_lengths: Tensor,
audio_span_tokens: List,
):
real_input_audio_lens = input_audio_lengths[:, 0].tolist()
max_len_in_batch = max(real_input_audio_lens)
padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(
dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
)
for index in range(len(input_audios)):
padding_mask[index, : input_audio_lengths[index][0].item()] = 0
x, bos, eos = self(input_audios, padding_mask, input_audio_lengths)
output_audios = []
for i in range(len(audio_span_tokens)):
audio_span = audio_span_tokens[i]
audio = x[i][: audio_span - 2]
if bos is not None:
audio = torch.concat([bos, audio, eos])
assert len(audio) == audio_span
output_audios.append(audio)
return output_audios
class AudioModel(torch.nn.Module):
def __init__(self, config):
super(AudioModel, self).__init__()
self.config = config
self.audio = AudioEncoder(**config.audio_config.to_dict())
def forward(self, audio_info):
audios = audio_info["input_audios"][0]
input_audio_lengths = audio_info["input_audio_lengths"][0]
audio_span_tokens = audio_info["audio_span_tokens"][0]
audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens)
return audios_features
# endregion
def get_max_megrezo_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return getattr(hf_config, "query_num", 64) * 10
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_prompt_token_counts((0, seq_len))
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, num_images: int):
width = height = hf_config.vision_config.image_size
imgs = [MegrezORawImageInput(image=Image.new("RGB", (width, height), color=0)) for _ in range(num_images)]
return {"image": imgs}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) # skip audio for now
return (seq_data, mm_data)
def input_processor_for_megrezo(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("image" not in multi_modal_data and "audio" not in multi_modal_data):
return inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=model_config.trust_remote_code)
processor = cached_get_processor(model_config.model, trust_remote_code=model_config.trust_remote_code)
prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
if prompt is None:
prompt = tokenizer.decode(token_ids)
images = multi_modal_data.get("image")
audios = multi_modal_data.get("audio")
prompt, multimodal_inputs = processor.process_multimodal_inputs(
prompt,
images=images,
audios=audios,
return_tensors="pt",
)
text_encodings = tokenizer(
prompt,
return_tensors="pt",
padding=True,
padding_side="left",
)
encodings = processor.merge_encodings(text_encodings, multimodal_inputs)
data = processor.data_collator([encodings])
new_prompt = tokenizer.decode(data["input_ids"][0])
new_multi_modal_data = {
"image": data["image_encoding"],
"audio": data["audio_encoding"],
}
return token_inputs(
prompt_token_ids=data["input_ids"][0],
prompt=new_prompt,
multi_modal_data=new_multi_modal_data,
)
def input_mapper_for_megrezo(ctx: InputContext, data: object):
return MultiModalInputs(data)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_megrezo)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_megrezo)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", 3000)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_megrezo_image_tokens)
@INPUT_REGISTRY.register_input_processor(input_processor_for_megrezo)
class MegrezOModel(nn.Module, VllmModelForTextGeneration, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config
self.multimodal_config = multimodal_config
self.llm = self.init_llm(config, cache_config, quant_config, prefix="model")
self.vision = self.init_vision_module(config, quant_config, prefix="vpm")
param_dtype = torch.get_default_dtype()
self.vision.to(dtype=param_dtype)
self.audio = self.init_audio_module(config, quant_config)
self.audio.to(dtype=param_dtype)
self.vision_dim = self.vision.embeddings.embed_dim
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(
self.embed_dim, self.vision_dim, quant_config=quant_config, prefix="vision.resampler"
)
self.resampler.to(device="cuda", dtype=param_dtype)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config, prefix="llm.lm_head"
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
self._called_cnt = 0
def get_vision_hidden_states(
self,
pixel_values,
tgt_sizes,
patch_attn_mask,
) -> torch.Tensor:
device = self.vision.embeddings.position_embedding.weight.device
dtype = self.vision.embeddings.position_embedding.weight.dtype
pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in pixel_values]).type(dtype)
vision_embedding = self.vision(
pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return self.resampler(vision_embedding, tgt_sizes)
def compose_embeddings(self, mini_batch):
input_ids = mini_batch["input_ids"]
image_encoding = mini_batch.get("image_encoding")
audio_encoding = mini_batch.get("audio_encoding")
embeddings_text = self.llm.model.embed_tokens(input_ids)
input_embeds = embeddings_text
if image_encoding:
pixel_values = image_encoding["pixel_values"][0]
tgt_sizes = image_encoding["tgt_sizes"][0]
patch_attention_mask = image_encoding["patch_attention_mask"][0]
bounds_image = image_encoding["image_bounds"][0]
device = self.vision.embeddings.position_embedding.weight.device
dtype = self.vision.embeddings.position_embedding.weight.dtype
embeddings_image = self.get_vision_hidden_states(
pixel_values.to(device, dtype),
tgt_sizes,
patch_attention_mask.to(device),
)
input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image)
if audio_encoding:
embeddings_audio = self.audio(audio_encoding)
bounds_audio = audio_encoding["audio_bounds"][0]
input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio)
return input_embeds
def _parse_inputs(self, input_ids: torch.Tensor, **kwargs):
if kwargs.get("pixel_values") is not None:
image_encoding = {
"pixel_values": kwargs.get("pixel_values"),
"tgt_sizes": kwargs.get("tgt_sizes"),
"patch_attention_mask": kwargs.get("patch_attention_mask"),
"image_bounds": kwargs.get("image_bounds"),
}
else:
image_encoding = None
if kwargs.get("input_audios") is not None:
audio_encoding = {
"input_audios": kwargs.get("input_audios"),
"input_audio_lengths": kwargs.get("input_audio_lengths"),
"audio_span_tokens": kwargs.get("audio_span_tokens"),
"audio_bounds": kwargs.get("audio_bounds"),
}
else:
audio_encoding = None
return {
"input_ids": input_ids,
"image_encoding": image_encoding,
"audio_encoding": audio_encoding,
}
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
embeddings = None
else:
mini_batch = self._parse_inputs(input_ids, **kwargs)
embeddings = self.compose_embeddings(mini_batch)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=embeddings,
)
self._called_cnt += 1
return output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
keys_to_modify_mapping = {
"llm.lm_head": "lm_head",
"vision.resampler": "resampler",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for key_to_modify, new_key in keys_to_modify_mapping.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# if "audio.positional_embedding" in name:
# continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
else:
print(f"Skipping loading of {name}")
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
else:
print(f"Skipping loading of {name}")
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm", connector="resampler", tower_model="vpm")
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(
LlamaModel(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
name=prefix,
)
def init_audio_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
return AudioModel(config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = LLMWrapper(
Idefics2VisionTransformer(config.vision_config),
name=prefix,
)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
resampler = Resampler(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix,
)
return resampler
================================================
FILE: vllm_demo/requirements.txt
================================================
vllm==0.6.3.post1
flash_attn==2.5.8
xformers==0.0.27.post2
================================================
FILE: vllm_demo/try_minicpm_v.py
================================================
from transformers import AutoTokenizer
from PIL import Image
from vllm import LLM, SamplingParams
MODEL_NAME = "/mnt/public/algm/models/MiniCPM-V-2_6/"
image = Image.open("../data/sample_image.jpg").convert("RGB")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
llm = LLM(
model=MODEL_NAME,
trust_remote_code=True,
gpu_memory_utilization=1,
max_model_len=2048
)
messages = [{
"role":
"user",
"content":
# Number of images
"(<image>./</image>)" + \
"\nWhat is the content of this image?"
}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Single Inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"image": image
# Multi images, the number of images should be equal to that of `(<image>./</image>)`
# "image": [image, image]
},
}
# Batch Inference
# inputs = [{
# "prompt": prompt,
# "multi_modal_data": {
# "image": image
# },
# } for _ in 2]
# 2.6
stop_tokens = ['<|im_end|>', '<|endoftext|>']
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
sampling_params = SamplingParams(
stop_token_ids=stop_token_ids,
use_beam_search=True,
temperature=0,
best_of=3,
max_tokens=1024
)
outputs = llm.generate(inputs, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
================================================
FILE: vllm_demo/try_qwen_vl.py
================================================
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"/mnt/public/algm/models/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
)
# default processer
processor = AutoProcessor.from_pretrained("/mnt/public/algm/models/Qwen2-VL-2B-Instruct")
prompt = "hi" * (128 - 1)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "../data/sample_image.jpg",
},
{"type": "text", "text": prompt},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
import pdb;pdb.set_trace()
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
import pdb;pdb.set_trace()
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
================================================
FILE: vllm_demo/vllm_profling.py
================================================
# -*- encoding: utf-8 -*-
# File: example_infer_vllm.py
# Description: None
from PIL import Image
from vllm import LLM
from vllm import ModelRegistry
from vllm import SamplingParams
from megrezo import MegrezOModel
ModelRegistry.register_model("MegrezO", MegrezOModel)
# Load the model.
model_path = "/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o" # Change this to the path of the model.
llm = LLM(
model_path,
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_num_seqs=8,
)
num_requests = 100
input_len = 128
output_length = 128
# prepare data
prompt = "hi" * (input_len - 1)
sampling_params = SamplingParams(
temperature=0,
max_tokens=output_length,
repetition_penalty=1.2,
stop=["<|turn_end|>", "<|eos|>"],
ignore_eos=True,
)
img = Image.open("../data/sample_image.jpg")
conversation = [
{
"role": "user",
"content": {
"text": prompt,
"image": img,
},
},
]
# Convert the conversation to vLLM acceptable format.
prompt = llm.get_tokenizer().apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
vllm_inputs = [
{
"prompt": prompt,
"multi_modal_data": {
"image": img,
},
}
for _ in range(num_requests)
]
# Generate the outputs.
outputs = llm.generate(
vllm_inputs,
sampling_params,
)
# Print the outputs.
# for output in outputs:
# print(output.outputs[0].text)
================================================
FILE: vllm_demo/vllm_profling_minicpm.py
================================================
from transformers import AutoTokenizer
from PIL import Image
from vllm import LLM, SamplingParams
model_path = "/mnt/public/algm/models/MiniCPM-V-2_6/"
image = Image.open("../data/sample_image.jpg").convert("RGB")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
llm = LLM(
model=model_path,
gpu_memory_utilization=0.9,
max_num_seqs=8,
trust_remote_code=True,
max_model_len=4096
)
num_requests = 100
input_len = 128
output_length = 128
# prepare data
prompt = "hi" * (input_len - 1)
sampling_params = SamplingParams(
temperature=0,
max_tokens=output_length,
repetition_penalty=1.2,
ignore_eos=True,
)
messages = [{
"role":
"user",
"content":
# Number of images
"(<image>./</image>)" + \
prompt
}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Single Inference
llm_inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": image
},
} for _ in range(num_requests)]
outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
================================================
FILE: vllm_demo/vllm_profling_qwen.py
================================================
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
# Load the model.
model_path = "/mnt/public/algm/models/Qwen2-VL-2B-Instruct" # Change this to the path of the model.
llm = LLM(
model=model_path,
limit_mm_per_prompt={"image": 10, "video": 10},
gpu_memory_utilization=0.9,
max_num_seqs=8,
)
num_requests = 100
input_len = 128
output_length = 128
# prepare data
prompt = "hi" * (input_len - 1)
sampling_params = SamplingParams(
temperature=0,
max_tokens=output_length,
repetition_penalty=1.2,
ignore_eos=True,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{
"type": "image",
"image": "../data/sample_image.jpg",
"min_pixels": 224 * 224,
"max_pixels": 1024 * 1024,
},
{"type": "text", "text": prompt},
],
},
]
processor = AutoProcessor.from_pretrained(model_path)
prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
mm_data = {}
if image_inputs is not None:
mm_data["image"] = image_inputs
if video_inputs is not None:
mm_data["video"] = video_inputs
llm_inputs = [
{
"prompt": prompt,
"multi_modal_data": mm_data,
}
for _ in range(num_requests)
]
outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
gitextract_nm6z7f7i/
├── LICENSE
├── README.md
├── README_zh.md
├── data/
│ └── train/
│ └── records.jsonl
├── example_chat_hf.py
├── finetune/
│ ├── dataset.py
│ ├── ds_config_zero2.json
│ ├── finetune.py
│ ├── finetune.sh
│ ├── requirements.txt
│ └── trainer.py
├── gradio_app.py
├── requirements.txt
└── vllm_demo/
├── example_infer_vllm.py
├── megrezo.py
├── requirements.txt
├── try_minicpm_v.py
├── try_qwen_vl.py
├── vllm_profling.py
├── vllm_profling_minicpm.py
└── vllm_profling_qwen.py
SYMBOL INDEX (81 symbols across 5 files)
FILE: finetune/dataset.py
class SupervisedDataset (line 13) | class SupervisedDataset(Dataset):
method __init__ (line 16) | def __init__(
method __len__ (line 29) | def __len__(self):
method check_ret (line 32) | def check_ret(self, ret):
method check_audio (line 46) | def check_audio(self, ret):
method prepare_labels (line 53) | def prepare_labels(self, data):
method add_dataset_prefix (line 100) | def add_dataset_prefix(self, item):
method __getitem__ (line 111) | def __getitem__(self, i):
FILE: finetune/finetune.py
class ModelArguments (line 30) | class ModelArguments:
class DataArguments (line 35) | class DataArguments:
class TrainingArguments (line 42) | class TrainingArguments(transformers.TrainingArguments):
class LoraArguments (line 61) | class LoraArguments:
function maybe_zero_3 (line 75) | def maybe_zero_3(param):
function get_peft_state_maybe_zero_3 (line 86) | def get_peft_state_maybe_zero_3(named_params, bias):
function rank0_print (line 114) | def rank0_print(*args):
function safe_save_model_for_hf_trainer (line 119) | def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
function make_supervised_data_module (line 133) | def make_supervised_data_module(
function get_parameter_number (line 160) | def get_parameter_number(model):
function load_model_from_pretrained (line 178) | def load_model_from_pretrained(model_path, dtype=torch.bfloat16):
function load_tokenizer_from_pretrained (line 185) | def load_tokenizer_from_pretrained(model_path):
function train (line 190) | def train():
FILE: finetune/trainer.py
class MegrezOTrainer (line 16) | class MegrezOTrainer(Trainer):
method compute_loss (line 17) | def compute_loss(self, model, inputs, return_outputs=False):
method prediction_step (line 55) | def prediction_step(
method training_step (line 140) | def training_step(
FILE: gradio_app.py
function main (line 31) | def main(model_path: str, port: int):
FILE: vllm_demo/megrezo.py
class MegrezORawImageInput (line 75) | class MegrezORawImageInput(TypedDict):
class MegrezOAudioInput (line 81) | class MegrezOAudioInput(TypedDict):
class MegrezOAudioTensorInput (line 87) | class MegrezOAudioTensorInput(TypedDict):
class MegrezOImagePixelInputs (line 95) | class MegrezOImagePixelInputs(TypedDict):
class MegrezOImageEmbeddingInputs (line 118) | class MegrezOImageEmbeddingInputs(TypedDict):
function insert_audio_embeddings (line 136) | def insert_audio_embeddings(text_embeddings, inserted_embeddings, insert...
function insert_image_embeddings (line 149) | def insert_image_embeddings(text_embeddings, inserted_embeddings, insert...
class Resampler (line 169) | class Resampler(nn.Module):
method __init__ (line 171) | def __init__(
method _init_weights (line 211) | def _init_weights(self, m: nn.Module) -> None:
method _repeat (line 220) | def _repeat(self, query, N: int):
method _set_2d_pos_cache (line 223) | def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.t...
method _adjust_pos_cache (line 228) | def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.typ...
method forward (line 240) | def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.T...
class LayerNorm (line 288) | class LayerNorm(nn.LayerNorm):
method forward (line 289) | def forward(self, x: Tensor) -> Tensor:
class Linear (line 294) | class Linear(nn.Linear):
method forward (line 295) | def forward(self, x: Tensor) -> Tensor:
class Conv1d (line 303) | class Conv1d(nn.Conv1d):
method _conv_forward (line 304) | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tens...
function sinusoids (line 308) | def sinusoids(length, channels, max_timescale=10000):
class MultiHeadAttention (line 317) | class MultiHeadAttention(nn.Module):
method __init__ (line 318) | def __init__(self, n_state: int, n_head: int):
method forward (line 326) | def forward(
method qkv_attention (line 348) | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optiona...
class ResidualAttentionBlock (line 363) | class ResidualAttentionBlock(nn.Module):
method __init__ (line 364) | def __init__(self, n_state: int, n_head: int, cross_attention: bool = ...
method forward (line 377) | def forward(
class AudioEncoder (line 391) | class AudioEncoder(nn.Module):
method __init__ (line 392) | def __init__(
method forward (line 427) | def forward(self, x: Tensor, padding_mask: Tensor = None, audio_length...
method encode (line 481) | def encode(
class AudioModel (line 506) | class AudioModel(torch.nn.Module):
method __init__ (line 508) | def __init__(self, config):
method forward (line 513) | def forward(self, audio_info):
function get_max_megrezo_image_tokens (line 524) | def get_max_megrezo_image_tokens(ctx: InputContext):
function dummy_seq_data_for_minicpmv (line 529) | def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
function dummy_image_for_minicpmv (line 533) | def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedCon...
function dummy_data_for_minicpmv (line 539) | def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, mm_counts: ...
function input_processor_for_megrezo (line 548) | def input_processor_for_megrezo(ctx: InputContext, inputs: DecoderOnlyIn...
function input_mapper_for_megrezo (line 592) | def input_mapper_for_megrezo(ctx: InputContext, data: object):
class MegrezOModel (line 601) | class MegrezOModel(nn.Module, VllmModelForTextGeneration, SupportsMultiM...
method __init__ (line 608) | def __init__(
method get_vision_hidden_states (line 647) | def get_vision_hidden_states(
method compose_embeddings (line 665) | def compose_embeddings(self, mini_batch):
method _parse_inputs (line 694) | def _parse_inputs(self, input_ids: torch.Tensor, **kwargs):
method forward (line 721) | def forward(
method compute_logits (line 753) | def compute_logits(
method sample (line 761) | def sample(
method load_weights (line 769) | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
method get_mm_mapping (line 835) | def get_mm_mapping(self) -> MultiModelKeys:
method init_llm (line 841) | def init_llm(
method init_audio_module (line 859) | def init_audio_module(
method init_vision_module (line 867) | def init_vision_module(
method init_resampler (line 881) | def init_resampler(
Condensed preview — 21 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (122K chars).
[
{
"path": "LICENSE",
"chars": 11336,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 11606,
"preview": "<div align=\"center\">\n\n# Megrez-3B-Omni: The First Open-Source End-Side Full Modality Understanding Model\n\n<p align=\"cent"
},
{
"path": "README_zh.md",
"chars": 7243,
"preview": "<div align=\"center\">\n\n# Megrez-3B-Omni: 首个端侧全模态理解开源模型\n\n<p align=\"center\">\n <img src=\"assets/megrez_logo.png\" width=\"4"
},
{
"path": "data/train/records.jsonl",
"chars": 8331,
"preview": "{\"conversations\":[{\"role\":\"user\",\"content\":{\"text\":\"Considering the items the figurine is holding, what scene or event f"
},
{
"path": "example_chat_hf.py",
"chars": 1032,
"preview": "# -*- encoding: utf-8 -*-\n# File: example_chat_hf.py\n# Description: None\n\nimport torch\nfrom transformers import AutoMode"
},
{
"path": "finetune/dataset.py",
"chars": 4373,
"preview": "# -*- encoding: utf-8 -*-\n# File: dataset.py\n# Description: None\n\nimport os\n\nimport numpy as np\nfrom regex import F\nimpo"
},
{
"path": "finetune/ds_config_zero2.json",
"chars": 821,
"preview": "{\n \"bf16\": {\n \"enabled\": \"auto\"\n },\n \"optimizer\": {\n \"type\": \"AdamW\",\n \"params\": {\n "
},
{
"path": "finetune/finetune.py",
"chars": 10092,
"preview": "# -*- encoding: utf-8 -*-\n# File: finetune.py\n# Description: None\n\n\nimport glob\nimport json\nimport logging\nimport os\nfro"
},
{
"path": "finetune/finetune.sh",
"chars": 1201,
"preview": "DATA_PATH=$(pwd)/../data/train/records.jsonl\nDATASET_PREFIX=$(pwd)/../data/train/\nCURRENT_TIME=$(date +%Y%m%d_%H%M%S)\nOU"
},
{
"path": "finetune/requirements.txt",
"chars": 20,
"preview": "deepspeed\naccelerate"
},
{
"path": "finetune/trainer.py",
"chars": 7640,
"preview": "# -*- encoding: utf-8 -*-\n# File: trainer.py\n# Description: None\n\n\nfrom typing import Any, Dict, List, Optional, Tuple, "
},
{
"path": "gradio_app.py",
"chars": 8198,
"preview": "# -*- encoding: utf-8 -*-\n# File: app.py\n# Description: None\n\n\nimport threading\nfrom copy import deepcopy\nfrom typing im"
},
{
"path": "requirements.txt",
"chars": 67,
"preview": "transformers>=4.44.0\ntokenizers>=0.20.3\naccelerate\ndatasets\ngradio\n"
},
{
"path": "vllm_demo/example_infer_vllm.py",
"chars": 1390,
"preview": "# -*- encoding: utf-8 -*-\n# File: example_infer_vllm.py\n# Description: None\n\nfrom PIL import Image\nfrom vllm import LLM\n"
},
{
"path": "vllm_demo/megrezo.py",
"chars": 32726,
"preview": "# coding=utf-8\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/m"
},
{
"path": "vllm_demo/requirements.txt",
"chars": 58,
"preview": "vllm==0.6.3.post1\nflash_attn==2.5.8\nxformers==0.0.27.post2"
},
{
"path": "vllm_demo/try_minicpm_v.py",
"chars": 1428,
"preview": "from transformers import AutoTokenizer\nfrom PIL import Image\nfrom vllm import LLM, SamplingParams\n\nMODEL_NAME = \"/mnt/pu"
},
{
"path": "vllm_demo/try_qwen_vl.py",
"chars": 1479,
"preview": "from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor\nfrom qwen_vl_utils import process"
},
{
"path": "vllm_demo/vllm_profling.py",
"chars": 1486,
"preview": "# -*- encoding: utf-8 -*-\n# File: example_infer_vllm.py\n# Description: None\n\nfrom PIL import Image\nfrom vllm import LLM\n"
},
{
"path": "vllm_demo/vllm_profling_minicpm.py",
"chars": 1119,
"preview": "from transformers import AutoTokenizer\nfrom PIL import Image\nfrom vllm import LLM, SamplingParams\n\n\nmodel_path = \"/mnt/p"
},
{
"path": "vllm_demo/vllm_profling_qwen.py",
"chars": 1584,
"preview": "from transformers import AutoProcessor\nfrom vllm import LLM, SamplingParams\nfrom qwen_vl_utils import process_vision_inf"
}
]
About this extraction
This page contains the full source code of the infinigence/Infini-Megrez-Omni GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 21 files (110.6 KB), approximately 27.9k tokens, and a symbol index with 81 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.