Repository: zhiyuanyou/DeQA-Score Branch: main Commit: 8e4abc826ed8 Files: 42 Total size: 294.2 KB Directory structure: gitextract_56m520fi/ ├── .gitignore ├── LICENSE ├── README.md ├── build_soft_labels/ │ └── gen_soft_label.py ├── preprocessor/ │ └── tokenizer.model ├── pyproject.toml ├── scripts/ │ ├── eval_dist.sh │ ├── eval_score.sh │ ├── infer.sh │ ├── infer_lora.sh │ ├── train.sh │ └── train_lora.sh ├── src/ │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── pair_dataset.py │ │ ├── single_dataset.py │ │ └── utils.py │ ├── evaluate/ │ │ ├── __init__.py │ │ ├── cal_distribution_gap.py │ │ ├── cal_plcc_srcc.py │ │ ├── eval_qbench_mcq.py │ │ ├── iqa_eval.py │ │ └── scorer.py │ ├── mm_utils.py │ ├── model/ │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── configuration_mplug_owl2.py │ │ ├── convert_mplug_owl2_weight_to_hf.py │ │ ├── modeling_attn_mask_utils.py │ │ ├── modeling_llama2.py │ │ ├── modeling_mplug_owl2.py │ │ ├── utils.py │ │ └── visual_encoder.py │ ├── train/ │ │ ├── mplug_owl2_trainer.py │ │ └── train_mem.py │ └── utils.py └── tests/ ├── datasets/ │ ├── test_pair_dataset.py │ └── test_uncertainty_levels.py └── model/ ├── test_find_prefix.py └── test_grad.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # source file related *__pycache__* *.pyc *.o *.so *.egg *.egg-info # training related *log* *.log *.pth *.pt # result related *answer* *ckpt* *.json *.jsonl res_* # mac hidden *.DS_Store* ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Depicted image Quality Assessment (DepictQA / DeQA) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution

Zhiyuan You12, Xin Cai2, Jinjin Gu4, Tianfan Xue235#, Chao Dong134#
1Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, 2Multimedia Laboratory, The Chinese University of Hong Kong, 3Shanghai AI Laboratory, 4Shenzhen University of Advanced Technology, 5CPII under InnoHK
#Corresponding author.
Homepage | Model Weights ( Full Tuning / LoRA Tuning ) | Datasets | Paper

Motivation

Model Architecture

