Showing preview only (309K chars total). Download the full file or copy to clipboard to get everything.
Repository: zhiyuanyou/DeQA-Score
Branch: main
Commit: 8e4abc826ed8
Files: 42
Total size: 294.2 KB
Directory structure:
gitextract_56m520fi/
├── .gitignore
├── LICENSE
├── README.md
├── build_soft_labels/
│ └── gen_soft_label.py
├── preprocessor/
│ └── tokenizer.model
├── pyproject.toml
├── scripts/
│ ├── eval_dist.sh
│ ├── eval_score.sh
│ ├── infer.sh
│ ├── infer_lora.sh
│ ├── train.sh
│ └── train_lora.sh
├── src/
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── pair_dataset.py
│ │ ├── single_dataset.py
│ │ └── utils.py
│ ├── evaluate/
│ │ ├── __init__.py
│ │ ├── cal_distribution_gap.py
│ │ ├── cal_plcc_srcc.py
│ │ ├── eval_qbench_mcq.py
│ │ ├── iqa_eval.py
│ │ └── scorer.py
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── configuration_mplug_owl2.py
│ │ ├── convert_mplug_owl2_weight_to_hf.py
│ │ ├── modeling_attn_mask_utils.py
│ │ ├── modeling_llama2.py
│ │ ├── modeling_mplug_owl2.py
│ │ ├── utils.py
│ │ └── visual_encoder.py
│ ├── train/
│ │ ├── mplug_owl2_trainer.py
│ │ └── train_mem.py
│ └── utils.py
└── tests/
├── datasets/
│ ├── test_pair_dataset.py
│ └── test_uncertainty_levels.py
└── model/
├── test_find_prefix.py
└── test_grad.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# source file related
*__pycache__*
*.pyc
*.o
*.so
*.egg
*.egg-info
# training related
*log*
*.log
*.pth
*.pt
# result related
*answer*
*ckpt*
*.json
*.jsonl
res_*
# mac hidden
*.DS_Store*
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Depicted image Quality Assessment (DepictQA / DeQA)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<div align="center">
<h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution</h1>
<div>
<a href="https://zhiyuanyou.github.io/" target="_blank">Zhiyuan You</a><sup>12</sup>,
<a href="https://caixin98.github.io/" target="_blank">Xin Cai</a><sup>2</sup>,
<a href="https://www.jasongt.com/" target="_blank">Jinjin Gu</a><sup>4</sup>,
<a href="https://tianfan.info/" target="_blank">Tianfan Xue</a><sup>235</sup><sup>#</sup>,
<a href="https://xpixel.group/2010/01/20/chaodong.html" target="_blank">Chao Dong</a><sup>134</sup><sup>#</sup>
</div>
<div>
<sup>1</sup>Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, <sup>2</sup>Multimedia Laboratory, The Chinese University of Hong Kong,
<sup>3</sup>Shanghai AI Laboratory, <sup>4</sup>Shenzhen University of Advanced Technology, <sup>5</sup>CPII under InnoHK
</div>
<div><sup>#</sup>Corresponding author.</div>
<div>
<a href="https://depictqa.github.io/deqa-score/" target="_blank"><strong>Homepage</strong></a> |
<strong>Model Weights</strong> (
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3" target="_blank"><strong>Full Tuning</strong></a> /
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3" target="_blank"><strong>LoRA Tuning</strong></a>
) |
<a href="https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score" target="_blank"><strong>Datasets</strong></a> |
<a href="https://arxiv.org/abs/2501.11561" target="_blank"><strong>Paper</strong></a>
</div>
<h2>Motivation</h2>
<div style="width: 100%; text-align: center; margin:auto;">
<img style="width: 75%" src="fig/teaser.png">
</div>
<h2>Model Architecture</h2>
<div style="width: 100%; text-align: center; margin:auto;">
<img style="width: 100%" src="fig/model.png">
</div>
</div>
## [Installation Free!] Quicker Start with Hugging Face AutoModel
[2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32).
The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo.
```python
import requests
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"zhiyuanyou/DeQA-Score-Mix3",
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16,
device_map="auto",
)
from PIL import Image
# The inputs should be a list of multiple PIL images
model.score(
[Image.open(requests.get(
"https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True
).raw)]
)
```
## Installation
If you only need to infer / evaluate:
```shell
git clone https://github.com/zhiyuanyou/DeQA-Score.git
cd DeQA-Score
pip install -e .
```
For training, you need to further install additional dependencies as follows:
```shell
pip install -e ".[train]"
pip install flash_attn --no-build-isolation
```
## Quick Start
### Image Quality Scorer
- CLI Interface
```shell
python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg
```
- Python API
```python
from src import Scorer
from PIL import Image
scorer = Scorer()
img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images
print(scorer(img_list).tolist())
```
## Training, Inference & Evaluation
### Datasets
<a id="datasets"></a>
- Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score).
- Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html),
[SPAQ](https://github.com/h4nwei/SPAQ),
[KADID](https://database.mmsp-kn.de/kadid-10k-database.html),
[PIPAL](https://github.com/HaomingCai/PIPAL-dataset),
[LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html),
[AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database),
[TID2013](https://www.ponomarenko.info/tid2013.htm),
and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html).
- Arrange the folders as follows:
```
|-- DeQA-Score
|-- Data-DeQA-Score
|-- KONIQ
|-- images/*.jpg
|-- metas
|-- SPAQ
|-- images/*.jpg
|-- metas
|-- KADID10K
|-- images/*.png
|-- metas
|-- PIPAL
|-- images/Distortion_*/*.bmp
|-- metas
|-- LIVE-WILD
|-- images/*.bmp
|-- metas
|-- AGIQA3K
|-- images/*.jpg
|-- metas
|-- TID2013
|-- images/distorted_images/*.bmp
|-- metas
|-- CSIQ
|-- images/dst_imgs/*/*.png
|-- metas
```
### Pretrained Weights
<a id="pretrained_weights"></a>
We provide two model weights (full tuning and LoRA tuning) with similar performance.
| | Training Datasets | Weights |
|-----|-----|-----|
| Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) |
| LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) |
Download one of the above model weights, then arrange the folders as follows:
```
|-- DeQA-Score
|-- checkpoints
|-- DeQA-Score-Mix3
|-- DeQA-Score-LoRA-Mix3
```
If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows:
```
|-- DeQA-Score
|-- ModelZoo
|-- mplug-owl2-llama2-7b
```
### Inference
After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**:
```shell
sh scripts/infer.sh $ONE_GPU_ID
```
```shell
sh scripts/infer_lora.sh $ONE_GPU_ID
```
### Evaluation
After inference, you can evaluate the inference results:
- SRCC / PLCC for quality score.
```shell
sh scripts/eval_score.sh
```
- KL Divergence / JS Divergence / Wasserstein Distance for score distribution.
```shell
sh scripts/eval_dist.sh
```
### Fine-tuning
Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights).
#### LoRA Fine-tuning
- Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
```shell
sh scripts/train_lora.sh $GPU_IDs
```
#### Full Fine-tuning from the Scratch
- At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
```shell
sh scripts/train.sh $GPU_IDs
```
## Soft Label Construction
- Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets).
- Run the following scripts to construct the distribution-based soft labels.
```shell
cd build_soft_labels
python gen_soft_label.py
```
## Acknowledgements
This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work.
## Citation
If you find our work useful for your research and applications, please cite using the BibTeX:
```bibtex
@inproceedings{deqa_score,
title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution},
author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={14483--14494},
year={2025}
}
@article{depictqa_v2,
title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset},
author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan},
journal={IEEE Transactions on Image Processing},
year={2025}
}
@inproceedings{depictqa_v1,
title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models},
author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao},
booktitle={European Conference on Computer Vision},
pages={259--276},
year={2024}
}
```
================================================
FILE: build_soft_labels/gen_soft_label.py
================================================
import argparse
import json
import numpy as np
import os
import random
from scipy.stats import norm, pearsonr, spearmanr
def parse_args():
parser = argparse.ArgumentParser(description="label parameters for DeQA-Score")
parser.add_argument("--config", type=str, default="./config.json")
args = parser.parse_args()
return args
questions = [
"What do you think about the quality of this image?",
"Can you rate the quality of this picture?",
"Can you judge the quality of this image?",
"How would you rate the quality of this image?",
"How would you judge the quality of this image?",
"What is your quality rating for this image?",
"What's your opinion on the quality of this picture?",
"Rate the quality of this image.",
"Could you evaluate the quality of this image?",
"How do you assess the quality of this image?",
]
def calculate_srcc_plcc(pred, mos):
srcc, _ = spearmanr(pred, mos)
plcc, _ = pearsonr(pred, mos)
return srcc, plcc
def get_level(mos, min_mos, max_mos):
eps = 1e-8
texts = ["bad", "poor", "fair", "good", "excellent"]
for idx in range(1, len(texts) + 1):
mos_left = min_mos + (idx - 1) / 5 * (max_mos - min_mos) - eps
mos_right = min_mos + idx / 5 * (max_mos - min_mos) + eps
if mos > mos_left and mos <= mos_right:
level = idx
break
text = texts[level - 1]
return text
def adjust_gaussian_bar(probs, score):
"""
alpha * (a + b + c + d + e) + 5 * beta = 1
alpha * (5a + 4b + 3c + 2d + e) + 15 beta = score
==>
alpha * A + 5 * beta = 1
alpha * B + 15 * beta = score
"""
A = np.array(probs).sum()
B = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
alpha = (score - 3) / (B - 3. * A + 1e-9)
beta = (1. - alpha * A) / 5.
return alpha, beta
def get_binary_probs(mos, min_mos=1.0, max_mos=5.0):
eps = 1e-8
probs = [0, 0, 0, 0, 0]
for idx in range(1, len(probs)):
mos_left = min_mos + (idx - 1) / 4 * (max_mos - min_mos) - eps
mos_right = min_mos + idx / 4 * (max_mos - min_mos) + eps
if mos > mos_left and mos <= mos_right:
probs[idx - 1] = (mos_right - mos) / (mos_right - mos_left)
probs[idx] = (mos - mos_left) / (mos_right - mos_left)
break
assert np.array((np.array(probs) == 0)).sum() == 3
assert round(np.array(probs).sum(), 5) == 1
probs = probs[::-1] # should start with "excellent" & end with "bad"
return probs
def main(cfg):
density_type = cfg["density_type"] # ["pdf", "cdf"]
thre_std = cfg["thre_std"]
thre_diff = cfg["thre_diff"]
with open(cfg["split_json"]) as fr:
split = json.load(fr)
with open(cfg["mos_json"]) as fr:
mos_dict = json.load(fr)
save_train = cfg["save_train"]
save_test = cfg["save_test"]
img_dir = cfg["img_dir"]
moses, stds, imgs = [], [], []
for img in mos_dict:
moses.append(mos_dict[img]["mos"])
stds.append(mos_dict[img]["std"])
imgs.append(img)
max_mos = max([float(_) for _ in moses])
min_mos = min([float(_) for _ in moses])
num_binary, idx = 0, 0
preds, gts, raw_diffs, diffs, alphas, betas = [], [], [], [], [], []
train_metas, test_metas = [], []
for img, mos_str, std_str in zip(imgs, moses, stds):
mos, std = float(mos_str), float(std_str)
if os.path.basename(img) in split["train"]:
training = True
elif os.path.basename(img) in split["test"]:
training = False
else:
idx += 1
# print(idx, img)
continue
text = get_level(mos, min_mos, max_mos)
query = random.choice(questions)
resp = answer.replace("{}", text)
# norm mos and std
mos_norm = 4 * (mos - min_mos) / (max_mos - min_mos) + 1 # [0, 1] -> [1, 5]
std_norm = 4 * std / (max_mos - min_mos)
# ["excellent", "good", "fair", "poor", "bad"] -> [5, 4, 3, 2, 1]
probs = []
for x in range(5, 0, -1):
if density_type == "cdf":
# better for smaller std dataset (see Appendix) like SPAQ
prob = norm.cdf(x+0.5, mos_norm, std_norm) - norm.cdf(x-0.5, mos_norm, std_norm)
else:
# better for larger std dataset (see Appendix) like KonIQ and KADID
assert density_type == "pdf"
prob = norm.pdf(x, loc=mos_norm, scale=std_norm)
probs.append(prob)
mos_rec = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
raw_diff = abs(mos_rec - mos_norm)
raw_diffs.append(raw_diff)
alpha, beta = adjust_gaussian_bar(probs, mos_norm)
probs_norm = [max(_ * alpha + beta, 0) for _ in probs]
mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
diff = abs(mos_rec - mos_norm)
if std_norm < thre_std or diff > thre_diff:
# if std is too small, use binary probs (see Appendix)
probs_norm = get_binary_probs(mos_norm)
mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
diff, alpha, beta = abs(mos_rec - mos_norm), 1., 0.
num_binary += 1
preds.append(mos_rec)
gts.append(mos_norm)
diffs.append(diff)
alphas.append(alpha)
betas.append(beta)
meta = {
"id": os.path.basename(img) + f"->{mos_str}",
"image": os.path.join(img_dir, img),
"gt_score": mos,
"gt_score_norm": mos_norm,
"level_probs_org": probs,
"level_probs": probs_norm,
"std": std,
"std_norm": std_norm,
}
if training:
conversations = [
{
"from": "human",
"value": query + "\n<|image|>",
},
{
"from": "gpt",
"value": resp,
},
]
meta["conversations"] = conversations
train_metas.append(meta)
else:
del meta["level_probs_org"]
del meta["level_probs"]
test_metas.append(meta)
print("=" * 100)
print(f"save {len(train_metas)} into {save_train}")
with open(save_train, "w") as fw:
fw.write(json.dumps(train_metas, indent=4))
print(f"save {len(test_metas)} into {save_test}")
with open(save_test, "w") as fw:
fw.write(json.dumps(test_metas, indent=4))
srcc, plcc = calculate_srcc_plcc(preds, gts)
print("srcc:", srcc, "plcc:", plcc)
print("[raw_diff]", "l1:", sum(raw_diffs) / len(raw_diffs), "l2:", np.sqrt((np.array(raw_diffs)**2).mean()))
print("[diff]", "l1:", sum(diffs) / len(diffs), "l2:", np.sqrt((np.array(diffs)**2).mean()))
print("[alpha]", "mean:", np.mean(alphas), "std:", np.std(alphas))
print("[beta]", "mean:", np.mean(betas), "std:", np.std(betas))
print("binary / all:", num_binary, "/", len(train_metas) + len(test_metas))
if __name__ == "__main__":
args = parse_args()
with open(args.config) as fr:
cfg = json.load(fr)
answer = cfg["answer"]
for dataset in cfg["dataset_params"]:
random.seed(131)
main(cfg["dataset_params"][dataset])
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "DeQA-Score"
version = "1.2.0"
description = "Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)"
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"torch==2.0.1", "torchvision==0.15.2",
"transformers==4.36.1", "tokenizers==0.15.0", "sentencepiece==0.1.99", "shortuuid",
"accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
"pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
"gradio==3.35.2", "gradio_client==0.2.9",
"requests", "httpx==0.24.0", "uvicorn", "fastapi", "icecream",
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "decord", "scipy",
]
[project.optional-dependencies]
train = ["deepspeed==0.9.5", "ninja", "wandb"]
[project.urls]
"Bug Tracker" = "https://github.com/zhiyuanyou/DeQA-Score/issues"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
================================================
FILE: scripts/eval_dist.sh
================================================
export PYTHONPATH=./:$PYTHONPATH
res_dir=./results/res_deqa_mix3/
gt_dir=../Data-DeQA-Score/
python src/evaluate/cal_distribution_gap.py \
--level_names excellent good fair poor bad \
--pred_paths $res_dir/test_koniq_2k.json \
$res_dir/test_spaq_2k.json \
$res_dir/test_kadid_2k.json \
$res_dir/test_pipal_5k.json \
$res_dir/test_livew_1k.json \
$res_dir/test_agiqa_3k.json \
$res_dir/test_tid2013_3k.json \
$res_dir/test_csiq_866.json \
--gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
$gt_dir/SPAQ/metas/test_spaq_2k.json \
$gt_dir/KADID10K/metas/test_kadid_2k.json \
$gt_dir/PIPAL/metas/test_pipal_5k.json \
$gt_dir/LIVE-WILD/metas/test_livew_1k.json \
$gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
$gt_dir/TID2013/metas/test_tid2013_3k.json \
$gt_dir/CSIQ/metas/test_csiq_866.json \
================================================
FILE: scripts/eval_score.sh
================================================
export PYTHONPATH=./:$PYTHONPATH
res_dir=./results/res_deqa_mix3/
gt_dir=../Data-DeQA-Score/
python src/evaluate/cal_plcc_srcc.py \
--level_names excellent good fair poor bad \
--pred_paths $res_dir/test_koniq_2k.json \
$res_dir/test_spaq_2k.json \
$res_dir/test_kadid_2k.json \
$res_dir/test_pipal_5k.json \
$res_dir/test_livew_1k.json \
$res_dir/test_agiqa_3k.json \
$res_dir/test_tid2013_3k.json \
$res_dir/test_csiq_866.json \
--gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
$gt_dir/SPAQ/metas/test_spaq_2k.json \
$gt_dir/KADID10K/metas/test_kadid_2k.json \
$gt_dir/PIPAL/metas/test_pipal_5k.json \
$gt_dir/LIVE-WILD/metas/test_livew_1k.json \
$gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
$gt_dir/TID2013/metas/test_tid2013_3k.json \
$gt_dir/CSIQ/metas/test_csiq_866.json \
================================================
FILE: scripts/infer.sh
================================================
export CUDA_VISIBLE_DEVICES=$1
export PYTHONPATH=./:$PYTHONPATH
python src/evaluate/iqa_eval.py \
--level-names excellent good fair poor bad \
--model-path checkpoints/DeQA-Score-Mix3/ \
--save-dir results/res_deqa_mix3/ \
--preprocessor-path ./preprocessor/ \
--root-dir ../Data-DeQA-Score/ \
--meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
================================================
FILE: scripts/infer_lora.sh
================================================
export CUDA_VISIBLE_DEVICES=$1
export PYTHONPATH=./:$PYTHONPATH
python src/evaluate/iqa_eval.py \
--level-names excellent good fair poor bad \
--model-path checkpoints/DeQA-Score-LoRA-Mix3/ \
--model-base ../ModelZoo/mplug-owl2-llama2-7b/ \
--save-dir results/res_deqa_lora_mix3/ \
--preprocessor-path ./preprocessor/ \
--root-dir ../Data-DeQA-Score/ \
--meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
================================================
FILE: scripts/train.sh
================================================
#!/bin/bash
export PYTHONPATH=./:$PYTHONPATH
LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
--deepspeed scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
--dataset_type pair \
--level_prefix "The quality of the image is" \
--level_names excellent good fair poor bad \
--softkl_loss True \
--weight_rank 1.0 \
--weight_softkl 1.0 \
--weight_next_token 0.05 \
--continuous_rating_loss True \
--closeset_rating_loss True \
--use_fix_std True \
--detach_pred_std True \
--data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
--data_weights 1 1 1 \
--image_folder ../Data-DeQA-Score/ \
--output_dir ./checkpoints/deqa_mix3_rank1.0_next0.05_kl1.0/ \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
================================================
FILE: scripts/train_lora.sh
================================================
#!/bin/bash
export PYTHONPATH=./:$PYTHONPATH
LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
--deepspeed scripts/zero3.json \
--lora_enable True \
--model_name_or_path $LOAD \
--version v1 \
--dataset_type pair \
--level_prefix "The quality of the image is" \
--level_names excellent good fair poor bad \
--softkl_loss True \
--weight_rank 1.0 \
--weight_softkl 1.0 \
--weight_next_token 0.05 \
--continuous_rating_loss True \
--closeset_rating_loss True \
--use_fix_std True \
--detach_pred_std True \
--data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
--data_weights 1 1 1 \
--image_folder ../Data-DeQA-Score/ \
--output_dir ./checkpoints/deqa_lora_mix3_rank1.0_next0.05_kl1.0 \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
================================================
FILE: src/__init__.py
================================================
from .model import MPLUGOwl2LlamaForCausalLM
from .evaluate import Scorer
================================================
FILE: src/constants.py
================================================
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "./demo_logs"
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<|image|>"
================================================
FILE: src/conversation.py
================================================
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
from src.constants import DEFAULT_IMAGE_TOKEN
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
TWO_NO_SYS = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
# init_msg = init_msg[0].replace("<image>", "").strip()
# if 'mmtag' in self.version:
# messages[0] = (init_role, init_msg)
# messages.insert(0, (self.roles[0], "<Image><image></Image>"))
# messages.insert(1, (self.roles[1], "Received."))
# else:
# messages[0] = (init_role, "<image>\n" + init_msg)
init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<|image|>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_mplug_owl2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO_NO_SYS,
sep=" ",
sep2="</s>",
)
# default_conversation = conv_vicuna_v1
default_conversation = conv_mplug_owl2
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"mplug_owl2": conv_mplug_owl2,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
================================================
FILE: src/datasets/__init__.py
================================================
from .pair_dataset import make_pair_data_module
from .single_dataset import make_single_data_module
def make_data_module(tokenizer, data_args):
if data_args.dataset_type == "single":
return make_single_data_module(tokenizer, data_args)
elif data_args.dataset_type == "pair":
return make_pair_data_module(tokenizer, data_args)
else:
raise ValueError
================================================
FILE: src/datasets/pair_dataset.py
================================================
import copy
import json
import os
import random
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
from PIL import Image
from torch.utils.data import Dataset
from src.constants import IGNORE_INDEX
from .utils import (expand2square, load_video, preprocess,
preprocess_multimodal, rank0_print)
class PairDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_paths,
data_weights,
tokenizer: transformers.PreTrainedTokenizer,
data_args,
):
super(PairDataset, self).__init__()
dataset_list = [] # list (different datasets) of list (samples in one dataset)
for data_path, data_weight in zip(data_paths, data_weights):
data_list = json.load(open(data_path, "r"))
dataset_list.append(data_list * data_weight)
self.dataset_list = dataset_list
# Construct nums_data, nums_data[i] is the number of samples in 0-i th datasets
nums_eachdata = [len(_) for _ in self.dataset_list]
nums_predata = copy.deepcopy(nums_eachdata)
for idx in range(1, len(nums_predata)):
nums_predata[idx] = nums_predata[idx] + nums_predata[idx - 1]
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.nums_eachdata = nums_eachdata
self.nums_predata = nums_predata
self.data_args = data_args
assert self.nums_predata[-1] == sum(self.nums_eachdata)
def __len__(self):
return self.nums_predata[-1]
@property
def lengths(self):
length_list = []
for dataset in self.dataset_list:
for sample in dataset:
img_tokens = 128 if "image" in sample else 0
length_list.append(
sum(len(conv["value"].split()) for conv in sample["conversations"])
+ img_tokens
)
return length_list
@property
def modality_lengths(self):
length_list = []
for dataset in self.dataset_list:
for sample in dataset:
cur_len = sum(
len(conv["value"].split()) for conv in sample["conversations"]
)
cur_len = cur_len if "image" in sample else -cur_len
length_list.append(cur_len)
return length_list
def next_rand(self):
return random.randint(0, len(self) - 1)
def __getitem__(self, i):
while True:
try:
# Get idx_dataset, idx_sample
if i < self.nums_predata[0]:
idx_dataset = 0
idx_sample = i
else:
for idx_dataset in range(1, len(self.nums_predata)):
if (
i < self.nums_predata[idx_dataset]
and i >= self.nums_predata[idx_dataset - 1]
):
idx_sample = i - self.nums_predata[idx_dataset - 1]
break
# Sample two items
item_A = self.get_one_item(idx_dataset, idx_sample)
while True:
idx_sample_B = random.randint(
0, self.nums_eachdata[idx_dataset] - 1
)
if idx_sample_B != idx_sample:
break
item_B = self.get_one_item(idx_dataset, idx_sample_B)
return {
"item_A": item_A,
"item_B": item_B,
}
except Exception as ex:
print(ex)
i = self.next_rand()
continue
def get_one_item(self, idx_dataset, idx_sample) -> Dict[str, torch.Tensor]:
# For IQA data, i must be int
sources = [self.dataset_list[idx_dataset][idx_sample]]
sources_org = copy.deepcopy(sources)
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources_org[0]:
image_file = sources[0]["image"]
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
if isinstance(image_file, list):
# Multiple Images as Input
image = [
Image.open(os.path.join(image_folder, imfile)).convert("RGB")
for imfile in image_file
]
if self.data_args.image_aspect_ratio == "pad":
image = [
expand2square(
img,
tuple(int(x * 255) for x in processor.image_mean),
)
for img in image
]
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
elif os.path.join(image_folder, image_file).endswith("mp4"):
# Video as Input
image = load_video(os.path.join(image_folder, image_file))
if self.data_args.image_aspect_ratio == "pad":
image = [
expand2square(
img,
tuple(int(x * 255) for x in processor.image_mean),
)
for img in image
]
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = Image.open(os.path.join(image_folder, image_file)).convert(
"RGB"
)
if self.data_args.image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_mean)
)
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
)
else:
# Without images
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=("image" in sources_org[0]),
)
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
)
# default task_type: "score", gt_socre & std: -10000, level_probs: [-10000] * 5
data_dict["task_type"] = sources_org[0].get("task_type", "score")
data_dict["gt_score"] = sources_org[0].get("gt_score", -10000)
data_dict["std"] = sources_org[0].get("std", -10000)
data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
# image exist in the data
if "image" in sources_org[0]:
data_dict["image_file"] = image_file
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
return data_dict
@dataclass
class DataCollatorForPairDataset(object):
"""Collate examples for pair fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
instances_A = [instance["item_A"] for instance in instances]
instances_B = [instance["item_B"] for instance in instances]
batch = {
"input_type": "pair",
"item_A": self.collate_one(instances_A),
"item_B": self.collate_one(instances_B),
}
return batch
def collate_one(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
batch["task_types"] = [instance["task_type"] for instance in instances]
batch["gt_scores"] = torch.tensor([instance["gt_score"] for instance in instances])
batch["stds"] = torch.tensor([instance["std"] for instance in instances])
batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch["images"] = torch.stack(images)
else:
batch["images"] = images
batch["image_files"] = [instance["image_file"] for instance in instances]
return batch
def make_pair_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = PairDataset(
tokenizer=tokenizer,
data_paths=data_args.data_paths,
data_weights=data_args.data_weights,
data_args=data_args,
)
data_collator = DataCollatorForPairDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
================================================
FILE: src/datasets/single_dataset.py
================================================
import copy
import json
import os
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
from PIL import Image
from torch.utils.data import Dataset
from src.constants import IGNORE_INDEX
from .utils import (expand2square, load_video, preprocess,
preprocess_multimodal, rank0_print)
class SingleDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_paths: str,
data_weights: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args,
):
super(SingleDataset, self).__init__()
list_data_dict = []
for data_path, data_weight in zip(data_paths, data_weights):
data_dict = json.load(open(data_path, "r"))
list_data_dict += data_dict * data_weight
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(
sum(len(conv["value"].split()) for conv in sample["conversations"])
+ img_tokens
)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(
len(conv["value"].split()) for conv in sample["conversations"]
)
cur_len = cur_len if "image" in sample else -cur_len
length_list.append(cur_len)
return length_list
def next_rand(self):
import random
return random.randint(0, len(self) - 1)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
while True:
try:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
sources_org = copy.deepcopy(sources)
assert (
len(sources) == 1
), "Don't know why it is wrapped to a list" # FIXME
if "image" in sources_org[0]:
image_file = sources_org[0]["image"]
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
from pathlib import Path
# if not Path(os.path.join(image_folder, image_file)).exists():
# i = self.next_rand()
# continue
if isinstance(image_file, list):
# Multiple Images as Input
try:
image = [
Image.open(os.path.join(image_folder, imfile)).convert(
"RGB"
)
for imfile in image_file
]
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == "pad":
image = [
expand2square(
img,
tuple(int(x * 255) for x in processor.image_mean),
)
for img in image
]
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
elif os.path.join(image_folder, image_file).endswith("mp4"):
# Video as Input
image = load_video(os.path.join(image_folder, image_file))
if self.data_args.image_aspect_ratio == "pad":
image = [
expand2square(
img,
tuple(int(x * 255) for x in processor.image_mean),
)
for img in image
]
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
try:
image = Image.open(
os.path.join(image_folder, image_file)
).convert("RGB")
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_mean)
)
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=("image" in sources_org[0]),
)
if isinstance(i, int):
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
)
# default task_type: "score", level_probs: [-10000] * 5
data_dict["task_type"] = sources_org[0].get("task_type", "score")
data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
# image exist in the data
if "image" in sources_org[0]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = torch.zeros(
3, crop_size["height"], crop_size["width"]
)
return data_dict
except Exception as ex:
print(ex)
i = self.next_rand()
continue
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
batch = dict(
input_type="single",
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
batch["task_types"] = [instance["task_type"] for instance in instances]
batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch["images"] = torch.stack(images)
else:
batch["images"] = images
return batch
def make_single_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SingleDataset(
tokenizer=tokenizer,
data_paths=data_args.data_paths,
data_weights=data_args.data_weights,
data_args=data_args,
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
================================================
FILE: src/datasets/utils.py
================================================
import copy
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.distributed as dist
import transformers
from PIL import Image
from src import conversation as conversation_lib
from src.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
from src.mm_utils import tokenizer_image_token
def rank0_print(*args):
try:
if dist.get_rank() == 0:
print(*args)
except:
print(*args)
@dataclass
class DataArguments:
data_paths: List[str] = field(default_factory=lambda: [])
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = "square"
image_grid_pinpoints: Optional[str] = field(default=None)
def _tokenize_fn(
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = (
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
)
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
)
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
replace_token = DEFAULT_IMAGE_TOKEN
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, replace_token
)
return sources
def preprocess_v1(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert (
conv.sep_style == conversation_lib.SeparatorStyle.TWO
or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
)
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1 + 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
round_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"]
+ source[1]["value"]
+ conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if (
conversation_lib.default_conversation.sep_style
== conversation_lib.SeparatorStyle.PLAIN
):
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn(
[header] + [s["value"] for s in source], tokenizer
)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
def load_video(video_file):
from decord import VideoReader
vr = VideoReader(video_file)
# Get video frame rate
fps = vr.get_avg_fps()
# Calculate frame indices for 1fps
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
frames = vr.get_batch(frame_indices).asnumpy()
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
================================================
FILE: src/evaluate/__init__.py
================================================
from .scorer import Scorer
================================================
FILE: src/evaluate/cal_distribution_gap.py
================================================
import argparse
import json
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
parser.add_argument("--level_names", type=str, required=True, nargs="+")
parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
parser.add_argument("--use_openset_probs", action="store_true")
args = parser.parse_args()
return args
def kl_divergence(mu_1, mu_2, sigma_1, sigma_2):
"""
Calculate the Kullback-Leibler (KL) divergence between two Gaussian distributions for numpy arrays.
Parameters:
mu_1 (np.array): Mean of the first distribution (array of size N).
mu_2 (np.array): Mean of the second distribution (array of size N).
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
Returns:
np.array: KL divergence from distribution 1 to distribution 2 (array of size N).
"""
eps = 1e-8
return np.log(sigma_2 / (sigma_1 + eps)) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2 * sigma_2**2 + eps) - 0.5
def js_divergence(mu_1, mu_2, sigma_1, sigma_2):
"""
Calculate the Jensen-Shannon (JS) divergence between two Gaussian distributions for numpy arrays.
Parameters:
mu_1 (np.array): Mean of the first distribution (array of size N).
mu_2 (np.array): Mean of the second distribution (array of size N).
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
Returns:
np.array: JS divergence between the two distributions (array of size N).
"""
# Midpoint distribution parameters
mu_m = 0.5 * (mu_1 + mu_2)
sigma_m = np.sqrt(0.5 * (sigma_1**2 + sigma_2**2))
# JS divergence as the average of the KL divergences
return 0.5 * kl_divergence(mu_1, mu_m, sigma_1, sigma_m) + 0.5 * kl_divergence(mu_2, mu_m, sigma_2, sigma_m)
def wasserstein_distance(mu_1, mu_2, sigma_1, sigma_2):
"""
Calculate the Wasserstein distance between two Gaussian distributions for numpy arrays.
Parameters:
mu_1 (np.array): Mean of the first distribution (array of size N).
mu_2 (np.array): Mean of the second distribution (array of size N).
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
Returns:
np.array: Wasserstein distance between the two distributions (array of size N).
"""
return np.sqrt((mu_1 - mu_2)**2 + (sigma_1 - sigma_2)**2)
def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
if use_openset_probs:
assert logits is None
probs = np.array([probs[_] for _ in level_names])
else:
assert probs is None
logprobs = np.array([logits[_] for _ in level_names])
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
return score, probs
def cal_std(score, probs):
variance = (np.array([5., 4., 3., 2., 1.]) - score) * (np.array([5., 4., 3., 2., 1.]) - score)
variance = np.inner(probs, variance)
std = np.sqrt(variance)
return std
if __name__ == "__main__":
args = parse_args()
level_names = args.level_names
pred_paths = args.pred_paths
gt_paths = args.gt_paths
use_openset_probs = args.use_openset_probs
for pred_path, gt_path in zip(pred_paths, gt_paths):
print("=" * 100)
print("Pred: ", pred_path)
print("GT: ", gt_path)
# load predict results
pred_metas = []
with open(pred_path) as fr:
for line in fr:
pred_meta = json.loads(line)
pred_metas.append(pred_meta)
# load gt results
with open(gt_path) as fr:
gt_metas = json.load(fr)
pred_metas.sort(key=lambda x: x["id"])
gt_metas.sort(key=lambda x: x["id"])
mu_preds = []
std_preds = []
mu_gts = []
std_gts = []
for pred_meta, gt_meta in zip(pred_metas, gt_metas):
assert pred_meta["id"] == gt_meta["id"]
if use_openset_probs:
pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=True)
else:
pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
pred_std = cal_std(pred_score, probs)
mu_preds.append(pred_score)
std_preds.append(pred_std)
mu_gts.append(gt_meta["gt_score_norm"])
std_gts.append(gt_meta["std_norm"])
mu_preds = np.array(mu_preds)
std_preds = np.array(std_preds)
mu_gts = np.array(mu_gts)
std_gts = np.array(std_gts)
kl = kl_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
js = js_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
wd = wasserstein_distance(mu_gts, mu_preds, std_gts, std_preds).mean()
print(f"KL: {kl}")
print(f"JS: {js}")
print(f"WD: {wd}")
================================================
FILE: src/evaluate/cal_plcc_srcc.py
================================================
import argparse
import json
import numpy as np
from scipy.optimize import curve_fit
from scipy.stats import pearsonr, spearmanr
def parse_args():
parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
parser.add_argument("--level_names", type=str, required=True, nargs="+")
parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
parser.add_argument("--use_openset_probs", action="store_true")
args = parser.parse_args()
return args
def calculate_srcc(pred, mos):
srcc, _ = spearmanr(pred, mos)
return srcc
def calculate_plcc(pred, mos):
plcc, _ = pearsonr(pred, mos)
return plcc
def fit_curve(x, y, curve_type="logistic_4params"):
r"""Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG.
The function with 4 params is more commonly used.
The 5 params function takes from DBCNN:
- https://github.com/zwx8981/DBCNN/blob/master/dbcnn/tools/verify_performance.m
"""
assert curve_type in [
"logistic_4params",
"logistic_5params",
], f"curve type should be in [logistic_4params, logistic_5params], but got {curve_type}."
betas_init_4params = [np.max(y), np.min(y), np.mean(x), np.std(x) / 4.0]
def logistic_4params(x, beta1, beta2, beta3, beta4):
yhat = (beta1 - beta2) / (1 + np.exp(-(x - beta3) / beta4)) + beta2
return yhat
betas_init_5params = [10, 0, np.mean(y), 0.1, 0.1]
def logistic_5params(x, beta1, beta2, beta3, beta4, beta5):
logistic_part = 0.5 - 1.0 / (1 + np.exp(beta2 * (x - beta3)))
yhat = beta1 * logistic_part + beta4 * x + beta5
return yhat
if curve_type == "logistic_4params":
logistic = logistic_4params
betas_init = betas_init_4params
elif curve_type == "logistic_5params":
logistic = logistic_5params
betas_init = betas_init_5params
betas, _ = curve_fit(logistic, x, y, p0=betas_init, maxfev=10000)
yhat = logistic(x, *betas)
return yhat
def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
if use_openset_probs:
assert logits is None
probs = np.array([probs[_] for _ in level_names])
else:
assert probs is None
logprobs = np.array([logits[_] for _ in level_names])
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
return score
if __name__ == "__main__":
args = parse_args()
level_names = args.level_names
pred_paths = args.pred_paths
gt_paths = args.gt_paths
use_openset_probs = args.use_openset_probs
for pred_path, gt_path in zip(pred_paths, gt_paths):
print("=" * 100)
print("Pred: ", pred_path)
print("GT: ", gt_path)
# load predict results
pred_metas = []
with open(pred_path) as fr:
for line in fr:
pred_meta = json.loads(line)
pred_metas.append(pred_meta)
# load gt results
with open(gt_path) as fr:
gt_metas = json.load(fr)
preds = []
gts = []
for pred_meta, gt_meta in zip(pred_metas, gt_metas):
assert pred_meta["id"] == gt_meta["id"]
if use_openset_probs:
pred_score = cal_score(level_names, probs=pred_meta["probs"], use_openset_probs=True)
else:
pred_score = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
preds.append(pred_score)
gts.append(gt_meta["gt_score"])
preds_fit = fit_curve(preds, gts)
srcc = calculate_srcc(preds_fit, gts)
plcc = calculate_plcc(preds_fit, gts)
print(f"SRCC: {srcc}")
print(f"PLCC: {plcc}")
================================================
FILE: src/evaluate/eval_qbench_mcq.py
================================================
import argparse
import torch
from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from src.conversation import conv_templates, SeparatorStyle
from src.model.builder import load_pretrained_model
from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
import json
from tqdm import tqdm
import os
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
os.makedirs(args.save_dir, exist_ok=True)
with open(args.meta_path) as f:
llvqa_data = json.load(f)
pbar = tqdm(total=len(llvqa_data))
conv_mode = "mplug_owl2"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
roles = conv.roles
correct = 0
for i, llddata in enumerate((llvqa_data)):
filename = llddata["img_path"]
message = llddata["question"] + "\n"
for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
message += f"{choice} {ans}\n"
if "correct_ans" in llddata and ans == llddata["correct_ans"]:
correct_choice = choice[0]
message = message + "Answer with the option's letter from the given choices directly.\n"
inp = message
conv = conv_templates[args.conv_mode].copy()
inp = "The input image:" + DEFAULT_IMAGE_TOKEN + inp
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
image = load_image(os.path.join(args.root_dir, filename))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(model.device)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
attention_mask=torch.ones_like(input_ids),
images=image_tensor,
do_sample=False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
num_beams=1,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
llddata["response"] = outputs
if correct_choice in outputs:
correct += 1
pbar.update(1)
pbar.set_description("[Running Accuracy]: {:.4f},[Response]: {}, [Correct Ans]: {}, , [Prog]: {}".format(correct/(i+1), outputs, llddata.get("correct_ans", -1), i+1))
save_path = os.path.join(args.save_dir, os.path.basename(args.meta_path))
with open(save_path, "a") as fw:
fw.write(json.dumps(llddata) + "\n")
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--save-dir", type=str, required=True)
parser.add_argument("--meta-path", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
args = parser.parse_args()
main(args)
================================================
FILE: src/evaluate/iqa_eval.py
================================================
import argparse
import json
import os
from collections import defaultdict
from io import BytesIO
import requests
import torch
from PIL import Image
from tqdm import tqdm
from src.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from src.conversation import conv_templates
from src.mm_utils import get_model_name_from_path, tokenizer_image_token
from src.model.builder import load_pretrained_model
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def load_image(image_file):
if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path,
args.model_base,
model_name,
args.load_8bit,
args.load_4bit,
device=args.device,
preprocessor_path=args.preprocessor_path,
)
meta_paths = args.meta_paths
root_dir = args.root_dir
batch_size = args.batch_size
save_dir = args.save_dir
os.makedirs(save_dir, exist_ok=True)
with_prob = args.with_prob
conv_mode = "mplug_owl2"
inp = "How would you rate the quality of this image?"
conv = conv_templates[conv_mode].copy()
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
conv.append_message(conv.roles[0], inp)
image = None
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() + " The quality of the image is"
toks = args.level_names
print(toks)
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
print(ids_)
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(args.device)
)
for meta_path in meta_paths:
with open(meta_path) as f:
iqadata = json.load(f)
image_tensors = []
batch_data = []
imgs_handled = []
save_path = os.path.join(save_dir, os.path.basename(meta_path))
if os.path.exists(save_path):
with open(save_path) as fr:
for line in fr:
meta_res = json.loads(line)
imgs_handled.append(meta_res["image"])
meta_name = os.path.basename(meta_path)
for i, llddata in enumerate(tqdm(iqadata, desc=f"Evaluating [{meta_name}]")):
try:
filename = llddata["image"]
except:
filename = llddata["img_path"]
if filename in imgs_handled:
continue
llddata["logits"] = defaultdict(float)
llddata["probs"] = defaultdict(float)
image = load_image(os.path.join(root_dir, filename))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean)
)
image_tensor = (
image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
.half()
.to(args.device)
)
image_tensors.append(image_tensor)
batch_data.append(llddata)
if (i + 1) % batch_size == 0 or i == len(iqadata) - 1:
with torch.inference_mode():
output_logits = model(
input_ids=input_ids.repeat(len(image_tensors), 1),
images=torch.cat(image_tensors, 0),
)["logits"][:, -1]
if with_prob:
output_probs = torch.softmax(output_logits, dim=1)
for j, xllddata in enumerate(batch_data):
for tok, id_ in zip(toks, ids_):
xllddata["logits"][tok] += output_logits[j, id_].item()
if with_prob:
xllddata["probs"][tok] += output_probs[j, id_].item()
meta_res = {
"id": xllddata["id"],
"image": xllddata["image"],
"gt_score": xllddata["gt_score"],
"logits": xllddata["logits"],
}
if with_prob:
meta_res["probs"] = xllddata["probs"]
with open(save_path, "a") as fw:
fw.write(json.dumps(meta_res) + "\n")
image_tensors = []
batch_data = []
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--preprocessor-path", type=str, default=None)
parser.add_argument("--meta-paths", type=str, required=True, nargs="+")
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--save-dir", type=str, default="results")
parser.add_argument("--level-names", type=str, required=True, nargs="+")
parser.add_argument("--with-prob", type=bool, default=False) # whether to save openset prob
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--image-aspect-ratio", type=str, default="pad")
args = parser.parse_args()
main(args)
================================================
FILE: src/evaluate/scorer.py
================================================
from PIL import Image
import torch.nn as nn
import torch
from typing import List
from src.model.builder import load_pretrained_model
from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
class Scorer(nn.Module):
def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"):
super().__init__()
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
def expand2square(self, pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def forward(self, image: List[Image.Image]):
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
with torch.inference_mode():
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
output_logits = self.model(
input_ids=self.input_ids.repeat(image_tensor.shape[0], 1),
images=image_tensor
)["logits"][:,-1, self.preferential_ids_]
return torch.softmax(output_logits, -1) @ self.weight_tensor
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
args = parser.parse_args()
scorer = Scorer(pretrained=args.model_path, device=args.device)
print(scorer([Image.open(args.img_path)]).tolist())
================================================
FILE: src/mm_utils.py
================================================
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import StoppingCriteria
from src.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
from icecream import ic
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg=None):
if model_cfg is not None:
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
else:
image_aspect_ratio = 'resize'
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == 'resize':
for image in images:
max_edge = max(image.size)
image = image.resize((max_edge, max_edge))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
================================================
FILE: src/model/__init__.py
================================================
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
from .configuration_mplug_owl2 import MPLUGOwl2Config
================================================
FILE: src/model/builder.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
from src.model import *
def load_pretrained_model(
model_path,
model_base,
model_name,
load_8bit=False,
load_4bit=False,
device_map="auto",
device="cuda",
preprocessor_path=None,
):
kwargs = {"device_map": device_map}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch.float16
if preprocessor_path is None:
preprocessor_path = model_path
if "deqa" in model_name.lower():
# Load LLaVA model
if "lora" in model_name.lower() and model_base is None:
warnings.warn(
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
)
if "lora" in model_name.lower() and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
print("Loading mPLUG-Owl2 from base model...")
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)
model.model.embed_tokens.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)
print("Loading additional mPLUG-Owl2 weights...")
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
non_lora_trainables = torch.load(
os.path.join(model_path, "non_lora_trainables.bin"),
map_location="cpu",
)
print(non_lora_trainables.keys())
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id, filename=filename, subfolder=subfolder
)
return torch.load(cache_file, map_location="cpu")
non_lora_trainables = load_from_hf(
model_path, "non_lora_trainables.bin"
)
non_lora_trainables = {
(k[17:] if k.startswith("base_model.model.") else k): v
for k, v in non_lora_trainables.items()
}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
print("Loading LoRA weights...")
model = PeftModel.from_pretrained(model, model_path)
print("Merging LoRA weights...")
model = model.merge_and_unload()
print("Model is loaded...")
elif model_base is not None:
# this may be mm projector only
print("Loading mPLUG-Owl2 from base model...")
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, **kwargs
)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print("Convert to FP16...")
model.to(torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
# vision_tower = model.get_model().vision_model
# print(vision_tower.device)
# vision_tower.to(device=device, dtype=torch.float16)
image_processor = CLIPImageProcessor.from_pretrained(preprocessor_path)
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
================================================
FILE: src/model/configuration_mplug_owl2.py
================================================
# Copyright (c) Alibaba.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
class MplugOwlVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
a
mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration defaults will yield a similar configuration to that of the mPLUG-Owl
[x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
```"""
model_type = "mplug_owl_vision_model"
def __init__(
self,
hidden_size=1024,
intermediate_size=4096,
projection_dim=768,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
image_size=448,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
use_flash_attn=False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-owl":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MplugOwlVisualAbstractorConfig(PretrainedConfig):
model_type = "mplug_owl_visual_abstract"
def __init__(
self,
num_learnable_queries=64,
hidden_size=1024,
num_hidden_layers=6,
num_attention_heads=16,
intermediate_size=2816,
attention_probs_dropout_prob=0.,
initializer_range=0.02,
layer_norm_eps=1e-6,
encoder_hidden_size=1024,
grid_size=None,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_learnable_queries = num_learnable_queries
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.encoder_hidden_size = encoder_hidden_size
self.grid_size = grid_size if grid_size else 32
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-owl":
config_dict = config_dict["abstractor_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
DEFAULT_VISUAL_CONFIG = {
"visual_model": MplugOwlVisionConfig().to_dict(),
"visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict()
}
class MPLUGOwl2Config(LlamaConfig):
model_type = "mplug_owl2"
def __init__(self, visual_config=None, **kwargs):
if visual_config is None:
self.visual_config = DEFAULT_VISUAL_CONFIG
else:
self.visual_config = visual_config
super().__init__(
**kwargs,
)
if __name__ == "__main__":
print(MplugOwlVisionConfig().to_dict())
================================================
FILE: src/model/convert_mplug_owl2_weight_to_hf.py
================================================
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
```
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
```
Thereafter, models can be loaded via:
```py
from transformers import LlamaForCausalLM, LlamaTokenizer
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path,
input_base_path,
model_size,
num_input_shards=1,
num_output_shards=2,
skip_permute=True,
norm_eps=1e-05):
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
os.makedirs(model_path, exist_ok=True)
# tmp_model_path = os.path.join(model_path, "tmp")
tmp_model_path = model_path
os.makedirs(tmp_model_path, exist_ok=True)
num_shards = num_input_shards
n_layers = llama_s2layer[model_size]
n_heads = llama_s2heads[model_size]
n_heads_per_shard = n_heads // num_shards
n_dense = llama_s2dense[model_size]
n_hidden = llama_s2hidden[model_size]
hidden_per_head = n_hidden // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
# permute for sliced rotary
def permute(w, skip_permute=skip_permute):
if skip_permute:
return w
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards==1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
elif input_base_path.split('/')[-1].startswith('iter_'):
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
load_dir = '/'.join(input_base_path.split('/')[:-1])
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
original_filename = filename
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
else:
# Sharded
filenames = []
for i in range(num_shards):
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
filenames.append(filename)
loaded = [
torch.load(filenames[i], map_location="cpu")['model']['language_model']
for i in range(num_shards)
]
print('Llama-Megatron Loaded!')
param_count = 0
index_dict = {"weight_map": {}}
print(f'Weighted Converting for {n_layers} layers...')
for layer_i in range(n_layers):
print(layer_i)
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
}
else:
raise NotImplemented
# else:
# # Sharded
# # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
# state_dict = {
# f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][
# f"layers.{layer_i}.input_layernorm.multiway.0.weight"
# ].clone(),
# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][
# f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"
# ].clone(),
# }
# wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], []
# for shard_idx in range(num_shards):
# wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"])
# wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"])
# wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"])
# state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
# torch.cat(
# [
# wq.view(n_heads_per_shard, hidden_per_head, n_hidden)
# for wq in range(wqs)
# ],
# dim=0,
# ).reshape(n_hidden, n_hidden)
# )
# state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
# torch.cat(
# [
# wk.view(n_heads_per_shard, hidden_per_head, n_hidden)
# for wk in range(wks)
# ],
# dim=0,
# ).reshape(n_hidden, n_hidden)
# )
# state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
# [
# wv.view(n_heads_per_shard, hidden_per_head, n_hidden)
# for wv in range(wvs)
# ],
# dim=0,
# ).reshape(n_hidden, n_hidden)
# state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
# [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1
# )
# state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0
# )
# state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1
# )
# state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0
# )
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f'Sharded file saved to {filename}')
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards==1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded['encoder']['norm.weight'],
"lm_head.weight": loaded['encoder']['lm_head.weight'],
}
else:
state_dict = {
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
}
loaded_all = torch.load(original_filename, map_location="cpu")['model']
# Vision Part
state_dict.update({
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
})
for v_layer_idx in range(24):
state_dict.update({
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
})
# Abstractor Part
state_dict.update({
"model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'],
"model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'],
"model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'],
"model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'],
})
for v_layer_idx in range(6):
state_dict.update({
# f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed":
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"],
# f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin",
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"],
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"],
})
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = MPLUGOwl2Config()
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
del loaded_all
gc.collect()
def write_tokenizer(tokenizer_path, input_tokenizer_path):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA_Megatron weights",
)
parser.add_argument(
"--model_size",
type=int,
default=7,
choices=[7, 13, 30, 65, 70],
)
parser.add_argument(
"--num_input_shards",
type=int,
default=1,
)
parser.add_argument(
"--num_output_shards",
type=int,
default=1,
)
parser.add_argument('--skip_permute', action='store_true')
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
num_input_shards=args.num_input_shards,
num_output_shards=args.num_output_shards,
skip_permute=args.skip_permute
)
if __name__ == "__main__":
main()
================================================
FILE: src/model/modeling_attn_mask_utils.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
class AttentionMaskConverter:
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
attention_mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`torch.Tensor`):
The embedded inputs as a torch Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
return attention_mask
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _create_4d_causal_attention_mask(
input_shape: Union[torch.Size, Tuple, List],
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
Args:
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
device (`int`):
The torch device the created mask shall have.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = past_key_values_length + input_shape[-1]
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
)
return attention_mask
================================================
FILE: src/model/modeling_llama2.py
================================================
import math
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import copy
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import transformers
from transformers.models.llama.modeling_llama import *
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from .configuration_mplug_owl2 import LlamaConfig
class MultiwayNetwork(nn.Module):
def __init__(self, module_provider, num_multiway=2):
super(MultiwayNetwork, self).__init__()
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
def forward(self, hidden_states, multiway_indices):
if len(self.multiway) == 1:
return self.multiway[0](hidden_states)
output_hidden_states = torch.empty_like(hidden_states)
for idx, subway in enumerate(self.multiway):
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
if hidden.numel():
output = subway(hidden)
if isinstance(output, tuple):
output = output[0]
output = output.squeeze(1)
output_hidden_states[local_indices] = output
return output_hidden_states.contiguous()
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = MultiwayNetwork(module_provider=partial(
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
)
self.v_proj = MultiwayNetwork(module_provider=partial(
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, )
key_states = self.k_proj(hidden_states, modality_indicators)
value_states = self.v_proj(hidden_states, modality_indicators)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states, modality_indicators)
value_states = self.v_proj(hidden_states, modality_indicators)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states, modality_indicators)
value_states = self.v_proj(hidden_states, modality_indicators)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = MultiwayNetwork(module_provider=partial(
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
))
self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
))
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, modality_indicators)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def model_forward(
self,
input_ids: torch.LongTensor = None,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
modality_indicators,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def causal_model_forward(
self,
input_ids: torch.LongTensor = None,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def replace_llama_modality_adaptive():
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2
transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
if __name__ == "__main__":
replace_llama_modality_adaptive()
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
model = transformers.LlamaForCausalLM(config)
print(model)
================================================
FILE: src/model/modeling_mplug_owl2.py
================================================
# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
from transformers import (AutoConfig, AutoModelForCausalLM, LlamaForCausalLM,
LlamaModel)
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mplug_owl2 import (MPLUGOwl2Config, MplugOwlVisionConfig,
MplugOwlVisualAbstractorConfig)
from .modeling_llama2 import replace_llama_modality_adaptive
from .utils import extend_list, find_prefix
from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<|image|>"
from icecream import ic
class MPLUGOwl2MetaModel:
def __init__(self, config):
super(MPLUGOwl2MetaModel, self).__init__(config)
self.vision_model = MplugOwlVisionModel(
MplugOwlVisionConfig(**config.visual_config["visual_model"])
)
self.visual_abstractor = MplugOwlVisualAbstractorModel(
MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]),
config.hidden_size,
)
def get_vision_tower(self):
vision_model = getattr(self, "vision_model", None)
if type(vision_model) is list:
vision_model = vision_model[0]
return vision_model
def get_visual_abstractor(self):
visual_abstractor = getattr(self, "visual_abstractor", None)
if type(visual_abstractor) is list:
visual_abstractor = visual_abstractor[0]
return visual_abstractor
class MPLUGOwl2MetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def encode_images(self, images):
image_features = self.get_model().vision_model(images).last_hidden_state
image_features = (
self.get_model()
.visual_abstractor(encoder_hidden_states=image_features)
.last_hidden_state
)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, images
):
if images is None or input_ids.shape[1] == 1:
if (
past_key_values is not None
and images is not None
and input_ids.shape[1] == 1
):
attention_mask = torch.ones(
(attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
return (
input_ids,
multiway_indices,
attention_mask,
past_key_values,
None,
labels,
)
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)
new_input_embeds = []
new_modality_indicators = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(
cur_input_ids[:half_len]
)
cur_input_embeds_2 = self.get_model().embed_tokens(
cur_input_ids[half_len:]
)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2],
dim=0,
)
new_input_embeds.append(cur_input_embeds)
cur_modality_indicators = (
torch.zeros(len(cur_input_embeds)).long().to(self.device)
)
new_modality_indicators.append(cur_modality_indicators)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
cur_modality_indicators = []
if labels is not None:
gitextract_56m520fi/
├── .gitignore
├── LICENSE
├── README.md
├── build_soft_labels/
│ └── gen_soft_label.py
├── preprocessor/
│ └── tokenizer.model
├── pyproject.toml
├── scripts/
│ ├── eval_dist.sh
│ ├── eval_score.sh
│ ├── infer.sh
│ ├── infer_lora.sh
│ ├── train.sh
│ └── train_lora.sh
├── src/
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── pair_dataset.py
│ │ ├── single_dataset.py
│ │ └── utils.py
│ ├── evaluate/
│ │ ├── __init__.py
│ │ ├── cal_distribution_gap.py
│ │ ├── cal_plcc_srcc.py
│ │ ├── eval_qbench_mcq.py
│ │ ├── iqa_eval.py
│ │ └── scorer.py
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── configuration_mplug_owl2.py
│ │ ├── convert_mplug_owl2_weight_to_hf.py
│ │ ├── modeling_attn_mask_utils.py
│ │ ├── modeling_llama2.py
│ │ ├── modeling_mplug_owl2.py
│ │ ├── utils.py
│ │ └── visual_encoder.py
│ ├── train/
│ │ ├── mplug_owl2_trainer.py
│ │ └── train_mem.py
│ └── utils.py
└── tests/
├── datasets/
│ ├── test_pair_dataset.py
│ └── test_uncertainty_levels.py
└── model/
├── test_find_prefix.py
└── test_grad.py
SYMBOL INDEX (248 symbols across 26 files)
FILE: build_soft_labels/gen_soft_label.py
function parse_args (line 9) | def parse_args():
function calculate_srcc_plcc (line 30) | def calculate_srcc_plcc(pred, mos):
function get_level (line 36) | def get_level(mos, min_mos, max_mos):
function adjust_gaussian_bar (line 49) | def adjust_gaussian_bar(probs, score):
function get_binary_probs (line 64) | def get_binary_probs(mos, min_mos=1.0, max_mos=5.0):
function main (line 80) | def main(cfg):
FILE: src/conversation.py
class SeparatorStyle (line 6) | class SeparatorStyle(Enum):
class Conversation (line 17) | class Conversation:
method get_prompt (line 30) | def get_prompt(self):
method append_message (line 119) | def append_message(self, role, message):
method get_images (line 122) | def get_images(self, return_pil=False):
method to_gradio_chatbot (line 172) | def to_gradio_chatbot(self):
method copy (line 203) | def copy(self):
method dict (line 214) | def dict(self):
FILE: src/datasets/__init__.py
function make_data_module (line 5) | def make_data_module(tokenizer, data_args):
FILE: src/datasets/pair_dataset.py
class PairDataset (line 19) | class PairDataset(Dataset):
method __init__ (line 22) | def __init__(
method __len__ (line 49) | def __len__(self):
method lengths (line 53) | def lengths(self):
method modality_lengths (line 65) | def modality_lengths(self):
method next_rand (line 76) | def next_rand(self):
method __getitem__ (line 79) | def __getitem__(self, i):
method get_one_item (line 112) | def get_one_item(self, idx_dataset, idx_sample) -> Dict[str, torch.Ten...
class DataCollatorForPairDataset (line 214) | class DataCollatorForPairDataset(object):
method __call__ (line 219) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
method collate_one (line 229) | def collate_one(self, instances: Sequence[Dict]) -> Dict[str, torch.Te...
function make_pair_data_module (line 263) | def make_pair_data_module(
FILE: src/datasets/single_dataset.py
class SingleDataset (line 18) | class SingleDataset(Dataset):
method __init__ (line 21) | def __init__(
method __len__ (line 39) | def __len__(self):
method lengths (line 43) | def lengths(self):
method modality_lengths (line 54) | def modality_lengths(self):
method next_rand (line 64) | def next_rand(self):
method __getitem__ (line 69) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
class DataCollatorForSupervisedDataset (line 194) | class DataCollatorForSupervisedDataset(object):
method __call__ (line 199) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
function make_single_data_module (line 231) | def make_single_data_module(
FILE: src/datasets/utils.py
function rank0_print (line 22) | def rank0_print(*args):
class DataArguments (line 31) | class DataArguments:
function _tokenize_fn (line 40) | def _tokenize_fn(
function _mask_targets (line 67) | def _mask_targets(target, tokenized_lens, speakers):
function _add_speaker_and_signal (line 78) | def _add_speaker_and_signal(header, source, get_conversation=True):
function preprocess_multimodal (line 100) | def preprocess_multimodal(sources: Sequence[str], data_args: DataArgumen...
function preprocess_v1 (line 122) | def preprocess_v1(
function preprocess_plain (line 210) | def preprocess_plain(
function preprocess (line 239) | def preprocess(
function load_video (line 292) | def load_video(video_file):
function expand2square (line 306) | def expand2square(pil_img, background_color):
FILE: src/evaluate/cal_distribution_gap.py
function parse_args (line 7) | def parse_args():
function kl_divergence (line 17) | def kl_divergence(mu_1, mu_2, sigma_1, sigma_2):
function js_divergence (line 34) | def js_divergence(mu_1, mu_2, sigma_1, sigma_2):
function wasserstein_distance (line 55) | def wasserstein_distance(mu_1, mu_2, sigma_1, sigma_2):
function cal_score (line 71) | def cal_score(level_names, logits=None, probs=None, use_openset_probs=Fa...
function cal_std (line 83) | def cal_std(score, probs):
FILE: src/evaluate/cal_plcc_srcc.py
function parse_args (line 9) | def parse_args():
function calculate_srcc (line 19) | def calculate_srcc(pred, mos):
function calculate_plcc (line 24) | def calculate_plcc(pred, mos):
function fit_curve (line 29) | def fit_curve(x, y, curve_type="logistic_4params"):
function cal_score (line 65) | def cal_score(level_names, logits=None, probs=None, use_openset_probs=Fa...
FILE: src/evaluate/eval_qbench_mcq.py
function disable_torch_init (line 21) | def disable_torch_init():
function load_image (line 30) | def load_image(image_file):
function main (line 39) | def main(args):
FILE: src/evaluate/iqa_eval.py
function disable_torch_init (line 18) | def disable_torch_init():
function load_image (line 28) | def load_image(image_file):
function main (line 37) | def main(args):
FILE: src/evaluate/scorer.py
class Scorer (line 14) | class Scorer(nn.Module):
method __init__ (line 15) | def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cu...
method expand2square (line 28) | def expand2square(self, pil_img, background_color):
method forward (line 41) | def forward(self, image: List[Image.Image]):
FILE: src/mm_utils.py
function load_image_from_base64 (line 11) | def load_image_from_base64(image):
function expand2square (line 15) | def expand2square(pil_img, background_color):
function process_images (line 29) | def process_images(images, image_processor, model_cfg=None):
function tokenizer_image_token (line 53) | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOK...
function get_model_name_from_path (line 75) | def get_model_name_from_path(model_path):
class KeywordsStoppingCriteria (line 86) | class KeywordsStoppingCriteria(StoppingCriteria):
method __init__ (line 87) | def __init__(self, keywords, tokenizer, input_ids):
method __call__ (line 101) | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTe...
FILE: src/model/builder.py
function load_pretrained_model (line 31) | def load_pretrained_model(
FILE: src/model/configuration_mplug_owl2.py
class LlamaConfig (line 15) | class LlamaConfig(PretrainedConfig):
method __init__ (line 99) | def __init__(
method _rope_scaling_validation (line 154) | def _rope_scaling_validation(self):
class MplugOwlVisionConfig (line 176) | class MplugOwlVisionConfig(PretrainedConfig):
method __init__ (line 218) | def __init__(
method from_pretrained (line 253) | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os....
class MplugOwlVisualAbstractorConfig (line 269) | class MplugOwlVisualAbstractorConfig(PretrainedConfig):
method __init__ (line 272) | def __init__(
method from_pretrained (line 299) | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os....
class MPLUGOwl2Config (line 321) | class MPLUGOwl2Config(LlamaConfig):
method __init__ (line 323) | def __init__(self, visual_config=None, **kwargs):
FILE: src/model/convert_mplug_owl2_weight_to_hf.py
function compute_intermediate_size (line 65) | def compute_intermediate_size(n):
function read_json (line 69) | def read_json(path):
function write_json (line 74) | def write_json(text, path):
function write_model (line 79) | def write_model(model_path,
function write_tokenizer (line 346) | def write_tokenizer(tokenizer_path, input_tokenizer_path):
function main (line 354) | def main():
FILE: src/model/modeling_attn_mask_utils.py
class AttentionMaskConverter (line 19) | class AttentionMaskConverter:
method __init__ (line 35) | def __init__(self, is_causal: bool, sliding_window: Optional[int] = No...
method to_causal_4d (line 44) | def to_causal_4d(
method to_4d (line 77) | def to_4d(
method _make_causal_mask (line 120) | def _make_causal_mask(
method _expand_mask (line 150) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Opti...
function _prepare_4d_causal_attention_mask (line 164) | def _prepare_4d_causal_attention_mask(
function _prepare_4d_attention_mask (line 204) | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, t...
function _create_4d_causal_attention_mask (line 220) | def _create_4d_causal_attention_mask(
FILE: src/model/modeling_llama2.py
function _get_unpad_data (line 22) | def _get_unpad_data(attention_mask):
class MultiwayNetwork (line 40) | class MultiwayNetwork(nn.Module):
method __init__ (line 42) | def __init__(self, module_provider, num_multiway=2):
method forward (line 47) | def forward(self, hidden_states, multiway_indices):
class LlamaAttention (line 67) | class LlamaAttention(nn.Module):
method __init__ (line 70) | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
method _init_rope (line 106) | def _init_rope(self):
method _shape (line 133) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
method forward (line 136) | def forward(
class LlamaFlashAttention2 (line 210) | class LlamaFlashAttention2(LlamaAttention):
method __init__ (line 217) | def __init__(self, *args, **kwargs):
method forward (line 225) | def forward(
method _flash_attention_forward (line 316) | def _flash_attention_forward(
method _upad_input (line 375) | def _upad_input(self, query_layer, key_layer, value_layer, attention_m...
class LlamaSdpaAttention (line 414) | class LlamaSdpaAttention(LlamaAttention):
method forward (line 422) | def forward(
class LlamaDecoderLayer (line 510) | class LlamaDecoderLayer(nn.Module):
method __init__ (line 511) | def __init__(self, config: LlamaConfig, layer_idx):
method forward (line 524) | def forward(
function model_forward (line 581) | def model_forward(
function causal_model_forward (line 726) | def causal_model_forward(
function replace_llama_modality_adaptive (line 820) | def replace_llama_modality_adaptive():
FILE: src/model/modeling_mplug_owl2.py
class MPLUGOwl2MetaModel (line 45) | class MPLUGOwl2MetaModel:
method __init__ (line 46) | def __init__(self, config):
method get_vision_tower (line 56) | def get_vision_tower(self):
method get_visual_abstractor (line 62) | def get_visual_abstractor(self):
class MPLUGOwl2MetaForCausalLM (line 69) | class MPLUGOwl2MetaForCausalLM(ABC):
method get_model (line 71) | def get_model(self):
method encode_images (line 74) | def encode_images(self, images):
method prepare_inputs_labels_for_multimodal (line 83) | def prepare_inputs_labels_for_multimodal(
class MPLUGOwl2LlamaModel (line 324) | class MPLUGOwl2LlamaModel(MPLUGOwl2MetaModel, LlamaModel):
method __init__ (line 327) | def __init__(self, config: MPLUGOwl2Config):
class MPLUGOwl2LlamaForCausalLM (line 331) | class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausal...
method __init__ (line 334) | def __init__(self, config):
method get_model (line 343) | def get_model(self):
method forward (line 346) | def forward(self, input_type=None, **kwargs):
method softkl_loss (line 380) | def softkl_loss(self, logits, labels, level_probs):
method forward_single (line 406) | def forward_single(
method get_score (line 514) | def get_score(self, item):
method get_subitem (line 558) | def get_subitem(self, item, task_type):
method forward_pair (line 585) | def forward_pair(self, item_A, item_B, **kwargs):
method rating_loss (line 669) | def rating_loss(
method binary_rating_loss (line 695) | def binary_rating_loss(self, pred_scores_A, gt_scores_A, pred_scores_B...
method prepare_inputs_for_generation (line 709) | def prepare_inputs_for_generation(
FILE: src/model/utils.py
function extend_list (line 5) | def extend_list(data_list, n, min_n):
function find_prefix (line 13) | def find_prefix(input_ids, prefix):
function auto_upgrade (line 35) | def auto_upgrade(config):
FILE: src/model/visual_encoder.py
function get_abs_pos (line 14) | def get_abs_pos(abs_pos, tgt_size):
function get_2d_sincos_pos_embed (line 33) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
function get_2d_sincos_pos_embed_from_grid (line 51) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 62) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
class MplugOwlVisionEmbeddings (line 84) | class MplugOwlVisionEmbeddings(nn.Module):
method __init__ (line 85) | def __init__(self, config):
method forward (line 108) | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
class MplugOwlVisionAttention (line 121) | class MplugOwlVisionAttention(nn.Module):
method __init__ (line 124) | def __init__(self, config):
method _shape (line 141) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
method forward (line 144) | def forward(
class QuickGELU (line 224) | class QuickGELU(nn.Module):
method forward (line 225) | def forward(self, x: torch.Tensor):
class MplugOwlMLP (line 229) | class MplugOwlMLP(nn.Module):
method __init__ (line 230) | def __init__(self, config):
method forward (line 237) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class MplugOwlVisionEncoderLayer (line 244) | class MplugOwlVisionEncoderLayer(nn.Module):
method __init__ (line 245) | def __init__(self, config):
method forward (line 253) | def forward(
class MplugOwlVisionEncoder (line 292) | class MplugOwlVisionEncoder(nn.Module):
method __init__ (line 302) | def __init__(self, config):
method forward (line 308) | def forward(
class MplugOwlVisionModel (line 384) | class MplugOwlVisionModel(PreTrainedModel):
method __init__ (line 388) | def __init__(self, config):
method forward (line 400) | def forward(
method get_input_embeddings (line 445) | def get_input_embeddings(self):
class MplugOwlVisualAbstractorMLP (line 449) | class MplugOwlVisualAbstractorMLP(nn.Module):
method __init__ (line 450) | def __init__(self, config):
method forward (line 461) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class MplugOwlVisualAbstractorMultiHeadAttention (line 468) | class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module):
method __init__ (line 469) | def __init__(self, config):
method save_attn_gradients (line 507) | def save_attn_gradients(self, attn_gradients):
method get_attn_gradients (line 510) | def get_attn_gradients(self):
method save_attention_map (line 513) | def save_attention_map(self, attention_map):
method get_attention_map (line 516) | def get_attention_map(self):
method transpose_for_scores (line 519) | def transpose_for_scores(self, x):
method forward (line 524) | def forward(
class MplugOwlVisualAbstractorCrossOutput (line 586) | class MplugOwlVisualAbstractorCrossOutput(nn.Module):
method __init__ (line 587) | def __init__(self, config):
method forward (line 594) | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Ten...
class MplugOwlVisualAbstractorAttention (line 600) | class MplugOwlVisualAbstractorAttention(nn.Module):
method __init__ (line 601) | def __init__(self, config):
method prune_heads (line 609) | def prune_heads(self, heads):
method forward (line 627) | def forward(
class MplugOwlVisualAbstractorLayer (line 657) | class MplugOwlVisualAbstractorLayer(nn.Module):
method __init__ (line 658) | def __init__(self, config, layer_idx):
method forward (line 668) | def forward(
class MplugOwlVisualAbstractorEncoder (line 693) | class MplugOwlVisualAbstractorEncoder(nn.Module):
method __init__ (line 694) | def __init__(self, config):
method forward (line 702) | def forward(
class MplugOwlVisualAbstractorModel (line 757) | class MplugOwlVisualAbstractorModel(PreTrainedModel):
method __init__ (line 759) | def __init__(self, config, language_hidden_size):
method _prune_heads (line 770) | def _prune_heads(self, heads_to_prune):
method get_extended_attention_mask (line 778) | def get_extended_attention_mask(
method forward (line 822) | def forward(
FILE: src/train/mplug_owl2_trainer.py
function maybe_zero_3 (line 12) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_mm_adapter_state_maybe_zero_3 (line 26) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function split_to_even_chunks (line 32) | def split_to_even_chunks(indices, lengths, num_chunks):
function get_modality_length_grouped_indices (line 54) | def get_modality_length_grouped_indices(lengths, batch_size, world_size,...
function get_length_grouped_indices (line 82) | def get_length_grouped_indices(lengths, batch_size, world_size, generato...
class LengthGroupedSampler (line 93) | class LengthGroupedSampler(Sampler):
method __init__ (line 99) | def __init__(
method __len__ (line 116) | def __len__(self):
method __iter__ (line 119) | def __iter__(self):
class MPLUGOwl2Trainer (line 127) | class MPLUGOwl2Trainer(Trainer):
method _get_train_sampler (line 129) | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
method create_optimizer (line 144) | def create_optimizer(self):
method _save_checkpoint (line 227) | def _save_checkpoint(self, model, trial, metrics=None):
method _save (line 230) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
FILE: src/train/train_mem.py
function rank0_print (line 38) | def rank0_print(*args):
class ModelArguments (line 44) | class ModelArguments:
class DataArguments (line 51) | class DataArguments:
class TrainingArguments (line 63) | class TrainingArguments(transformers.TrainingArguments):
function maybe_zero_3 (line 149) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_peft_state_maybe_zero_3 (line 167) | def get_peft_state_maybe_zero_3(named_params, bias):
function get_peft_state_non_lora_maybe_zero_3 (line 192) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
function get_mm_adapter_state_maybe_zero_3 (line 202) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function find_all_lora_names (line 214) | def find_all_lora_names(model):
function safe_save_model_for_hf_trainer (line 228) | def safe_save_model_for_hf_trainer(
function smart_tokenizer_and_embedding_resize (line 248) | def smart_tokenizer_and_embedding_resize(
function train (line 275) | def train():
FILE: src/utils.py
function build_logger (line 17) | def build_logger(logger_name, logger_filename):
class StreamToLogger (line 60) | class StreamToLogger(object):
method __init__ (line 64) | def __init__(self, logger, log_level=logging.INFO):
method __getattr__ (line 70) | def __getattr__(self, attr):
method write (line 73) | def write(self, buf):
method flush (line 87) | def flush(self):
function disable_torch_init (line 93) | def disable_torch_init():
function violates_moderation (line 102) | def violates_moderation(text):
function pretty_print_semaphore (line 123) | def pretty_print_semaphore(semaphore):
FILE: tests/datasets/test_pair_dataset.py
class DataArguments (line 13) | class DataArguments:
FILE: tests/model/test_find_prefix.py
function find_prefix (line 5) | def find_prefix(input_ids, prefix):
FILE: tests/model/test_grad.py
class MyModel (line 5) | class MyModel(nn.Module):
method __init__ (line 6) | def __init__(self, dim_in, closeset):
method forward (line 12) | def forward(self, x_A, x_B, gt):
method get_score (line 20) | def get_score(self, logits):
method rating_loss (line 34) | def rating_loss(self, pred_scores_A, pred_scores_B, gt):
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (314K chars).
[
{
"path": ".gitignore",
"chars": 192,
"preview": "# source file related\n*__pycache__*\n*.pyc\n*.o\n*.so\n*.egg\n*.egg-info\n\n# training related\n*log*\n*.log\n*.pth\n*.pt\n\n# result"
},
{
"path": "LICENSE",
"chars": 1108,
"preview": "MIT License\n\nCopyright (c) 2025 Depicted image Quality Assessment (DepictQA / DeQA)\n\nPermission is hereby granted, free "
},
{
"path": "README.md",
"chars": 8283,
"preview": "<div align=\"center\">\n <h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribut"
},
{
"path": "build_soft_labels/gen_soft_label.py",
"chars": 7335,
"preview": "import argparse\nimport json\nimport numpy as np\nimport os\nimport random\nfrom scipy.stats import norm, pearsonr, spearmanr"
},
{
"path": "pyproject.toml",
"chars": 1304,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"DeQA-Score\"\nve"
},
{
"path": "scripts/eval_dist.sh",
"chars": 856,
"preview": "export PYTHONPATH=./:$PYTHONPATH\n\nres_dir=./results/res_deqa_mix3/\ngt_dir=../Data-DeQA-Score/\n\npython src/evaluate/cal_d"
},
{
"path": "scripts/eval_score.sh",
"chars": 849,
"preview": "export PYTHONPATH=./:$PYTHONPATH\n\nres_dir=./results/res_deqa_mix3/\ngt_dir=../Data-DeQA-Score/\n\npython src/evaluate/cal_p"
},
{
"path": "scripts/infer.sh",
"chars": 775,
"preview": "export CUDA_VISIBLE_DEVICES=$1\nexport PYTHONPATH=./:$PYTHONPATH\n\npython src/evaluate/iqa_eval.py \\\n\t--level-names excell"
},
{
"path": "scripts/infer_lora.sh",
"chars": 835,
"preview": "export CUDA_VISIBLE_DEVICES=$1\nexport PYTHONPATH=./:$PYTHONPATH\n\npython src/evaluate/iqa_eval.py \\\n\t--level-names excell"
},
{
"path": "scripts/train.sh",
"chars": 1624,
"preview": "#!/bin/bash\nexport PYTHONPATH=./:$PYTHONPATH\n\nLOAD=\"../ModelZoo/mplug-owl2-llama2-7b/\"\n\ndeepspeed --include localhost:$1"
},
{
"path": "scripts/train_lora.sh",
"chars": 1653,
"preview": "#!/bin/bash\nexport PYTHONPATH=./:$PYTHONPATH\n\nLOAD=\"../ModelZoo/mplug-owl2-llama2-7b/\"\n\ndeepspeed --include localhost:$1"
},
{
"path": "src/__init__.py",
"chars": 73,
"preview": "from .model import MPLUGOwl2LlamaForCausalLM\nfrom .evaluate import Scorer"
},
{
"path": "src/constants.py",
"chars": 192,
"preview": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \"./demo_logs\"\n\n# Model Constants\nIGNORE_"
},
{
"path": "src/conversation.py",
"chars": 12753,
"preview": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\nfrom src.constants import DEFAULT_IMAGE_TO"
},
{
"path": "src/datasets/__init__.py",
"chars": 387,
"preview": "from .pair_dataset import make_pair_data_module\nfrom .single_dataset import make_single_data_module\n\n\ndef make_data_modu"
},
{
"path": "src/datasets/pair_dataset.py",
"chars": 10676,
"preview": "import copy\nimport json\nimport os\nimport random\nfrom dataclasses import dataclass\nfrom typing import Dict, Sequence\n\nimp"
},
{
"path": "src/datasets/single_dataset.py",
"chars": 9631,
"preview": "import copy\nimport json\nimport os\nfrom dataclasses import dataclass\nfrom typing import Dict, Sequence\n\nimport torch\nimpo"
},
{
"path": "src/datasets/utils.py",
"chars": 10306,
"preview": "import copy\nfrom dataclasses import dataclass, field\nfrom typing import Dict, List, Optional, Sequence\n\nfrom PIL import "
},
{
"path": "src/evaluate/__init__.py",
"chars": 26,
"preview": "from .scorer import Scorer"
},
{
"path": "src/evaluate/cal_distribution_gap.py",
"chars": 5372,
"preview": "import argparse\nimport json\n\nimport numpy as np\n\n\ndef parse_args():\n parser = argparse.ArgumentParser(description=\"ev"
},
{
"path": "src/evaluate/cal_plcc_srcc.py",
"chars": 3899,
"preview": "import argparse\nimport json\n\nimport numpy as np\nfrom scipy.optimize import curve_fit\nfrom scipy.stats import pearsonr, s"
},
{
"path": "src/evaluate/eval_qbench_mcq.py",
"chars": 5399,
"preview": "import argparse\nimport torch\n\nfrom src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN\nfrom src.conversation imp"
},
{
"path": "src/evaluate/iqa_eval.py",
"chars": 6755,
"preview": "import argparse\nimport json\nimport os\nfrom collections import defaultdict\nfrom io import BytesIO\n\nimport requests\nimport"
},
{
"path": "src/evaluate/scorer.py",
"chars": 2789,
"preview": "from PIL import Image\n\nimport torch.nn as nn\nimport torch\n\nfrom typing import List\n\nfrom src.model.builder import load_p"
},
{
"path": "src/mm_utils.py",
"chars": 4424,
"preview": "from PIL import Image\nfrom io import BytesIO\nimport base64\n\nimport torch\nfrom transformers import StoppingCriteria\nfrom "
},
{
"path": "src/model/__init__.py",
"chars": 112,
"preview": "from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM\nfrom .configuration_mplug_owl2 import MPLUGOwl2Config"
},
{
"path": "src/model/builder.py",
"chars": 6590,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "src/model/configuration_mplug_owl2.py",
"chars": 15505,
"preview": "# Copyright (c) Alibaba.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root dire"
},
{
"path": "src/model/convert_mplug_owl2_weight_to_hf.py",
"chars": 23109,
"preview": "# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License,"
},
{
"path": "src/model/modeling_attn_mask_utils.py",
"chars": 10121,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "src/model/modeling_llama2.py",
"chars": 37054,
"preview": "import math\nimport warnings\nfrom functools import partial\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\n"
},
{
"path": "src/model/modeling_mplug_owl2.py",
"chars": 31627,
"preview": "# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)\n#\n# Licensed under the Apache License, Version 2.0"
},
{
"path": "src/model/utils.py",
"chars": 1986,
"preview": "import torch\nfrom transformers import AutoConfig\n\n\ndef extend_list(data_list, n, min_n):\n if min_n == 0:\n retu"
},
{
"path": "src/model/visual_encoder.py",
"chars": 38105,
"preview": "import math\nfrom typing import Any, Optional, Tuple, Union\n\nfrom transformers.modeling_outputs import BaseModelOutput, B"
},
{
"path": "src/train/mplug_owl2_trainer.py",
"chars": 10476,
"preview": "import os\nfrom typing import List, Optional\n\nimport torch\nfrom icecream import ic\nfrom torch.utils.data import Sampler\nf"
},
{
"path": "src/train/train_mem.py",
"chars": 18759,
"preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
},
{
"path": "src/utils.py",
"chars": 3989,
"preview": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nimport requests\n\nfrom mplug_owl2.constants "
},
{
"path": "tests/datasets/test_pair_dataset.py",
"chars": 1991,
"preview": "from dataclasses import dataclass, field\nfrom typing import List, Optional\n\nimport transformers\nfrom torch.utils.data im"
},
{
"path": "tests/datasets/test_uncertainty_levels.py",
"chars": 466,
"preview": "from transformers import AutoTokenizer\n\nif __name__ == \"__main__\":\n model_base = \"./preprocessor\"\n tokenizer = Aut"
},
{
"path": "tests/model/test_find_prefix.py",
"chars": 1536,
"preview": "import torch\nfrom transformers import AutoTokenizer\n\n\ndef find_prefix(input_ids, prefix):\n \"\"\"\n input_ids: [B, N1]"
},
{
"path": "tests/model/test_grad.py",
"chars": 2355,
"preview": "import torch\nimport torch.nn as nn\n\n\nclass MyModel(nn.Module):\n def __init__(self, dim_in, closeset):\n super()"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the zhiyuanyou/DeQA-Score GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (294.2 KB), approximately 70.5k tokens, and a symbol index with 248 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.