## [Installation Free!] Quicker Start with Hugging Face AutoModel [2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32). The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo. ```python import requests import torch from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "zhiyuanyou/DeQA-Score-Mix3", trust_remote_code=True, attn_implementation="eager", torch_dtype=torch.float16, device_map="auto", ) from PIL import Image # The inputs should be a list of multiple PIL images model.score( [Image.open(requests.get( "https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True ).raw)] ) ``` ## Installation If you only need to infer / evaluate: ```shell git clone https://github.com/zhiyuanyou/DeQA-Score.git cd DeQA-Score pip install -e . ``` For training, you need to further install additional dependencies as follows: ```shell pip install -e ".[train]" pip install flash_attn --no-build-isolation ``` ## Quick Start ### Image Quality Scorer - CLI Interface ```shell python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg ``` - Python API ```python from src import Scorer from PIL import Image scorer = Scorer() img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images print(scorer(img_list).tolist()) ``` ## Training, Inference & Evaluation ### Datasets - Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score). - Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html), [SPAQ](https://github.com/h4nwei/SPAQ), [KADID](https://database.mmsp-kn.de/kadid-10k-database.html), [PIPAL](https://github.com/HaomingCai/PIPAL-dataset), [LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html), [AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database), [TID2013](https://www.ponomarenko.info/tid2013.htm), and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html). - Arrange the folders as follows: ``` |-- DeQA-Score |-- Data-DeQA-Score |-- KONIQ |-- images/*.jpg |-- metas |-- SPAQ |-- images/*.jpg |-- metas |-- KADID10K |-- images/*.png |-- metas |-- PIPAL |-- images/Distortion_*/*.bmp |-- metas |-- LIVE-WILD |-- images/*.bmp |-- metas |-- AGIQA3K |-- images/*.jpg |-- metas |-- TID2013 |-- images/distorted_images/*.bmp |-- metas |-- CSIQ |-- images/dst_imgs/*/*.png |-- metas ``` ### Pretrained Weights We provide two model weights (full tuning and LoRA tuning) with similar performance. | | Training Datasets | Weights | |-----|-----|-----| | Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) | | LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) | Download one of the above model weights, then arrange the folders as follows: ``` |-- DeQA-Score |-- checkpoints |-- DeQA-Score-Mix3 |-- DeQA-Score-LoRA-Mix3 ``` If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows: ``` |-- DeQA-Score |-- ModelZoo |-- mplug-owl2-llama2-7b ``` ### Inference After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**: ```shell sh scripts/infer.sh $ONE_GPU_ID ``` ```shell sh scripts/infer_lora.sh $ONE_GPU_ID ``` ### Evaluation After inference, you can evaluate the inference results: - SRCC / PLCC for quality score. ```shell sh scripts/eval_score.sh ``` - KL Divergence / JS Divergence / Wasserstein Distance for score distribution. ```shell sh scripts/eval_dist.sh ``` ### Fine-tuning Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights). #### LoRA Fine-tuning - Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID. ```shell sh scripts/train_lora.sh $GPU_IDs ``` #### Full Fine-tuning from the Scratch - At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID. ```shell sh scripts/train.sh $GPU_IDs ``` ## Soft Label Construction - Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets). - Run the following scripts to construct the distribution-based soft labels. ```shell cd build_soft_labels python gen_soft_label.py ``` ## Acknowledgements This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work. ## Citation If you find our work useful for your research and applications, please cite using the BibTeX: ```bibtex @inproceedings{deqa_score, title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution}, author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao}, booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={14483--14494}, year={2025} } @article{depictqa_v2, title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset}, author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan}, journal={IEEE Transactions on Image Processing}, year={2025} } @inproceedings{depictqa_v1, title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models}, author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao}, booktitle={European Conference on Computer Vision}, pages={259--276}, year={2024} } ``` ================================================ FILE: build_soft_labels/gen_soft_label.py ================================================ import argparse import json import numpy as np import os import random from scipy.stats import norm, pearsonr, spearmanr def parse_args(): parser = argparse.ArgumentParser(description="label parameters for DeQA-Score") parser.add_argument("--config", type=str, default="./config.json") args = parser.parse_args() return args questions = [ "What do you think about the quality of this image?", "Can you rate the quality of this picture?", "Can you judge the quality of this image?", "How would you rate the quality of this image?", "How would you judge the quality of this image?", "What is your quality rating for this image?", "What's your opinion on the quality of this picture?", "Rate the quality of this image.", "Could you evaluate the quality of this image?", "How do you assess the quality of this image?", ] def calculate_srcc_plcc(pred, mos): srcc, _ = spearmanr(pred, mos) plcc, _ = pearsonr(pred, mos) return srcc, plcc def get_level(mos, min_mos, max_mos): eps = 1e-8 texts = ["bad", "poor", "fair", "good", "excellent"] for idx in range(1, len(texts) + 1): mos_left = min_mos + (idx - 1) / 5 * (max_mos - min_mos) - eps mos_right = min_mos + idx / 5 * (max_mos - min_mos) + eps if mos > mos_left and mos <= mos_right: level = idx break text = texts[level - 1] return text def adjust_gaussian_bar(probs, score): """ alpha * (a + b + c + d + e) + 5 * beta = 1 alpha * (5a + 4b + 3c + 2d + e) + 15 beta = score ==> alpha * A + 5 * beta = 1 alpha * B + 15 * beta = score """ A = np.array(probs).sum() B = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1])) alpha = (score - 3) / (B - 3. * A + 1e-9) beta = (1. - alpha * A) / 5. return alpha, beta def get_binary_probs(mos, min_mos=1.0, max_mos=5.0): eps = 1e-8 probs = [0, 0, 0, 0, 0] for idx in range(1, len(probs)): mos_left = min_mos + (idx - 1) / 4 * (max_mos - min_mos) - eps mos_right = min_mos + idx / 4 * (max_mos - min_mos) + eps if mos > mos_left and mos <= mos_right: probs[idx - 1] = (mos_right - mos) / (mos_right - mos_left) probs[idx] = (mos - mos_left) / (mos_right - mos_left) break assert np.array((np.array(probs) == 0)).sum() == 3 assert round(np.array(probs).sum(), 5) == 1 probs = probs[::-1] # should start with "excellent" & end with "bad" return probs def main(cfg): density_type = cfg["density_type"] # ["pdf", "cdf"] thre_std = cfg["thre_std"] thre_diff = cfg["thre_diff"] with open(cfg["split_json"]) as fr: split = json.load(fr) with open(cfg["mos_json"]) as fr: mos_dict = json.load(fr) save_train = cfg["save_train"] save_test = cfg["save_test"] img_dir = cfg["img_dir"] moses, stds, imgs = [], [], [] for img in mos_dict: moses.append(mos_dict[img]["mos"]) stds.append(mos_dict[img]["std"]) imgs.append(img) max_mos = max([float(_) for _ in moses]) min_mos = min([float(_) for _ in moses]) num_binary, idx = 0, 0 preds, gts, raw_diffs, diffs, alphas, betas = [], [], [], [], [], [] train_metas, test_metas = [], [] for img, mos_str, std_str in zip(imgs, moses, stds): mos, std = float(mos_str), float(std_str) if os.path.basename(img) in split["train"]: training = True elif os.path.basename(img) in split["test"]: training = False else: idx += 1 # print(idx, img) continue text = get_level(mos, min_mos, max_mos) query = random.choice(questions) resp = answer.replace("{}", text) # norm mos and std mos_norm = 4 * (mos - min_mos) / (max_mos - min_mos) + 1 # [0, 1] -> [1, 5] std_norm = 4 * std / (max_mos - min_mos) # ["excellent", "good", "fair", "poor", "bad"] -> [5, 4, 3, 2, 1] probs = [] for x in range(5, 0, -1): if density_type == "cdf": # better for smaller std dataset (see Appendix) like SPAQ prob = norm.cdf(x+0.5, mos_norm, std_norm) - norm.cdf(x-0.5, mos_norm, std_norm) else: # better for larger std dataset (see Appendix) like KonIQ and KADID assert density_type == "pdf" prob = norm.pdf(x, loc=mos_norm, scale=std_norm) probs.append(prob) mos_rec = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1])) raw_diff = abs(mos_rec - mos_norm) raw_diffs.append(raw_diff) alpha, beta = adjust_gaussian_bar(probs, mos_norm) probs_norm = [max(_ * alpha + beta, 0) for _ in probs] mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1])) diff = abs(mos_rec - mos_norm) if std_norm < thre_std or diff > thre_diff: # if std is too small, use binary probs (see Appendix) probs_norm = get_binary_probs(mos_norm) mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1])) diff, alpha, beta = abs(mos_rec - mos_norm), 1., 0. num_binary += 1 preds.append(mos_rec) gts.append(mos_norm) diffs.append(diff) alphas.append(alpha) betas.append(beta) meta = { "id": os.path.basename(img) + f"->{mos_str}", "image": os.path.join(img_dir, img), "gt_score": mos, "gt_score_norm": mos_norm, "level_probs_org": probs, "level_probs": probs_norm, "std": std, "std_norm": std_norm, } if training: conversations = [ { "from": "human", "value": query + "\n<|image|>", }, { "from": "gpt", "value": resp, }, ] meta["conversations"] = conversations train_metas.append(meta) else: del meta["level_probs_org"] del meta["level_probs"] test_metas.append(meta) print("=" * 100) print(f"save {len(train_metas)} into {save_train}") with open(save_train, "w") as fw: fw.write(json.dumps(train_metas, indent=4)) print(f"save {len(test_metas)} into {save_test}") with open(save_test, "w") as fw: fw.write(json.dumps(test_metas, indent=4)) srcc, plcc = calculate_srcc_plcc(preds, gts) print("srcc:", srcc, "plcc:", plcc) print("[raw_diff]", "l1:", sum(raw_diffs) / len(raw_diffs), "l2:", np.sqrt((np.array(raw_diffs)**2).mean())) print("[diff]", "l1:", sum(diffs) / len(diffs), "l2:", np.sqrt((np.array(diffs)**2).mean())) print("[alpha]", "mean:", np.mean(alphas), "std:", np.std(alphas)) print("[beta]", "mean:", np.mean(betas), "std:", np.std(betas)) print("binary / all:", num_binary, "/", len(train_metas) + len(test_metas)) if __name__ == "__main__": args = parse_args() with open(args.config) as fr: cfg = json.load(fr) answer = cfg["answer"] for dataset in cfg["dataset_params"]: random.seed(131) main(cfg["dataset_params"][dataset]) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "DeQA-Score" version = "1.2.0" description = "Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)" readme = "README.md" requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "torch==2.0.1", "torchvision==0.15.2", "transformers==4.36.1", "tokenizers==0.15.0", "sentencepiece==0.1.99", "shortuuid", "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", "gradio==3.35.2", "gradio_client==0.2.9", "requests", "httpx==0.24.0", "uvicorn", "fastapi", "icecream", "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "decord", "scipy", ] [project.optional-dependencies] train = ["deepspeed==0.9.5", "ninja", "wandb"] [project.urls] "Bug Tracker" = "https://github.com/zhiyuanyou/DeQA-Score/issues" [tool.setuptools.packages.find] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] [tool.wheel] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] ================================================ FILE: scripts/eval_dist.sh ================================================ export PYTHONPATH=./:$PYTHONPATH res_dir=./results/res_deqa_mix3/ gt_dir=../Data-DeQA-Score/ python src/evaluate/cal_distribution_gap.py \ --level_names excellent good fair poor bad \ --pred_paths $res_dir/test_koniq_2k.json \ $res_dir/test_spaq_2k.json \ $res_dir/test_kadid_2k.json \ $res_dir/test_pipal_5k.json \ $res_dir/test_livew_1k.json \ $res_dir/test_agiqa_3k.json \ $res_dir/test_tid2013_3k.json \ $res_dir/test_csiq_866.json \ --gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \ $gt_dir/SPAQ/metas/test_spaq_2k.json \ $gt_dir/KADID10K/metas/test_kadid_2k.json \ $gt_dir/PIPAL/metas/test_pipal_5k.json \ $gt_dir/LIVE-WILD/metas/test_livew_1k.json \ $gt_dir/AGIQA3K/metas/test_agiqa_3k.json \ $gt_dir/TID2013/metas/test_tid2013_3k.json \ $gt_dir/CSIQ/metas/test_csiq_866.json \ ================================================ FILE: scripts/eval_score.sh ================================================ export PYTHONPATH=./:$PYTHONPATH res_dir=./results/res_deqa_mix3/ gt_dir=../Data-DeQA-Score/ python src/evaluate/cal_plcc_srcc.py \ --level_names excellent good fair poor bad \ --pred_paths $res_dir/test_koniq_2k.json \ $res_dir/test_spaq_2k.json \ $res_dir/test_kadid_2k.json \ $res_dir/test_pipal_5k.json \ $res_dir/test_livew_1k.json \ $res_dir/test_agiqa_3k.json \ $res_dir/test_tid2013_3k.json \ $res_dir/test_csiq_866.json \ --gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \ $gt_dir/SPAQ/metas/test_spaq_2k.json \ $gt_dir/KADID10K/metas/test_kadid_2k.json \ $gt_dir/PIPAL/metas/test_pipal_5k.json \ $gt_dir/LIVE-WILD/metas/test_livew_1k.json \ $gt_dir/AGIQA3K/metas/test_agiqa_3k.json \ $gt_dir/TID2013/metas/test_tid2013_3k.json \ $gt_dir/CSIQ/metas/test_csiq_866.json \ ================================================ FILE: scripts/infer.sh ================================================ export CUDA_VISIBLE_DEVICES=$1 export PYTHONPATH=./:$PYTHONPATH python src/evaluate/iqa_eval.py \ --level-names excellent good fair poor bad \ --model-path checkpoints/DeQA-Score-Mix3/ \ --save-dir results/res_deqa_mix3/ \ --preprocessor-path ./preprocessor/ \ --root-dir ../Data-DeQA-Score/ \ --meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \ ../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \ ../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \ ../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \ ../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \ ../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \ ../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \ ../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \ ================================================ FILE: scripts/infer_lora.sh ================================================ export CUDA_VISIBLE_DEVICES=$1 export PYTHONPATH=./:$PYTHONPATH python src/evaluate/iqa_eval.py \ --level-names excellent good fair poor bad \ --model-path checkpoints/DeQA-Score-LoRA-Mix3/ \ --model-base ../ModelZoo/mplug-owl2-llama2-7b/ \ --save-dir results/res_deqa_lora_mix3/ \ --preprocessor-path ./preprocessor/ \ --root-dir ../Data-DeQA-Score/ \ --meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \ ../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \ ../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \ ../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \ ../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \ ../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \ ../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \ ../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \ ================================================ FILE: scripts/train.sh ================================================ #!/bin/bash export PYTHONPATH=./:$PYTHONPATH LOAD="../ModelZoo/mplug-owl2-llama2-7b/" deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \ --deepspeed scripts/zero3.json \ --model_name_or_path $LOAD \ --version v1 \ --dataset_type pair \ --level_prefix "The quality of the image is" \ --level_names excellent good fair poor bad \ --softkl_loss True \ --weight_rank 1.0 \ --weight_softkl 1.0 \ --weight_next_token 0.05 \ --continuous_rating_loss True \ --closeset_rating_loss True \ --use_fix_std True \ --detach_pred_std True \ --data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \ ../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \ ../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \ --data_weights 1 1 1 \ --image_folder ../Data-DeQA-Score/ \ --output_dir ./checkpoints/deqa_mix3_rank1.0_next0.05_kl1.0/ \ --image_aspect_ratio pad \ --group_by_modality_length True \ --bf16 True \ --num_train_epochs 3 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "no" \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --tune_visual_abstractor True \ --freeze_vision_model False \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to tensorboard ================================================ FILE: scripts/train_lora.sh ================================================ #!/bin/bash export PYTHONPATH=./:$PYTHONPATH LOAD="../ModelZoo/mplug-owl2-llama2-7b/" deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \ --deepspeed scripts/zero3.json \ --lora_enable True \ --model_name_or_path $LOAD \ --version v1 \ --dataset_type pair \ --level_prefix "The quality of the image is" \ --level_names excellent good fair poor bad \ --softkl_loss True \ --weight_rank 1.0 \ --weight_softkl 1.0 \ --weight_next_token 0.05 \ --continuous_rating_loss True \ --closeset_rating_loss True \ --use_fix_std True \ --detach_pred_std True \ --data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \ ../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \ ../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \ --data_weights 1 1 1 \ --image_folder ../Data-DeQA-Score/ \ --output_dir ./checkpoints/deqa_lora_mix3_rank1.0_next0.05_kl1.0 \ --image_aspect_ratio pad \ --group_by_modality_length True \ --bf16 True \ --num_train_epochs 3 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "no" \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --tune_visual_abstractor True \ --freeze_vision_model False \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to tensorboard ================================================ FILE: src/__init__.py ================================================ from .model import MPLUGOwl2LlamaForCausalLM from .evaluate import Scorer ================================================ FILE: src/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "./demo_logs" # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "<|image|>" ================================================ FILE: src/conversation.py ================================================ import dataclasses from enum import auto, Enum from typing import List, Tuple from src.constants import DEFAULT_IMAGE_TOKEN class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() TWO_NO_SYS = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() # init_msg = init_msg[0].replace("", "").strip() # if 'mmtag' in self.version: # messages[0] = (init_role, init_msg) # messages.insert(0, (self.roles[0], "")) # messages.insert(1, (self.roles[1], "Received.")) # else: # messages[0] = (init_role, "\n" + init_msg) init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip() messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO_NO_SYS: seps = [self.sep, self.sep2] ret = "" for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode in ["Default", "Crop"]: pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if longest_edge != max(image.size): if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = img_str + msg.replace('<|image|>', '').strip() ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_mplug_owl2 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO_NO_SYS, sep=" ", sep2="", ) # default_conversation = conv_vicuna_v1 default_conversation = conv_mplug_owl2 conv_templates = { "default": conv_vicuna_v0, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "mplug_owl2": conv_mplug_owl2, } if __name__ == "__main__": print(default_conversation.get_prompt()) ================================================ FILE: src/datasets/__init__.py ================================================ from .pair_dataset import make_pair_data_module from .single_dataset import make_single_data_module def make_data_module(tokenizer, data_args): if data_args.dataset_type == "single": return make_single_data_module(tokenizer, data_args) elif data_args.dataset_type == "pair": return make_pair_data_module(tokenizer, data_args) else: raise ValueError ================================================ FILE: src/datasets/pair_dataset.py ================================================ import copy import json import os import random from dataclasses import dataclass from typing import Dict, Sequence import torch import transformers from PIL import Image from torch.utils.data import Dataset from src.constants import IGNORE_INDEX from .utils import (expand2square, load_video, preprocess, preprocess_multimodal, rank0_print) class PairDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, data_paths, data_weights, tokenizer: transformers.PreTrainedTokenizer, data_args, ): super(PairDataset, self).__init__() dataset_list = [] # list (different datasets) of list (samples in one dataset) for data_path, data_weight in zip(data_paths, data_weights): data_list = json.load(open(data_path, "r")) dataset_list.append(data_list * data_weight) self.dataset_list = dataset_list # Construct nums_data, nums_data[i] is the number of samples in 0-i th datasets nums_eachdata = [len(_) for _ in self.dataset_list] nums_predata = copy.deepcopy(nums_eachdata) for idx in range(1, len(nums_predata)): nums_predata[idx] = nums_predata[idx] + nums_predata[idx - 1] rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.nums_eachdata = nums_eachdata self.nums_predata = nums_predata self.data_args = data_args assert self.nums_predata[-1] == sum(self.nums_eachdata) def __len__(self): return self.nums_predata[-1] @property def lengths(self): length_list = [] for dataset in self.dataset_list: for sample in dataset: img_tokens = 128 if "image" in sample else 0 length_list.append( sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens ) return length_list @property def modality_lengths(self): length_list = [] for dataset in self.dataset_list: for sample in dataset: cur_len = sum( len(conv["value"].split()) for conv in sample["conversations"] ) cur_len = cur_len if "image" in sample else -cur_len length_list.append(cur_len) return length_list def next_rand(self): return random.randint(0, len(self) - 1) def __getitem__(self, i): while True: try: # Get idx_dataset, idx_sample if i < self.nums_predata[0]: idx_dataset = 0 idx_sample = i else: for idx_dataset in range(1, len(self.nums_predata)): if ( i < self.nums_predata[idx_dataset] and i >= self.nums_predata[idx_dataset - 1] ): idx_sample = i - self.nums_predata[idx_dataset - 1] break # Sample two items item_A = self.get_one_item(idx_dataset, idx_sample) while True: idx_sample_B = random.randint( 0, self.nums_eachdata[idx_dataset] - 1 ) if idx_sample_B != idx_sample: break item_B = self.get_one_item(idx_dataset, idx_sample_B) return { "item_A": item_A, "item_B": item_B, } except Exception as ex: print(ex) i = self.next_rand() continue def get_one_item(self, idx_dataset, idx_sample) -> Dict[str, torch.Tensor]: # For IQA data, i must be int sources = [self.dataset_list[idx_dataset][idx_sample]] sources_org = copy.deepcopy(sources) assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if "image" in sources_org[0]: image_file = sources[0]["image"] image_folder = self.data_args.image_folder processor = self.data_args.image_processor if isinstance(image_file, list): # Multiple Images as Input image = [ Image.open(os.path.join(image_folder, imfile)).convert("RGB") for imfile in image_file ] if self.data_args.image_aspect_ratio == "pad": image = [ expand2square( img, tuple(int(x * 255) for x in processor.image_mean), ) for img in image ] image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] elif os.path.join(image_folder, image_file).endswith("mp4"): # Video as Input image = load_video(os.path.join(image_folder, image_file)) if self.data_args.image_aspect_ratio == "pad": image = [ expand2square( img, tuple(int(x * 255) for x in processor.image_mean), ) for img in image ] image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = Image.open(os.path.join(image_folder, image_file)).convert( "RGB" ) if self.data_args.image_aspect_ratio == "pad": image = expand2square( image, tuple(int(x * 255) for x in processor.image_mean) ) image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args, ) else: # Without images sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=("image" in sources_org[0]), ) data_dict = dict( input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0], ) # default task_type: "score", gt_socre & std: -10000, level_probs: [-10000] * 5 data_dict["task_type"] = sources_org[0].get("task_type", "score") data_dict["gt_score"] = sources_org[0].get("gt_score", -10000) data_dict["std"] = sources_org[0].get("std", -10000) data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5) # image exist in the data if "image" in sources_org[0]: data_dict["image_file"] = image_file data_dict["image"] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"]) return data_dict @dataclass class DataCollatorForPairDataset(object): """Collate examples for pair fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: instances_A = [instance["item_A"] for instance in instances] instances_B = [instance["item_B"] for instance in instances] batch = { "input_type": "pair", "item_A": self.collate_one(instances_A), "item_B": self.collate_one(instances_B), } return batch def collate_one(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple( [instance[key] for instance in instances] for key in ("input_ids", "labels") ) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX ) input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) batch["task_types"] = [instance["task_type"] for instance in instances] batch["gt_scores"] = torch.tensor([instance["gt_score"] for instance in instances]) batch["stds"] = torch.tensor([instance["std"] for instance in instances]) batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances]) if "image" in instances[0]: images = [instance["image"] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch["images"] = torch.stack(images) else: batch["images"] = images batch["image_files"] = [instance["image_file"] for instance in instances] return batch def make_pair_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args ) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = PairDataset( tokenizer=tokenizer, data_paths=data_args.data_paths, data_weights=data_args.data_weights, data_args=data_args, ) data_collator = DataCollatorForPairDataset(tokenizer=tokenizer) return dict( train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator ) ================================================ FILE: src/datasets/single_dataset.py ================================================ import copy import json import os from dataclasses import dataclass from typing import Dict, Sequence import torch import transformers from PIL import Image from torch.utils.data import Dataset from src.constants import IGNORE_INDEX from .utils import (expand2square, load_video, preprocess, preprocess_multimodal, rank0_print) class SingleDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, data_paths: str, data_weights: str, tokenizer: transformers.PreTrainedTokenizer, data_args, ): super(SingleDataset, self).__init__() list_data_dict = [] for data_path, data_weight in zip(data_paths, data_weights): data_dict = json.load(open(data_path, "r")) list_data_dict += data_dict * data_weight rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if "image" in sample else 0 length_list.append( sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens ) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum( len(conv["value"].split()) for conv in sample["conversations"] ) cur_len = cur_len if "image" in sample else -cur_len length_list.append(cur_len) return length_list def next_rand(self): import random return random.randint(0, len(self) - 1) def __getitem__(self, i) -> Dict[str, torch.Tensor]: while True: try: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] sources_org = copy.deepcopy(sources) assert ( len(sources) == 1 ), "Don't know why it is wrapped to a list" # FIXME if "image" in sources_org[0]: image_file = sources_org[0]["image"] image_folder = self.data_args.image_folder processor = self.data_args.image_processor from pathlib import Path # if not Path(os.path.join(image_folder, image_file)).exists(): # i = self.next_rand() # continue if isinstance(image_file, list): # Multiple Images as Input try: image = [ Image.open(os.path.join(image_folder, imfile)).convert( "RGB" ) for imfile in image_file ] except Exception as ex: print(ex) i = self.next_rand() continue if self.data_args.image_aspect_ratio == "pad": image = [ expand2square( img, tuple(int(x * 255) for x in processor.image_mean), ) for img in image ] image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] elif os.path.join(image_folder, image_file).endswith("mp4"): # Video as Input image = load_video(os.path.join(image_folder, image_file)) if self.data_args.image_aspect_ratio == "pad": image = [ expand2square( img, tuple(int(x * 255) for x in processor.image_mean), ) for img in image ] image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: try: image = Image.open( os.path.join(image_folder, image_file) ).convert("RGB") except Exception as ex: print(ex) i = self.next_rand() continue if self.data_args.image_aspect_ratio == "pad": image = expand2square( image, tuple(int(x * 255) for x in processor.image_mean) ) image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args, ) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=("image" in sources_org[0]), ) if isinstance(i, int): data_dict = dict( input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0], ) # default task_type: "score", level_probs: [-10000] * 5 data_dict["task_type"] = sources_org[0].get("task_type", "score") data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5) # image exist in the data if "image" in sources_org[0]: data_dict["image"] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict["image"] = torch.zeros( 3, crop_size["height"], crop_size["width"] ) return data_dict except Exception as ex: print(ex) i = self.next_rand() continue @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple( [instance[key] for instance in instances] for key in ("input_ids", "labels") ) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX ) input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_type="single", input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) batch["task_types"] = [instance["task_type"] for instance in instances] batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances]) if "image" in instances[0]: images = [instance["image"] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch["images"] = torch.stack(images) else: batch["images"] = images return batch def make_single_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args ) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = SingleDataset( tokenizer=tokenizer, data_paths=data_args.data_paths, data_weights=data_args.data_weights, data_args=data_args, ) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict( train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator ) ================================================ FILE: src/datasets/utils.py ================================================ import copy from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from dataclasses import dataclass, field from typing import List, Optional import torch import torch.distributed as dist import transformers from PIL import Image from src import conversation as conversation_lib from src.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX from src.mm_utils import tokenizer_image_token def rank0_print(*args): try: if dist.get_rank() == 0: print(*args) except: print(*args) @dataclass class DataArguments: data_paths: List[str] = field(default_factory=lambda: []) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = "square" image_grid_pinpoints: Optional[str] = field(default=None) def _tokenize_fn( strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer ) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = "unknown" sentence["value"] = ( BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL ) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence["value"]: sentence["value"] = ( sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() ) sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] sentence["value"] = sentence["value"].strip() replace_token = DEFAULT_IMAGE_TOKEN sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, replace_token ) return sources def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert ( conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS ) # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 + 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 round_len -= 1 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]["value"] source[0]["value"] = DEFAULT_IMAGE_TOKEN conversation = ( source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep ) conversations.append(conversation) # tokenize conversations input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if ( conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN ): return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn( [header] + [s["value"] for s in source], tokenizer )["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) def load_video(video_file): from decord import VideoReader vr = VideoReader(video_file) # Get video frame rate fps = vr.get_avg_fps() # Calculate frame indices for 1fps frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] frames = vr.get_batch(frame_indices).asnumpy() return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result ================================================ FILE: src/evaluate/__init__.py ================================================ from .scorer import Scorer ================================================ FILE: src/evaluate/cal_distribution_gap.py ================================================ import argparse import json import numpy as np def parse_args(): parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score") parser.add_argument("--level_names", type=str, required=True, nargs="+") parser.add_argument("--pred_paths", type=str, required=True, nargs="+") parser.add_argument("--gt_paths", type=str, required=True, nargs="+") parser.add_argument("--use_openset_probs", action="store_true") args = parser.parse_args() return args def kl_divergence(mu_1, mu_2, sigma_1, sigma_2): """ Calculate the Kullback-Leibler (KL) divergence between two Gaussian distributions for numpy arrays. Parameters: mu_1 (np.array): Mean of the first distribution (array of size N). mu_2 (np.array): Mean of the second distribution (array of size N). sigma_1 (np.array): Standard deviation of the first distribution (array of size N). sigma_2 (np.array): Standard deviation of the second distribution (array of size N). Returns: np.array: KL divergence from distribution 1 to distribution 2 (array of size N). """ eps = 1e-8 return np.log(sigma_2 / (sigma_1 + eps)) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2 * sigma_2**2 + eps) - 0.5 def js_divergence(mu_1, mu_2, sigma_1, sigma_2): """ Calculate the Jensen-Shannon (JS) divergence between two Gaussian distributions for numpy arrays. Parameters: mu_1 (np.array): Mean of the first distribution (array of size N). mu_2 (np.array): Mean of the second distribution (array of size N). sigma_1 (np.array): Standard deviation of the first distribution (array of size N). sigma_2 (np.array): Standard deviation of the second distribution (array of size N). Returns: np.array: JS divergence between the two distributions (array of size N). """ # Midpoint distribution parameters mu_m = 0.5 * (mu_1 + mu_2) sigma_m = np.sqrt(0.5 * (sigma_1**2 + sigma_2**2)) # JS divergence as the average of the KL divergences return 0.5 * kl_divergence(mu_1, mu_m, sigma_1, sigma_m) + 0.5 * kl_divergence(mu_2, mu_m, sigma_2, sigma_m) def wasserstein_distance(mu_1, mu_2, sigma_1, sigma_2): """ Calculate the Wasserstein distance between two Gaussian distributions for numpy arrays. Parameters: mu_1 (np.array): Mean of the first distribution (array of size N). mu_2 (np.array): Mean of the second distribution (array of size N). sigma_1 (np.array): Standard deviation of the first distribution (array of size N). sigma_2 (np.array): Standard deviation of the second distribution (array of size N). Returns: np.array: Wasserstein distance between the two distributions (array of size N). """ return np.sqrt((mu_1 - mu_2)**2 + (sigma_1 - sigma_2)**2) def cal_score(level_names, logits=None, probs=None, use_openset_probs=False): if use_openset_probs: assert logits is None probs = np.array([probs[_] for _ in level_names]) else: assert probs is None logprobs = np.array([logits[_] for _ in level_names]) probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) score = np.inner(probs, np.array([5., 4., 3., 2., 1.])) return score, probs def cal_std(score, probs): variance = (np.array([5., 4., 3., 2., 1.]) - score) * (np.array([5., 4., 3., 2., 1.]) - score) variance = np.inner(probs, variance) std = np.sqrt(variance) return std if __name__ == "__main__": args = parse_args() level_names = args.level_names pred_paths = args.pred_paths gt_paths = args.gt_paths use_openset_probs = args.use_openset_probs for pred_path, gt_path in zip(pred_paths, gt_paths): print("=" * 100) print("Pred: ", pred_path) print("GT: ", gt_path) # load predict results pred_metas = [] with open(pred_path) as fr: for line in fr: pred_meta = json.loads(line) pred_metas.append(pred_meta) # load gt results with open(gt_path) as fr: gt_metas = json.load(fr) pred_metas.sort(key=lambda x: x["id"]) gt_metas.sort(key=lambda x: x["id"]) mu_preds = [] std_preds = [] mu_gts = [] std_gts = [] for pred_meta, gt_meta in zip(pred_metas, gt_metas): assert pred_meta["id"] == gt_meta["id"] if use_openset_probs: pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=True) else: pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False) pred_std = cal_std(pred_score, probs) mu_preds.append(pred_score) std_preds.append(pred_std) mu_gts.append(gt_meta["gt_score_norm"]) std_gts.append(gt_meta["std_norm"]) mu_preds = np.array(mu_preds) std_preds = np.array(std_preds) mu_gts = np.array(mu_gts) std_gts = np.array(std_gts) kl = kl_divergence(mu_gts, mu_preds, std_gts, std_preds).mean() js = js_divergence(mu_gts, mu_preds, std_gts, std_preds).mean() wd = wasserstein_distance(mu_gts, mu_preds, std_gts, std_preds).mean() print(f"KL: {kl}") print(f"JS: {js}") print(f"WD: {wd}") ================================================ FILE: src/evaluate/cal_plcc_srcc.py ================================================ import argparse import json import numpy as np from scipy.optimize import curve_fit from scipy.stats import pearsonr, spearmanr def parse_args(): parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score") parser.add_argument("--level_names", type=str, required=True, nargs="+") parser.add_argument("--pred_paths", type=str, required=True, nargs="+") parser.add_argument("--gt_paths", type=str, required=True, nargs="+") parser.add_argument("--use_openset_probs", action="store_true") args = parser.parse_args() return args def calculate_srcc(pred, mos): srcc, _ = spearmanr(pred, mos) return srcc def calculate_plcc(pred, mos): plcc, _ = pearsonr(pred, mos) return plcc def fit_curve(x, y, curve_type="logistic_4params"): r"""Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG. The function with 4 params is more commonly used. The 5 params function takes from DBCNN: - https://github.com/zwx8981/DBCNN/blob/master/dbcnn/tools/verify_performance.m """ assert curve_type in [ "logistic_4params", "logistic_5params", ], f"curve type should be in [logistic_4params, logistic_5params], but got {curve_type}." betas_init_4params = [np.max(y), np.min(y), np.mean(x), np.std(x) / 4.0] def logistic_4params(x, beta1, beta2, beta3, beta4): yhat = (beta1 - beta2) / (1 + np.exp(-(x - beta3) / beta4)) + beta2 return yhat betas_init_5params = [10, 0, np.mean(y), 0.1, 0.1] def logistic_5params(x, beta1, beta2, beta3, beta4, beta5): logistic_part = 0.5 - 1.0 / (1 + np.exp(beta2 * (x - beta3))) yhat = beta1 * logistic_part + beta4 * x + beta5 return yhat if curve_type == "logistic_4params": logistic = logistic_4params betas_init = betas_init_4params elif curve_type == "logistic_5params": logistic = logistic_5params betas_init = betas_init_5params betas, _ = curve_fit(logistic, x, y, p0=betas_init, maxfev=10000) yhat = logistic(x, *betas) return yhat def cal_score(level_names, logits=None, probs=None, use_openset_probs=False): if use_openset_probs: assert logits is None probs = np.array([probs[_] for _ in level_names]) else: assert probs is None logprobs = np.array([logits[_] for _ in level_names]) probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) score = np.inner(probs, np.array([5., 4., 3., 2., 1.])) return score if __name__ == "__main__": args = parse_args() level_names = args.level_names pred_paths = args.pred_paths gt_paths = args.gt_paths use_openset_probs = args.use_openset_probs for pred_path, gt_path in zip(pred_paths, gt_paths): print("=" * 100) print("Pred: ", pred_path) print("GT: ", gt_path) # load predict results pred_metas = [] with open(pred_path) as fr: for line in fr: pred_meta = json.loads(line) pred_metas.append(pred_meta) # load gt results with open(gt_path) as fr: gt_metas = json.load(fr) preds = [] gts = [] for pred_meta, gt_meta in zip(pred_metas, gt_metas): assert pred_meta["id"] == gt_meta["id"] if use_openset_probs: pred_score = cal_score(level_names, probs=pred_meta["probs"], use_openset_probs=True) else: pred_score = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False) preds.append(pred_score) gts.append(gt_meta["gt_score"]) preds_fit = fit_curve(preds, gts) srcc = calculate_srcc(preds_fit, gts) plcc = calculate_plcc(preds_fit, gts) print(f"SRCC: {srcc}") print(f"PLCC: {plcc}") ================================================ FILE: src/evaluate/eval_qbench_mcq.py ================================================ import argparse import torch from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from src.conversation import conv_templates, SeparatorStyle from src.model.builder import load_pretrained_model from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import requests from PIL import Image from io import BytesIO from transformers import TextStreamer import json from tqdm import tqdm import os def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def load_image(image_file): if image_file.startswith('http://') or image_file.startswith('https://'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def main(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) os.makedirs(args.save_dir, exist_ok=True) with open(args.meta_path) as f: llvqa_data = json.load(f) pbar = tqdm(total=len(llvqa_data)) conv_mode = "mplug_owl2" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() roles = conv.roles correct = 0 for i, llddata in enumerate((llvqa_data)): filename = llddata["img_path"] message = llddata["question"] + "\n" for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): message += f"{choice} {ans}\n" if "correct_ans" in llddata and ans == llddata["correct_ans"]: correct_choice = choice[0] message = message + "Answer with the option's letter from the given choices directly.\n" inp = message conv = conv_templates[args.conv_mode].copy() inp = "The input image:" + DEFAULT_IMAGE_TOKEN + inp conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() print(prompt) image = load_image(os.path.join(args.root_dir, filename)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(model.device) input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = model.generate( input_ids, attention_mask=torch.ones_like(input_ids), images=image_tensor, do_sample=False, temperature=args.temperature, max_new_tokens=args.max_new_tokens, num_beams=1, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() llddata["response"] = outputs if correct_choice in outputs: correct += 1 pbar.update(1) pbar.set_description("[Running Accuracy]: {:.4f},[Response]: {}, [Correct Ans]: {}, , [Prog]: {}".format(correct/(i+1), outputs, llddata.get("correct_ans", -1), i+1)) save_path = os.path.join(args.save_dir, os.path.basename(args.meta_path)) with open(save_path, "a") as fw: fw.write(json.dumps(llddata) + "\n") if args.debug: print("\n", {"prompt": prompt, "outputs": outputs}, "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--root-dir", type=str, required=True) parser.add_argument("--save-dir", type=str, required=True) parser.add_argument("--meta-path", type=str, required=True) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--debug", action="store_true") parser.add_argument("--image-aspect-ratio", type=str, default='pad') args = parser.parse_args() main(args) ================================================ FILE: src/evaluate/iqa_eval.py ================================================ import argparse import json import os from collections import defaultdict from io import BytesIO import requests import torch from PIL import Image from tqdm import tqdm from src.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from src.conversation import conv_templates from src.mm_utils import get_model_name_from_path, tokenizer_image_token from src.model.builder import load_pretrained_model def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def load_image(image_file): if image_file.startswith("http://") or image_file.startswith("https://"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image def main(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device, preprocessor_path=args.preprocessor_path, ) meta_paths = args.meta_paths root_dir = args.root_dir batch_size = args.batch_size save_dir = args.save_dir os.makedirs(save_dir, exist_ok=True) with_prob = args.with_prob conv_mode = "mplug_owl2" inp = "How would you rate the quality of this image?" conv = conv_templates[conv_mode].copy() inp = inp + "\n" + DEFAULT_IMAGE_TOKEN conv.append_message(conv.roles[0], inp) image = None conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() + " The quality of the image is" toks = args.level_names print(toks) ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] print(ids_) input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(args.device) ) for meta_path in meta_paths: with open(meta_path) as f: iqadata = json.load(f) image_tensors = [] batch_data = [] imgs_handled = [] save_path = os.path.join(save_dir, os.path.basename(meta_path)) if os.path.exists(save_path): with open(save_path) as fr: for line in fr: meta_res = json.loads(line) imgs_handled.append(meta_res["image"]) meta_name = os.path.basename(meta_path) for i, llddata in enumerate(tqdm(iqadata, desc=f"Evaluating [{meta_name}]")): try: filename = llddata["image"] except: filename = llddata["img_path"] if filename in imgs_handled: continue llddata["logits"] = defaultdict(float) llddata["probs"] = defaultdict(float) image = load_image(os.path.join(root_dir, filename)) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square( image, tuple(int(x * 255) for x in image_processor.image_mean) ) image_tensor = ( image_processor.preprocess(image, return_tensors="pt")["pixel_values"] .half() .to(args.device) ) image_tensors.append(image_tensor) batch_data.append(llddata) if (i + 1) % batch_size == 0 or i == len(iqadata) - 1: with torch.inference_mode(): output_logits = model( input_ids=input_ids.repeat(len(image_tensors), 1), images=torch.cat(image_tensors, 0), )["logits"][:, -1] if with_prob: output_probs = torch.softmax(output_logits, dim=1) for j, xllddata in enumerate(batch_data): for tok, id_ in zip(toks, ids_): xllddata["logits"][tok] += output_logits[j, id_].item() if with_prob: xllddata["probs"][tok] += output_probs[j, id_].item() meta_res = { "id": xllddata["id"], "image": xllddata["image"], "gt_score": xllddata["gt_score"], "logits": xllddata["logits"], } if with_prob: meta_res["probs"] = xllddata["probs"] with open(save_path, "a") as fw: fw.write(json.dumps(meta_res) + "\n") image_tensors = [] batch_data = [] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--preprocessor-path", type=str, default=None) parser.add_argument("--meta-paths", type=str, required=True, nargs="+") parser.add_argument("--root-dir", type=str, required=True) parser.add_argument("--save-dir", type=str, default="results") parser.add_argument("--level-names", type=str, required=True, nargs="+") parser.add_argument("--with-prob", type=bool, default=False) # whether to save openset prob parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--debug", action="store_true") parser.add_argument("--image-aspect-ratio", type=str, default="pad") args = parser.parse_args() main(args) ================================================ FILE: src/evaluate/scorer.py ================================================ from PIL import Image import torch.nn as nn import torch from typing import List from src.model.builder import load_pretrained_model from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria class Scorer(nn.Module): def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"): super().__init__() tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is" self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device) self.tokenizer = tokenizer self.model = model self.image_processor = image_processor self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) def expand2square(self, pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def forward(self, image: List[Image.Image]): image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image] with torch.inference_mode(): image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) output_logits = self.model( input_ids=self.input_ids.repeat(image_tensor.shape[0], 1), images=image_tensor )["logits"][:,-1, self.preferential_ids_] return torch.softmax(output_logits, -1) @ self.weight_tensor if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg") args = parser.parse_args() scorer = Scorer(pretrained=args.model_path, device=args.device) print(scorer([Image.open(args.img_path)]).tolist()) ================================================ FILE: src/mm_utils.py ================================================ from PIL import Image from io import BytesIO import base64 import torch from transformers import StoppingCriteria from src.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN from icecream import ic def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg=None): if model_cfg is not None: image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) else: image_aspect_ratio = 'resize' new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) elif image_aspect_ratio == 'resize': for image in images: max_edge = max(image.size) image = image.resize((max_edge, max_edge)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) else: return image_processor(images, return_tensors='pt')['pixel_values'] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False ================================================ FILE: src/model/__init__.py ================================================ from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM from .configuration_mplug_owl2 import MPLUGOwl2Config ================================================ FILE: src/model/builder.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings import torch from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) from transformers.models.clip.image_processing_clip import CLIPImageProcessor from src.model import * def load_pretrained_model( model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", preprocessor_path=None, ): kwargs = {"device_map": device_map} if device != "cuda": kwargs["device_map"] = {"": device} if load_8bit: kwargs["load_in_8bit"] = True elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) else: kwargs["torch_dtype"] = torch.float16 if preprocessor_path is None: preprocessor_path = model_path if "deqa" in model_name.lower(): # Load LLaVA model if "lora" in model_name.lower() and model_base is None: warnings.warn( "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." ) if "lora" in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False) print("Loading mPLUG-Owl2 from base model...") model = MPLUGOwl2LlamaForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs ) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) model.model.embed_tokens.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) print("Loading additional mPLUG-Owl2 weights...") if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): non_lora_trainables = torch.load( os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu", ) print(non_lora_trainables.keys()) else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder ) return torch.load(cache_file, map_location="cpu") non_lora_trainables = load_from_hf( model_path, "non_lora_trainables.bin" ) non_lora_trainables = { (k[17:] if k.startswith("base_model.model.") else k): v for k, v in non_lora_trainables.items() } model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print("Loading LoRA weights...") model = PeftModel.from_pretrained(model, model_path) print("Merging LoRA weights...") model = model.merge_and_unload() print("Model is loaded...") elif model_base is not None: # this may be mm projector only print("Loading mPLUG-Owl2 from base model...") tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = MPLUGOwl2LlamaForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False) model = MPLUGOwl2LlamaForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, **kwargs ) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print("Convert to FP16...") model.to(torch.float16) else: tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) # vision_tower = model.get_model().vision_model # print(vision_tower.device) # vision_tower.to(device=device, dtype=torch.float16) image_processor = CLIPImageProcessor.from_pretrained(preprocessor_path) if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len ================================================ FILE: src/model/configuration_mplug_owl2.py ================================================ # Copyright (c) Alibaba. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import os from typing import Union from transformers.configuration_utils import PretrainedConfig from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import logging from transformers.models.auto import CONFIG_MAPPING class LlamaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LLaMA-7B. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`LlamaModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, Llama 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. pretraining_tp (`int`, *optional*, defaults to 1): Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. ```python >>> from transformers import LlamaModel, LlamaConfig >>> # Initializing a LLaMA llama-7b style configuration >>> configuration = LlamaConfig() >>> # Initializing a model from the llama-7b style configuration >>> model = LlamaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ if self.rope_scaling is None: return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") class MplugOwlVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate a mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration defaults will yield a similar configuration to that of the mPLUG-Owl [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 32): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). ```""" model_type = "mplug_owl_vision_model" def __init__( self, hidden_size=1024, intermediate_size=4096, projection_dim=768, num_hidden_layers=24, num_attention_heads=16, num_channels=3, image_size=448, patch_size=14, hidden_act="quick_gelu", layer_norm_eps=1e-6, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, use_flash_attn=False, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.projection_dim = projection_dim self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.use_flash_attn = use_flash_attn @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from MplugOwlConfig if config_dict.get("model_type") == "mplug-owl": config_dict = config_dict["vision_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class MplugOwlVisualAbstractorConfig(PretrainedConfig): model_type = "mplug_owl_visual_abstract" def __init__( self, num_learnable_queries=64, hidden_size=1024, num_hidden_layers=6, num_attention_heads=16, intermediate_size=2816, attention_probs_dropout_prob=0., initializer_range=0.02, layer_norm_eps=1e-6, encoder_hidden_size=1024, grid_size=None, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_learnable_queries = num_learnable_queries self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.attention_probs_dropout_prob = attention_probs_dropout_prob self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.encoder_hidden_size = encoder_hidden_size self.grid_size = grid_size if grid_size else 32 @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the visual_abstractor config dict if we are loading from MplugOwlConfig if config_dict.get("model_type") == "mplug-owl": config_dict = config_dict["abstractor_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) DEFAULT_VISUAL_CONFIG = { "visual_model": MplugOwlVisionConfig().to_dict(), "visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict() } class MPLUGOwl2Config(LlamaConfig): model_type = "mplug_owl2" def __init__(self, visual_config=None, **kwargs): if visual_config is None: self.visual_config = DEFAULT_VISUAL_CONFIG else: self.visual_config = visual_config super().__init__( **kwargs, ) if __name__ == "__main__": print(MplugOwlVisionConfig().to_dict()) ================================================ FILE: src/model/convert_mplug_owl2_weight_to_hf.py ================================================ # Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import gc import json import math import os import shutil import warnings import torch from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM try: from transformers import LlamaTokenizerFast except ImportError as e: warnings.warn(e) warnings.warn( "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" ) LlamaTokenizerFast = None """ Sample usage: ``` python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \ --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B ``` Thereafter, models can be loaded via: ```py from transformers import LlamaForCausalLM, LlamaTokenizer model = LlamaForCausalLM.from_pretrained("/output/path") tokenizer = LlamaTokenizer.from_pretrained("/output/path") ``` Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). """ llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80} llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64} llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016, 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192} def compute_intermediate_size(n): return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 def read_json(path): with open(path, "r") as f: return json.load(f) def write_json(text, path): with open(path, "w") as f: json.dump(text, f) def write_model(model_path, input_base_path, model_size, num_input_shards=1, num_output_shards=2, skip_permute=True, norm_eps=1e-05): # if os.path.exists(model_path): # shutil.rmtree(model_path) os.makedirs(model_path, exist_ok=True) # tmp_model_path = os.path.join(model_path, "tmp") tmp_model_path = model_path os.makedirs(tmp_model_path, exist_ok=True) num_shards = num_input_shards n_layers = llama_s2layer[model_size] n_heads = llama_s2heads[model_size] n_heads_per_shard = n_heads // num_shards n_dense = llama_s2dense[model_size] n_hidden = llama_s2hidden[model_size] hidden_per_head = n_hidden // n_heads base = 10000.0 inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head)) # permute for sliced rotary def permute(w, skip_permute=skip_permute): if skip_permute: return w return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden) print(f"Fetching all parameters from the checkpoint at {input_base_path}.") # Load weights if num_shards==1: # Not sharded # (The sharded implementation would also work, but this is simpler.) # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt if os.path.exists(os.path.join(input_base_path, 'release')): filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt') elif input_base_path.split('/')[-1].startswith('iter_'): iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0')) load_dir = '/'.join(input_base_path.split('/')[:-1]) filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt') if not os.path.exists(filename): filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') else: tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt') with open(tracker_filename, 'r') as f: metastring = f.read().strip() iteration = 'iter_{:07d}'.format(int(metastring)) filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt') if not os.path.exists(filename): filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') original_filename = filename loaded = torch.load(filename, map_location="cpu")['model']['language_model'] else: # Sharded filenames = [] for i in range(num_shards): if os.path.exists(os.path.join(input_base_path, 'release')): filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt') else: tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt') with open(tracker_filename, 'r') as f: metastring = f.read().strip() iteration = 'iter_{:07d}'.format(int(metastring)) filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt') if not os.path.exists(filename): filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') filenames.append(filename) loaded = [ torch.load(filenames[i], map_location="cpu")['model']['language_model'] for i in range(num_shards) ] print('Llama-Megatron Loaded!') param_count = 0 index_dict = {"weight_map": {}} print(f'Weighted Converting for {n_layers} layers...') for layer_i in range(n_layers): print(layer_i) filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" if num_shards == 1: # Unsharded state_dict = { f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"], f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"], f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"], f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"], f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"], f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"], f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"], f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"], f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"], f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"], f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"], f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"], } else: raise NotImplemented # else: # # Sharded # # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share # # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is # # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. # state_dict = { # f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][ # f"layers.{layer_i}.input_layernorm.multiway.0.weight" # ].clone(), # f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][ # f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight" # ].clone(), # } # wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], [] # for shard_idx in range(num_shards): # wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"]) # wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"]) # wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"]) # state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( # torch.cat( # [ # wq.view(n_heads_per_shard, hidden_per_head, n_hidden) # for wq in range(wqs) # ], # dim=0, # ).reshape(n_hidden, n_hidden) # ) # state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( # torch.cat( # [ # wk.view(n_heads_per_shard, hidden_per_head, n_hidden) # for wk in range(wks) # ], # dim=0, # ).reshape(n_hidden, n_hidden) # ) # state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( # [ # wv.view(n_heads_per_shard, hidden_per_head, n_hidden) # for wv in range(wvs) # ], # dim=0, # ).reshape(n_hidden, n_hidden) # state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( # [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1 # ) # state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0 # ) # state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1 # ) # state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0 # ) state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq for k, v in state_dict.items(): index_dict["weight_map"][k] = filename param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) print(f'Sharded file saved to {filename}') filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" if num_shards==1: # Unsharded state_dict = { "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'], "model.norm.weight": loaded['encoder']['norm.weight'], "lm_head.weight": loaded['encoder']['lm_head.weight'], } else: state_dict = { "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'], "model.norm.weight": loaded[0]['encoder']['norm.weight'], "lm_head.weight": loaded[0]['encoder']['lm_head.weight'], } loaded_all = torch.load(original_filename, map_location="cpu")['model'] # Vision Part state_dict.update({ "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'], "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'], "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'], "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'], "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'], "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'], "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'], }) for v_layer_idx in range(24): state_dict.update({ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'], f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'], f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'], f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'], f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'], f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'], f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'], }) # Abstractor Part state_dict.update({ "model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'], "model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'], "model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'], "model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'], }) for v_layer_idx in range(6): state_dict.update({ # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed": f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"], # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin", f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"], f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"], }) for k, v in state_dict.items(): index_dict["weight_map"][k] = filename param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) # Write configs index_dict["metadata"] = {"total_size": param_count * 2} write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) config = MPLUGOwl2Config() config.save_pretrained(tmp_model_path) # Make space so we can load the model properly now. del state_dict del loaded del loaded_all gc.collect() def write_tokenizer(tokenizer_path, input_tokenizer_path): # Initialize the tokenizer based on the `spm` model tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") tokenizer = tokenizer_class(input_tokenizer_path) tokenizer.save_pretrained(tokenizer_path) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", help="Location of LLaMA_Megatron weights", ) parser.add_argument( "--model_size", type=int, default=7, choices=[7, 13, 30, 65, 70], ) parser.add_argument( "--num_input_shards", type=int, default=1, ) parser.add_argument( "--num_output_shards", type=int, default=1, ) parser.add_argument('--skip_permute', action='store_true') parser.add_argument( "--output_dir", help="Location to write HF model and tokenizer", ) args = parser.parse_args() write_model( model_path=args.output_dir, input_base_path=args.input_dir, model_size=args.model_size, num_input_shards=args.num_input_shards, num_output_shards=args.num_output_shards, skip_permute=args.skip_permute ) if __name__ == "__main__": main() ================================================ FILE: src/model/modeling_attn_mask_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch class AttentionMaskConverter: """ A utility attention mask class that allows one to: - Create a causal 4d mask - Create a causal 4d mask with slided window - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, key_value_length) that can be multiplied with attention scores Parameters: is_causal (`bool`): Whether the attention mask should be a uni-directional (causal) or bi-directional mask. sliding_window (`int`, *optional*): Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. """ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): self.is_causal = is_causal self.sliding_window = sliding_window if self.sliding_window is not None and self.sliding_window <= 0: raise ValueError( f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" ) def to_causal_4d( self, batch_size: int, query_length: int, key_value_length: int, dtype: torch.dtype = torch.float32, device: Union[torch.device, "str"] = "cpu", ) -> torch.Tensor: """ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative bias to upper right hand triangular matrix (causal mask). """ if not self.is_causal: raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") # If shape is not cached, create a new causal mask and cache it input_shape = (batch_size, query_length) past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if input_shape[-1] > 1 or self.sliding_window is not None: causal_4d_mask = self._make_causal_mask( input_shape, dtype, device=device, past_key_values_length=past_key_values_length, sliding_window=self.sliding_window, ) return causal_4d_mask def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: if key_value_length is None: raise ValueError( "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." ) past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, dtype, device=attention_mask_2d.device, past_key_values_length=past_key_values_length, sliding_window=self.sliding_window, ) elif self.sliding_window is not None: raise NotImplementedError("Sliding window is currently only implemented for causal masking") # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( attention_mask_2d.device ) expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask return expanded_4d_mask @staticmethod def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, sliding_window: Optional[int] = None, ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # add lower triangular sliding window mask if necessary if sliding_window is not None: diagonal = past_key_values_length - sliding_window + 1 context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def _prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], input_shape: Union[torch.Size, Tuple, List], inputs_embeds: torch.Tensor, past_key_values_length: int, sliding_window: Optional[int] = None, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)` Args: attention_mask (`torch.Tensor` or `None`): A 2D attention mask of shape `(batch_size, key_value_length)` input_shape (`tuple(int)` or `list(int)` or `torch.Size`): The input shape should be a tuple that defines `(batch_size, query_length)`. inputs_embeds (`torch.Tensor`): The embedded inputs as a torch Tensor. past_key_values_length (`int`): The length of the key value cache. sliding_window (`int`, *optional*): If the model uses windowed attention, a sliding window should be passed. """ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length # 4d mask is passed through the layers if attention_mask is not None: attention_mask = attn_mask_converter.to_4d( attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype ) else: attention_mask = attn_mask_converter.to_causal_4d( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) return attention_mask def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)` Args: mask (`torch.Tensor` or `None`): A 2D attention mask of shape `(batch_size, key_value_length)` dtype (`torch.dtype`): The torch dtype the created mask shall have. tgt_len (`int`): The target length or query length the created mask shall have. """ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) def _create_4d_causal_attention_mask( input_shape: Union[torch.Size, Tuple, List], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, sliding_window: Optional[int] = None, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` Args: input_shape (`tuple(int)` or `list(int)` or `torch.Size`): The input shape should be a tuple that defines `(batch_size, query_length)`. dtype (`torch.dtype`): The torch dtype the created mask shall have. device (`int`): The torch device the created mask shall have. sliding_window (`int`, *optional*): If the model uses windowed attention, a sliding window should be passed. """ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = past_key_values_length + input_shape[-1] attention_mask = attn_mask_converter.to_causal_4d( input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device ) return attention_mask ================================================ FILE: src/model/modeling_llama2.py ================================================ import math import warnings from functools import partial from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn import copy import os import sys dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) import transformers from transformers.models.llama.modeling_llama import * def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from .configuration_mplug_owl2 import LlamaConfig class MultiwayNetwork(nn.Module): def __init__(self, module_provider, num_multiway=2): super(MultiwayNetwork, self).__init__() self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) def forward(self, hidden_states, multiway_indices): if len(self.multiway) == 1: return self.multiway[0](hidden_states) output_hidden_states = torch.empty_like(hidden_states) for idx, subway in enumerate(self.multiway): local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) hidden = hidden_states[local_indices].unsqueeze(1).contiguous() if hidden.numel(): output = subway(hidden) if isinstance(output, tuple): output = output[0] output = output.squeeze(1) output_hidden_states[local_indices] = output return output_hidden_states.contiguous() class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = MultiwayNetwork(module_provider=partial( nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) ) self.v_proj = MultiwayNetwork(module_provider=partial( nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, modality_indicators: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, ) key_states = self.k_proj(hidden_states, modality_indicators) value_states = self.v_proj(hidden_states, modality_indicators) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class LlamaFlashAttention2(LlamaAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, modality_indicators: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states, modality_indicators) value_states = self.v_proj(hidden_states, modality_indicators) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class LlamaSdpaAttention(LlamaAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from LlamaAttention.forward def forward( self, hidden_states: torch.Tensor, modality_indicators: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, modality_indicators=modality_indicators, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states, modality_indicators) value_states = self.v_proj(hidden_states, modality_indicators) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value LLAMA_ATTENTION_CLASSES = { "eager": LlamaAttention, "flash_attention_2": LlamaFlashAttention2, "sdpa": LlamaSdpaAttention, } class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LlamaAttention(config=config) self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = MultiwayNetwork(module_provider=partial( LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps )) self.post_attention_layernorm = MultiwayNetwork(module_provider=partial( LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps )) def forward( self, hidden_states: torch.Tensor, modality_indicators: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states, modality_indicators) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, modality_indicators=modality_indicators, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs def model_forward( self, input_ids: torch.LongTensor = None, modality_indicators: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, modality_indicators, attention_mask, position_ids, ) else: layer_outputs = decoder_layer( hidden_states, modality_indicators=modality_indicators, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def causal_model_forward( self, input_ids: torch.LongTensor = None, modality_indicators: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, modality_indicators=modality_indicators, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def replace_llama_modality_adaptive(): transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2 transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward if __name__ == "__main__": replace_llama_modality_adaptive() config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/') model = transformers.LlamaForCausalLM(config) print(model) ================================================ FILE: src/model/modeling_mplug_owl2.py ================================================ # Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) from transformers import (AutoConfig, AutoModelForCausalLM, LlamaForCausalLM, LlamaModel) from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_mplug_owl2 import (MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig) from .modeling_llama2 import replace_llama_modality_adaptive from .utils import extend_list, find_prefix from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "<|image|>" from icecream import ic class MPLUGOwl2MetaModel: def __init__(self, config): super(MPLUGOwl2MetaModel, self).__init__(config) self.vision_model = MplugOwlVisionModel( MplugOwlVisionConfig(**config.visual_config["visual_model"]) ) self.visual_abstractor = MplugOwlVisualAbstractorModel( MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size, ) def get_vision_tower(self): vision_model = getattr(self, "vision_model", None) if type(vision_model) is list: vision_model = vision_model[0] return vision_model def get_visual_abstractor(self): visual_abstractor = getattr(self, "visual_abstractor", None) if type(visual_abstractor) is list: visual_abstractor = visual_abstractor[0] return visual_abstractor class MPLUGOwl2MetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def encode_images(self, images): image_features = self.get_model().vision_model(images).last_hidden_state image_features = ( self.get_model() .visual_abstractor(encoder_hidden_states=image_features) .last_hidden_state ) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): if images is None or input_ids.shape[1] == 1: if ( past_key_values is not None and images is not None and input_ids.shape[1] == 1 ): attention_mask = torch.ones( (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device, ) multiway_indices = torch.zeros_like(input_ids).long().to(self.device) return ( input_ids, multiway_indices, attention_mask, past_key_values, None, labels, ) if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_modality_indicators = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal # FIXME: this is a hacky fix, for deepspeed zero3 to work half_len = cur_input_ids.shape[0] // 2 cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens( cur_input_ids[:half_len] ) cur_input_embeds_2 = self.get_model().embed_tokens( cur_input_ids[half_len:] ) cur_input_embeds = torch.cat( [cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0, ) new_input_embeds.append(cur_input_embeds) cur_modality_indicators = ( torch.zeros(len(cur_input_embeds)).long().to(self.device) ) new_modality_indicators.append(cur_modality_indicators) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] cur_modality_indicators = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids[:image_token_start]) ) cur_new_input_embeds.append(cur_image_features) # Add modality indicator assert image_token_start == len(cur_input_ids[:image_token_start]) cur_modality_indicators.append( torch.zeros(len(cur_input_ids[:image_token_start])).long() ) cur_modality_indicators.append( torch.ones(len(cur_image_features)).long() ) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype, ) ) cur_labels = cur_labels[image_token_start + 1 :] cur_image_idx += 1 cur_input_ids = cur_input_ids[image_token_start + 1 :] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids) ) cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long()) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [ x.to(device=self.device) for x in cur_new_input_embeds ] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) # Modality cur_modality_indicators = [ x.to(device=self.device) for x in cur_modality_indicators ] cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0) new_modality_indicators.append(cur_modality_indicators) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) # Embedding new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) # Modality new_modality_indicators_align = [] for cur_modality_indicator in new_modality_indicators: cur_new_embed = torch.cat( ( cur_modality_indicator, torch.zeros( max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device, ), ), dim=0, ) new_modality_indicators_align.append(cur_new_embed) new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0) # Label if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat( ( cur_new_label, torch.full( (max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device, ), ), dim=0, ) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) # Attention Mask if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip( attention_mask, _new_labels, new_labels ): new_attn_mask_pad_left = torch.full( (cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device, ) new_attn_mask_pad_right = torch.full( (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device, ) cur_new_attention_mask = torch.cat( ( new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right, ), dim=0, ) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) new_modality_indicators = torch.stack(new_modality_indicators, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( ( attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1], ), True, dtype=attention_mask.dtype, device=attention_mask.device, ) attention_mask = torch.cat( (new_attn_mask_pad_left, attention_mask), dim=1 ) assert attention_mask.shape == new_input_embeds.shape[:2] return ( None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels, ) class MPLUGOwl2LlamaModel(MPLUGOwl2MetaModel, LlamaModel): config_class = MPLUGOwl2Config def __init__(self, config: MPLUGOwl2Config): super(MPLUGOwl2LlamaModel, self).__init__(config) class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM): config_class = MPLUGOwl2Config def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = MPLUGOwl2LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward(self, input_type=None, **kwargs): if input_type is None: return self.forward_single(**kwargs) elif input_type == "single": kwargs_desp = self.get_subitem(kwargs, task_type="description") kwargs_score = self.get_subitem(kwargs, task_type="score") loss_desp = 0 if len(kwargs_desp["task_types"]) > 0: del kwargs_desp["task_types"] output_desp = self.forward_single(**kwargs_desp) loss_desp = output_desp.loss loss_score = 0 if len(kwargs_score["task_types"]) > 0: del kwargs_score["task_types"] output_score = self.forward_single( use_softkl_loss=self.config.softkl_loss, **kwargs_score, ) loss_score = output_score.loss if dist.get_rank() == 0: loss_desp_item = loss_desp if type(loss_desp) == int else loss_desp.item() loss_score_item = loss_score if type(loss_score) == int else loss_score.item() print( f"[loss (w/o weight) | " f"description loss: {round(loss_desp_item, 6)}, " f"score loss: {round(loss_score_item, 6)}]" ) loss = self.config.weight_desp * loss_desp + self.config.weight_next_token * loss_score return CausalLMOutputWithPast(loss=loss) elif input_type == "pair": return self.forward_pair(**kwargs) else: raise ValueError def softkl_loss(self, logits, labels, level_probs): batch_size = logits.shape[0] level_prefix = torch.tensor(self.config.level_prefix).to(labels.device) idx_prefix_label = find_prefix(labels, level_prefix) # B idx_level_label = idx_prefix_label + level_prefix.shape[0] level_ids_label = labels[torch.arange(batch_size), idx_level_label] for level_id in level_ids_label: assert level_id in self.config.level_ids # After padding in prepare_inputs_labels_for_multimodal(), the length of labels will be the same as logits assert logits.shape[1] == labels.shape[1] idx_level_logit = idx_level_label - 1 logits_level_ids = logits[ torch.arange(batch_size), idx_level_logit ].contiguous() # [B, V] preds = torch.softmax(logits_level_ids, dim=1) # [B, V] target = torch.zeros_like(preds) # [B, V] target[:, self.config.level_ids] = level_probs target = target.detach() pred_log = torch.log(preds) loss_kl = F.kl_div(pred_log, target, reduction="batchmean") return loss_kl, idx_level_label, idx_level_logit def forward_single( self, input_ids: torch.LongTensor = None, # modality_indicators: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, use_softkl_loss: Optional[bool] = None, level_probs: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, modality_indicators=modality_indicators, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss_kl = None if use_softkl_loss and labels is not None: loss_kl, idx_level_label, idx_level_logit = self.softkl_loss(logits, labels, level_probs) def del_elements(source, idx): """source: [B, N] / [B, N, V], idx: [B, ] with the value range [0, N-1]""" mask = torch.ones([*source.shape[:2]], dtype=torch.bool) for idx_1, idx_del in enumerate(idx): mask[idx_1, idx_del] = False if len(source.shape) == 2: source_del = source[mask].view(source.size(0), source.size(1)-1) else: assert len(source.shape) == 3 source_del = source[mask].view(source.size(0), source.size(1)-1, source.size(2)) return source_del labels_del = del_elements(labels, idx_level_label) logits_del = del_elements(logits, idx_level_logit) loss = None if labels is not None: # Shift so that tokens < n predict n if loss_kl is None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() else: shift_logits = logits_del[..., :-1, :].contiguous() shift_labels = labels_del[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output if loss is not None and loss_kl is not None: loss = loss + self.config.weight_softkl * loss_kl return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def get_score(self, item): outputs = self.forward_single( input_ids=item["input_ids"], attention_mask=item["attention_mask"], labels=item["labels"], images=item["images"], return_dict=True, use_softkl_loss=self.config.softkl_loss, level_probs=item["level_probs"], ) batch_size = outputs.logits.shape[0] level_prefix = torch.tensor(self.config.level_prefix).to(item["labels"].device) idx_prefix_label = find_prefix(item["labels"], level_prefix) # B idx_level_label = idx_prefix_label + level_prefix.shape[0] level_ids_label = item["labels"][torch.arange(batch_size), idx_level_label] for level_id in level_ids_label: assert level_id in self.config.level_ids num_vision_tokens = outputs.logits.shape[1] - item["labels"].shape[1] idx_level_logit = idx_level_label + num_vision_tokens - 1 logits_level_ids = outputs.logits[ torch.arange(batch_size), idx_level_logit ].contiguous() # [B, V] probs_org = torch.softmax(logits_level_ids, dim=1) # [B, V] loss_in_level = 1 - probs_org[:, self.config.level_ids].contiguous().sum(dim=1) # [B, 5] -> [B, ] bound = torch.tensor(1e-2).to(loss_in_level) loss_in_level = torch.max(bound, loss_in_level.mean()) # level prob > 0.99 if self.config.closeset_rating_loss: logits_levels = logits_level_ids[:, self.config.level_ids].contiguous() probs = torch.softmax(logits_levels, dim=1) else: probs = probs_org[:, self.config.level_ids].contiguous() weights = torch.tensor([5, 4, 3, 2, 1]).to(probs) scores = torch.matmul(probs, weights) variances = (weights.repeat(batch_size, 1) - scores.unsqueeze(1)) ** 2 stds = torch.sqrt(torch.sum(probs * variances, dim=1)) return scores, stds, outputs.loss, loss_in_level def get_subitem(self, item, task_type): for key in list(item.keys()): if item[key] is None: del item[key] subitem = {} for key in item: subitem[key] = [] for idx in range(len(item["task_types"])): if item["task_types"][idx] == task_type: for key in item: subitem[key].append(item[key][idx]) batch_size = torch.tensor(len(subitem["task_types"])).cuda() world_size = dist.get_world_size() batch_size_allrank = [torch.tensor(0).cuda() for _ in range(world_size)] dist.barrier() dist.all_gather(batch_size_allrank, batch_size) batch_size_max = torch.stack(batch_size_allrank, dim=0).max().item() batch_size_min = torch.stack(batch_size_allrank, dim=0).min().item() for key in item: subitem[key] = extend_list(subitem[key], batch_size_max, batch_size_min) if torch.is_tensor(item[key]) and len(subitem[key]): subitem[key] = torch.stack(subitem[key], dim=0) return subitem def forward_pair(self, item_A, item_B, **kwargs): item_A_desp = self.get_subitem(item_A, task_type="description") item_B_desp = self.get_subitem(item_B, task_type="description") assert item_A_desp["task_types"] == item_B_desp["task_types"] item_A_score = self.get_subitem(item_A, task_type="score") item_B_score = self.get_subitem(item_B, task_type="score") assert item_A_score["task_types"] == item_B_score["task_types"] # calculate loss_desp for description tasks loss_desp = 0 if len(item_A_desp["task_types"]) > 0: outputs = self.forward_single( input_ids=item_A_desp["input_ids"], attention_mask=item_A_desp["attention_mask"], labels=item_A_desp["labels"], images=item_A_desp["images"], return_dict=True, use_softkl_loss=False, ) loss_desp = outputs.loss # calculate loss_score for score tasks loss_score = 0 if len(item_A_score["task_types"]) > 0: gt_scores_A = item_A_score["gt_scores"] pred_scores_A, pred_stds_A, loss_next_token_A, loss_in_level_A = self.get_score(item_A_score) gt_scores_B = item_B_score["gt_scores"] pred_scores_B, pred_stds_B, loss_next_token_B, loss_in_level_B = self.get_score(item_B_score) if not self.config.continuous_rating_loss: loss_rank = self.binary_rating_loss(pred_scores_A, gt_scores_A, pred_scores_B, gt_scores_B) else: gt_stds_A = item_A_score["stds"] gt_stds_B = item_B_score["stds"] assert (gt_stds_A >= 0).all() and (gt_stds_B >= 0).all() loss_rank = self.rating_loss( pred_scores_A, pred_stds_A, gt_scores_A, gt_stds_A, pred_scores_B, pred_stds_B, gt_scores_B, gt_stds_B, ) loss_next_token = loss_next_token_A + loss_next_token_B loss_in_level = loss_in_level_A + loss_in_level_B if dist.get_rank() == 0: print( f"[score loss (w/o weight) | " f"ranking loss: {round(loss_rank.item(), 6)}, " f"next token loss: {round(loss_next_token.item(), 6)}, " f"in level loss: {round(loss_in_level.item(), 6)}]" ) loss_rank = self.config.weight_rank * loss_rank if self.config.weight_next_token: assert self.config.weight_next_token > 0 loss_next_token = self.config.weight_next_token * loss_next_token else: loss_next_token = 0 if self.config.weight_in_level: assert self.config.weight_in_level > 0 loss_in_level = self.config.weight_in_level * loss_in_level else: loss_in_level = 0 loss_score = loss_rank + loss_next_token + loss_in_level if dist.get_rank() == 0: loss_desp_item = loss_desp if type(loss_desp) == int else loss_desp.item() loss_score_item = loss_score if type(loss_score) == int else loss_score.item() print( f"[loss (w/o weight) | " f"description loss: {round(loss_desp_item, 6)}, " f"score loss: {round(loss_score_item, 6)}]" ) loss = self.config.weight_desp * loss_desp + loss_score return CausalLMOutputWithPast(loss=loss) def rating_loss( self, pred_scores_A, pred_stds_A, gt_scores_A, gt_stds_A, pred_scores_B, pred_stds_B, gt_scores_B, gt_stds_B, ): # eps=1e-8 is important. eps=0 is unable to step, and lr keeps unchanged. eps = 1e-8 if self.config.use_fix_std: pred = 0.5 * (1 + torch.erf((pred_scores_A - pred_scores_B) / 2)) # 2 -> sqrt(2 * (1**2 + 1**2)) else: pred_var = pred_stds_A * pred_stds_A + pred_stds_B * pred_stds_B + eps if self.config.detach_pred_std: pred_var = pred_var.detach() pred = 0.5 * (1 + torch.erf((pred_scores_A - pred_scores_B) / torch.sqrt(2 * pred_var))) gt_var = gt_stds_A * gt_stds_A + gt_stds_B * gt_stds_B + eps gt = 0.5 * (1 + torch.erf((gt_scores_A - gt_scores_B) / torch.sqrt(2 * gt_var))).to(pred.device) gt = gt.detach() loss = (1 - (pred * gt + eps).sqrt() - ((1 - pred) * (1 - gt) + eps).sqrt()).mean() return loss def binary_rating_loss(self, pred_scores_A, gt_scores_A, pred_scores_B, gt_scores_B): pred = 0.5 * (1 + torch.erf((pred_scores_A - pred_scores_B) / 2)) # 2 -> sqrt(2 * (1**2 + 1**2)) gt = (gt_scores_A > gt_scores_B).to(pred.dtype).to(pred.device) gt = gt.detach() if self.config.binary_rating_loss == "bce": loss = F.binary_cross_entropy(pred, gt) elif self.config.binary_rating_loss == "fidelity": loss_1 = 1 - pred[gt == 1].sqrt() loss_2 = 1 - (1 - pred[gt == 0]).sqrt() loss = (loss_1.sum() + loss_2.sum()) / pred_scores_A.shape[0] else: raise NotImplementedError(f"Wrong type of binary_rating_loss: {self.config.binary_rating_loss}") return loss def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, images=None, **kwargs, ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": images, } ) return model_inputs AutoConfig.register("mplug_owl2", MPLUGOwl2Config) AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM) replace_llama_modality_adaptive() if __name__ == "__main__": config = MPLUGOwl2Config.from_pretrained("zhiyuanyou/DeQA-Score-Mix3") from icecream import ic # config = MPLUGOwl2Config() model = AutoModelForCausalLM(config) images = torch.randn(2, 3, 448, 448) input_ids = torch.cat( [ torch.ones(8).long(), torch.tensor([-1] * 1).long(), torch.ones(8).long(), torch.tensor([-1] * 1).long(), torch.ones(8).long(), ], dim=0, ).unsqueeze(0) labels = input_ids.clone() labels[labels < 0] = -100 # image_feature = model.encode_images(images) # ic(image_feature.shape) output = model(images=images, input_ids=input_ids, labels=labels) ic(output.loss) ic(output.logits.shape) ================================================ FILE: src/model/utils.py ================================================ import torch from transformers import AutoConfig def extend_list(data_list, n, min_n): if min_n == 0: return [] while len(data_list) < n: data_list.extend(data_list[:n - len(data_list)]) return data_list def find_prefix(input_ids, prefix): """ input_ids: [B, N1], no start token prefix: [N2, ], no start token """ len_prefix = prefix.shape[0] # N2 # Create all possible windows of len_prefix input_ids_unfold = input_ids.unfold(1, len_prefix, 1) # Check if all elements in the window match the sequence matches = (input_ids_unfold == prefix).all(dim=2) # Convert boolean matches to integers for argmax operation matches_int = matches.type(torch.int64) # Calculate indices for the first match, if any, otherwise set to -1 indices = torch.where( matches.any(dim=1), matches_int.argmax(dim=1), torch.tensor(-1, dtype=torch.int64), ) assert (indices >= 0).all(), "Some inputs do not contain prefix" return indices def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if "mplug_owl2" in config and "mplug_owl2" not in cfg.model_type: assert cfg.model_type == "mplug_owl2" print( "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." ) print( "You must upgrade the checkpoint to the new code base (this can be done automatically)." ) confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "mplug_owl2") cfg.architectures[0] = "LlavaLlamaForCausalLM" cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) ================================================ FILE: src/model/visual_encoder.py ================================================ import math from typing import Any, Optional, Tuple, Union from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint from icecream import ic def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M # return: M, C src_size = int(math.sqrt(abs_pos.size(0))) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class MplugOwlVisionEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size)) self.patch_embed = nn.Conv2d( in_channels=3, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size)) self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.size(0) image_embeds = self.patch_embed(pixel_values) image_embeds = image_embeds.flatten(2).transpose(1, 2) class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype) embeddings = torch.cat([class_embeds, image_embeds], dim=1) embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype) embeddings = self.pre_layernorm(embeddings) return embeddings class MplugOwlVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = nn.Dropout(config.attention_dropout) self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size) self.dense = nn.Linear(self.hidden_size, self.hidden_size) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, seq_len, embed_dim = hidden_states.size() mixed_qkv = self.query_key_value(hidden_states) mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute( 3, 0, 2, 1, 4 ) # [3, b, np, sq, hn] query_states, key_states, value_states = ( mixed_qkv[0], mixed_qkv[1], mixed_qkv[2], ) # if self.config.use_flash_attn and flash_attn_func is not None: if False: # [b*sq, np, hn] query_states = query_states.permute(0, 2, 1, 3).contiguous() query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1) key_states = key_states.permute(0, 2, 1, 3).contiguous() key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1) value_states = value_states.permute(0, 2, 1, 3).contiguous() value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1) cu_seqlens = torch.arange( 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device ) context_layer = flash_attn_func( query_states, key_states, value_states, cu_seqlens, cu_seqlens, seq_len, seq_len, self.dropout if self.training else 0.0, softmax_scale=self.scale, causal=False, return_attn_probs=False, ) # [b*sq, np, hn] => [b, sq, np, hn] context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2)) else: # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) attention_scores = attention_scores * self.scale # Normalize the attention scores to probabilities. attention_probs = torch.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) context_layer = context_layer.reshape(new_context_layer_shape) output = self.dense(context_layer) outputs = (output, attention_probs) if output_attentions else (output, None) return outputs class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class MplugOwlMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = QuickGELU() self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class MplugOwlVisionEncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.self_attn = MplugOwlVisionAttention(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) self.mlp = MplugOwlMLP(config) self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, head_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = hidden_states + residual outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class MplugOwlVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`MplugOwlVisionEncoderLayer`]. Args: config (`MplugOwlVisionConfig`): The corresponding vision configuration for the `MplugOwlEncoder`. """ def __init__(self, config): super().__init__() self.config = config self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = True def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Embedded representation of the inputs. Should be float, not int tokens. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class MplugOwlVisionModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["MplugOwlVisionEncoderLayer"] def __init__(self, config): super().__init__(config) self.config = config self.hidden_size = config.hidden_size self.embeddings = MplugOwlVisionEmbeddings(config) self.encoder = MplugOwlVisionEncoder(config) self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def get_input_embeddings(self): return self.embeddings class MplugOwlVisualAbstractorMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config in_features = config.hidden_size self.act = nn.SiLU() self.w1 = nn.Linear(in_features, config.intermediate_size) self.w2 = nn.Linear(config.intermediate_size, in_features) self.w3 = nn.Linear(in_features, config.intermediate_size) self.ffn_ln = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states) hidden_states = self.ffn_ln(hidden_states) hidden_states = self.w2(hidden_states) return hidden_states class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.save_attention = False # self.q_pos_embed = nn.Parameter( # torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float() # ).requires_grad_(False) # grids = config.grid_size # self.k_pos_embed = nn.Parameter( # torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float() # ).requires_grad_(False) grids = config.grid_size self.register_buffer( 'q_pos_embed', torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float() ) self.register_buffer( 'k_pos_embed', torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float() ) def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. qk_pos_embed = torch.cat([self.q_pos_embed, self.k_pos_embed], dim = 0).unsqueeze(0).to(dtype=hidden_states.dtype) key_layer = self.transpose_for_scores(self.key(encoder_hidden_states + qk_pos_embed)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask mixed_query_layer = self.query(hidden_states + self.q_pos_embed.unsqueeze(0).to(dtype=hidden_states.dtype)) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = outputs + (past_key_value,) return outputs class MplugOwlVisualAbstractorCrossOutput(nn.Module): def __init__(self, config): super().__init__() dim = config.hidden_size self.out_proj = nn.Linear(dim, dim, bias=True) self.norm2 = nn.LayerNorm(dim) self.mlp = MplugOwlVisualAbstractorMLP(config) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: input_tensor = input_tensor + self.out_proj(hidden_states) input_tensor = input_tensor + self.mlp(self.norm2(input_tensor)) return input_tensor class MplugOwlVisualAbstractorAttention(nn.Module): def __init__(self, config): super().__init__() self.attention = MplugOwlVisualAbstractorMultiHeadAttention(config) self.output = MplugOwlVisualAbstractorCrossOutput(config) self.pruned_heads = set() self.norm1 = nn.LayerNorm(config.hidden_size) self.normk = nn.LayerNorm(config.hidden_size) def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.out_proj, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: # HACK we apply norm on q and k hidden_states = self.norm1(hidden_states) encoder_hidden_states = self.normk(encoder_hidden_states) encoder_hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) encoder_attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1) self_outputs = self.attention( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] return outputs class MplugOwlVisualAbstractorLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.layer_idx = layer_idx self.crossattention = MplugOwlVisualAbstractorAttention(config) self.has_cross_attention = True def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, ): if encoder_hidden_states is None: raise ValueError("encoder_hidden_states must be given for cross-attention layers") cross_attention_outputs = self.crossattention( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = (query_attention_output,) return outputs class MplugOwlVisualAbstractorEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layers = nn.ModuleList( [MplugOwlVisualAbstractorLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.gradient_checkpointing = True def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None for i in range(self.config.num_hidden_layers): layer_module = self.layers[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, ) hidden_states = layer_outputs[0] return BaseModelOutput( last_hidden_state=hidden_states, ) class MplugOwlVisualAbstractorModel(PreTrainedModel): _no_split_modules = ["MplugOwlVisualAbstractorLayer"] def __init__(self, config, language_hidden_size): super().__init__(config) self.config = config self.encoder = MplugOwlVisualAbstractorEncoder(config) self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size) self.query_embeds = torch.nn.Parameter(torch.randn(1, config.num_learnable_queries, config.hidden_size)) self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size)) self.post_init() def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device, ) -> torch.Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`Tuple[int]`): The shape of the input to the model. device: (`torch.device`): The device of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape ) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict query_embeds = self.query_embeds.repeat(encoder_hidden_states.shape[0], 1, 1) embedding_output = query_embeds input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask is None: attention_mask = torch.ones( (query_embeds.shape[0], query_embeds.shape[1]), dtype=torch.long, device=query_embeds.device ) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] pooled_output = sequence_output[:, 0, :] sequence_output = self.visual_fc(sequence_output) sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1) return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, ) ================================================ FILE: src/train/mplug_owl2_trainer.py ================================================ import os from typing import List, Optional import torch from icecream import ic from torch.utils.data import Sampler from transformers import Trainer from transformers.trainer import (ALL_LAYERNORM_LAYERS, get_parameter_names, has_length, is_sagemaker_mp_enabled, logger) def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return def split_to_even_chunks(indices, lengths, num_chunks): """ Split a list of indices into `chunks` chunks of roughly equal lengths. """ if len(indices) % num_chunks != 0: return [indices[i::num_chunks] for i in range(num_chunks)] num_indices_per_chunk = len(indices) // num_chunks chunks = [[] for _ in range(num_chunks)] chunks_lengths = [0 for _ in range(num_chunks)] for index in indices: shortest_chunk = chunks_lengths.index(min(chunks_lengths)) chunks[shortest_chunk].append(index) chunks_lengths[shortest_chunk] += lengths[index] if len(chunks[shortest_chunk]) == num_indices_per_chunk: chunks_lengths[shortest_chunk] = float("inf") return chunks def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): # all samples are in the same modality return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] megabatch_size = world_size * batch_size mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] last_mm = mm_megabatches[-1] last_lang = lang_megabatches[-1] additional_batch = last_mm + last_lang megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] if len(additional_batch) > 0: megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) megabatch_size = world_size * batch_size megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] return [i for megabatch in megabatches for batch in megabatch for i in batch] class LengthGroupedSampler(Sampler): r""" Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ def __init__( self, batch_size: int, world_size: int, lengths: Optional[List[int]] = None, generator=None, group_by_modality: bool = False, ): if lengths is None: raise ValueError("Lengths must be provided.") self.batch_size = batch_size self.world_size = world_size self.lengths = lengths self.generator = generator self.group_by_modality = group_by_modality def __len__(self): return len(self.lengths) def __iter__(self): if self.group_by_modality: indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) else: indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) return iter(indices) class MPLUGOwl2Trainer(Trainer): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None if self.args.group_by_modality_length: lengths = self.train_dataset.modality_lengths return LengthGroupedSampler( self.args.train_batch_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, lengths=lengths, group_by_modality=True, ) else: return super()._get_train_sampler() def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ #if is_sagemaker_mp_enabled(): # return super().create_optimizer() #if self.sharded_ddp == ShardedDDPOption.SIMPLE: # return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, forbidden_layer_types=ALL_LAYERNORM_LAYERS) # params except normlayers decay_parameters = [name for name in decay_parameters if "bias" not in name] # params except normlayers and bias if self.args.visual_abstractor_lr is not None: projector_parameters = [name for name, _ in opt_model.named_parameters() if "visual_abstractor_lr" in name] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) ], "weight_decay": 0.0, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, "lr": self.args.visual_abstractor_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.visual_abstractor_lr, }, ] else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params'])) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) if True: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def _save_checkpoint(self, model, trial, metrics=None): super(MPLUGOwl2Trainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): super(MPLUGOwl2Trainer, self)._save(output_dir, state_dict) ================================================ FILE: src/train/train_mem.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from dataclasses import dataclass, field from typing import Dict, List, Optional from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import torch import transformers from transformers.models.clip.image_processing_clip import CLIPImageProcessor from src import conversation as conversation_lib from src.datasets import make_data_module from src.model import * from src.train.mplug_owl2_trainer import MPLUGOwl2Trainer local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) @dataclass class DataArguments: dataset_type: str = "single" # [single, pair] data_paths: List[str] = field(default_factory=lambda: []) data_weights: List[int] = field(default_factory=lambda: []) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = "square" image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) level_prefix: str = field(default="") level_names: List[str] = field(default_factory=lambda: []) weight_desp: float = field(default=1.0, metadata={"help": "Absolute weight of description loss."}) weight_rank: float = field(default=1.0, metadata={"help": "Absolute weight of ranking loss."}) softkl_loss: bool = field( default=False, metadata={ "help": "If True, use softkl_loss for level token; else, use next token loss." }, ) weight_softkl: float = field( default=1.0, metadata={ "help": "Relative weight of softkl loss (w.r.t weight of next token loss as 1.0)." }, ) weight_next_token: float = field(default=1.0, metadata={"help": "Absolute weight of next token loss."}) weight_in_level: float = field(default=None, metadata={"help": "Absolute weight of in level loss."}) continuous_rating_loss: bool = field( default=True, metadata={ "help": "Used in pair dataset. If True, use continuous_rating_loss; else, use binary_rating_loss.", }, ) binary_rating_loss: str = field( default="fidelity", metadata={ "help": "Used in pair dataset if continuous_rating_loss is False or no std in dataset. bce loss / fidelity loss.", "choices": ["bce", "fidelity"], }, ) closeset_rating_loss: bool = field( default=False, metadata={ "help": "Used in pair dataset. If True, softmax in closeset; else, softmax in openset." }, ) use_fix_std: bool = field( default=True, metadata={ "help": "Use fixed std or predicted std." }, ) detach_pred_std: bool = field( default=False, metadata={ "help": "Detach predicted std." }, ) tune_visual_abstractor: bool = field(default=False) freeze_vision_model: bool = field(default=True) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={ "help": "Compress the quantization statistics through double quantization." }, ) quant_type: str = field( default="nf4", metadata={ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." }, ) bits: int = field(default=16, metadata={"help": "How many bits to use."}) lora_enable: bool = False lora_r: int = 128 lora_alpha: int = 256 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" visual_abstractor_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) save_safetensors: bool = False def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning( f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" ) with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = { k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match) } to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def find_all_lora_names(model): lora_module_names = set() multimodal_keywords = ["vision_model", "visual_abstractor"] for name, _ in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if "v_proj.multiway.1" in name or "q_proj" in name: lora_module_names.add(name) ls = list(lora_module_names) print(ls) return ls def safe_save_model_for_hf_trainer( trainer: transformers.Trainer, output_dir: str, ): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments) ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = ( torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) ) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update( dict( # device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} ), ) ) model = MPLUGOwl2LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation="flash_attention_2", torch_dtype=compute_dtype, **bnb_model_from_pretrained_args, ) print(model.config) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype = ( torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) ) model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=training_args.gradient_checkpointing ) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_lora_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[ model_args.version ] else: conversation_lib.default_conversation = conversation_lib.conv_templates[ "vicuna_v1" ] if not training_args.freeze_vision_model and training_args.bits in [4, 8]: model.get_model().vision_model.to( dtype=compute_dtype, device=training_args.device ) else: vision_tower = model.get_model().vision_model vision_tower.to( dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device, ) if training_args.tune_visual_abstractor and training_args.bits in [4, 8]: model.get_model().visual_abstractor.to( dtype=compute_dtype, device=training_args.device ) else: visual_abstractor = model.get_model().visual_abstractor visual_abstractor.to( dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device, ) data_args.image_processor = CLIPImageProcessor.from_pretrained( model_args.model_name_or_path ) data_args.is_multimodal = True model.config.softkl_loss = training_args.softkl_loss model.config.weight_desp = training_args.weight_desp model.config.weight_next_token = training_args.weight_next_token if data_args.dataset_type == "pair": model.config.weight_rank = training_args.weight_rank model.config.weight_in_level = training_args.weight_in_level model.config.continuous_rating_loss = training_args.continuous_rating_loss model.config.binary_rating_loss = training_args.binary_rating_loss model.config.closeset_rating_loss = training_args.closeset_rating_loss model.config.use_fix_std = training_args.use_fix_std model.config.detach_pred_std = training_args.detach_pred_std if training_args.level_prefix and training_args.level_names: model.config.weight_softkl = training_args.weight_softkl model.config.level_prefix = tokenizer(training_args.level_prefix).input_ids[1:] # index 1: no need start token for level_name in training_args.level_names: level_id = tokenizer(level_name)["input_ids"] assert len(level_id) == 2 and level_id[0] == 1 model.config.level_ids = [ id_[1] for id_ in tokenizer(training_args.level_names).input_ids ] # index 1: no need start token model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints for n, p in model.named_parameters(): if training_args.lora_enable: p.requires_grad = True if "lora_" in n else False # if "lm_head" in n: # print(n) # p.requires_grad = True else: p.requires_grad = True if training_args.lora_enable: model.print_trainable_parameters() model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = ( training_args.tune_visual_abstractor ) print(training_args.tune_visual_abstractor) model.get_model().visual_abstractor.requires_grad_(False) if training_args.tune_visual_abstractor: for n, p in model.get_model().visual_abstractor.named_parameters(): p.requires_grad = True model.config.freeze_vision_model = training_args.freeze_vision_model print(training_args.freeze_vision_model) model.get_model().vision_model.requires_grad_(True) if training_args.freeze_vision_model: for p in model.get_model().vision_model.parameters(): p.requires_grad = False if training_args.lora_enable: model.print_trainable_parameters() model.config.visual_abstractor_lr = training_args.visual_abstractor_lr if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if "norm" in name: module = module.to(torch.float32) if "lm_head" in name or "embed_tokens" in name: if hasattr(module, "weight"): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_data_module(tokenizer=tokenizer, data_args=data_args) trainer = MPLUGOwl2Trainer( model=model, tokenizer=tokenizer, args=training_args, **data_module ) # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): # trainer.train(resume_from_checkpoint=True) # else: # trainer.train() # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save( non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"), ) else: safe_save_model_for_hf_trainer( trainer=trainer, output_dir=training_args.output_dir, ) if __name__ == "__main__": train() ================================================ FILE: src/utils.py ================================================ import datetime import logging import logging.handlers import os import sys import requests from mplug_owl2.constants import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" ================================================ FILE: tests/datasets/test_pair_dataset.py ================================================ from dataclasses import dataclass, field from typing import List, Optional import transformers from torch.utils.data import DataLoader from transformers import AutoTokenizer from transformers.models.clip.image_processing_clip import CLIPImageProcessor from src.datasets import make_data_module @dataclass class DataArguments: dataset_type: str = "pair" data_paths: List[str] = field(default_factory=lambda: []) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = "square" image_grid_pinpoints: Optional[str] = field(default=None) if __name__ == "__main__": cfg_path = "./preprocessor" tokenizer = AutoTokenizer.from_pretrained(cfg_path, use_fast=False) parser = transformers.HfArgumentParser(DataArguments) (data_args,) = parser.parse_args_into_dataclasses() data_args.image_folder = "../Data-DeQA-Score/" data_args.data_paths = [ "../Data-DeQA-Score//KONIQ/metas/train_koniq_7k.json", "../Data-DeQA-Score//SPAQ/metas/train_spaq_9k.json", "../Data-DeQA-Score//KADID10K/metas/train_kadid_8k.json", ] data_args.data_weights = [1,1,1] data_args.image_processor = CLIPImageProcessor.from_pretrained(cfg_path) data_args.is_multimodal = True data_module = make_data_module(tokenizer=tokenizer, data_args=data_args) train_dataset = data_module["train_dataset"] collate_fn = data_module["data_collator"] data_loader = DataLoader( train_dataset, batch_size=4, collate_fn=collate_fn, shuffle=True ) for idx, data in enumerate(data_loader): print("=" * 100) print(f"{idx} / {len(data_loader)}") print(data.keys()) print(data["item_A"].keys()) print(data["item_B"].keys()) print(data["item_A"]["image_files"]) print(data["item_A"]["gt_scores"]) print(data["item_B"]["image_files"]) print(data["item_B"]["gt_scores"]) ================================================ FILE: tests/datasets/test_uncertainty_levels.py ================================================ from transformers import AutoTokenizer if __name__ == "__main__": model_base = "./preprocessor" tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print("=" * 100) levels = ["minimal", "low", "medium", "high", "severe"] for level in levels: input_ids = tokenizer(level)["input_ids"] assert len(input_ids) == 2 and input_ids[0] == 1 text = tokenizer.decode(input_ids[1]) print(input_ids, text) ================================================ FILE: tests/model/test_find_prefix.py ================================================ import torch from transformers import AutoTokenizer def find_prefix(input_ids, prefix): """ input_ids: [B, N1], no start token prefix: [N2, ], no start token """ len_prefix = prefix.shape[0] # N2 # Create all possible windows of len_prefix input_ids_unfold = input_ids.unfold(1, len_prefix, 1) # Check if all elements in the window match the sequence matches = (input_ids_unfold == prefix).all(dim=2) # Convert boolean matches to integers for argmax operation matches_int = matches.type(torch.int64) # Calculate indices for the first match, if any, otherwise set to -1 indices = torch.where( matches.any(dim=1), matches_int.argmax(dim=1), torch.tensor(-1, dtype=torch.int64), ) return indices pretrained = "./preprocessor" tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False) # Example input_ids, [B, N1] input_ids = tokenizer( [ "I am happy. The quality of the image is good.", "ABCD. The quality of the image is poor.", "Please go to school. The quality of the image is excellent.", ], return_tensors="pt", padding=True, ).input_ids[:, 1:] print("=" * 100) print("input_ids: ") print(input_ids) # Example prefix, [N2, ] prefix = tokenizer("The quality of the image is", return_tensors="pt").input_ids[0, 1:] print("=" * 100) print("prefix: ") print(prefix) # Find prefix indices indices = find_prefix(input_ids, prefix) print("=" * 100) print("Indices of the prefix: ") print(indices) ================================================ FILE: tests/model/test_grad.py ================================================ import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self, dim_in, closeset): super().__init__() self.closeset = closeset self.lm_head = nn.Linear(dim_in, 1024) self.level_ids = [128, 256, 512, 640, 768] def forward(self, x_A, x_B, gt): logits_A = self.lm_head(x_A) # [B, V] scores_A, loss_inlevel_A = self.get_score(logits_A) # [B, ] logits_B = self.lm_head(x_B) # [B, V] scores_B, loss_inlevel_B = self.get_score(logits_B) # [B, ] loss = self.rating_loss(scores_A, scores_B, gt) return loss def get_score(self, logits): probs_org = torch.softmax(logits, dim=1) # [B, V] loss_in_level = 1 - probs_org[:, self.level_ids].contiguous().sum(dim=1) # [B, 5] -> [B, ] loss_in_level = loss_in_level.mean() # level prob > 0.99 if self.closeset: logits_levels = logits[:, self.level_ids].contiguous() probs = torch.softmax(logits_levels, dim=1) else: probs = probs_org[:, self.level_ids].contiguous() weights = torch.tensor([5., 4., 3., 2., 1.]).to(probs) scores = torch.matmul(probs, weights) return scores, loss_in_level def rating_loss(self, pred_scores_A, pred_scores_B, gt): pred = 0.5 * (1 + torch.erf((pred_scores_A - pred_scores_B) / 2)) # 2 -> sqrt(2 * (1**2 + 1**2)) # eps=1e-8 is important. eps=0 is unable to step, and lr keeps unchanged. eps = 1e-8 loss = (1 - (pred * gt + eps).sqrt() - ((1 - pred) * (1 - gt) + eps).sqrt()).mean() return loss if __name__ == "__main__": dim_in = 512 closeset = False batch_size = 8 num_epoch = 5 num_step = 100 lr = 0.1 model = MyModel(dim_in, closeset) parameters = model.parameters() optim = torch.optim.AdamW( parameters, lr = lr, ) for epoch in range(num_epoch): for step in range(num_step): x_A = torch.rand(batch_size, dim_in) x_B = torch.rand(batch_size, dim_in) gt = torch.rand(batch_size) loss = model(x_A, x_B, gt) optim.zero_grad() loss.backward() optim.step() print("=" * 100) print(model.lm_head.weight.grad) print(model.lm_head.weight)