Repository: Vision-CAIR/MiniGPT-4
Branch: main
Commit: d94738a7626e
Files: 126
Total size: 491.5 KB
Directory structure:
gitextract_bx6vsnx8/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE.md
├── LICENSE_Lavis.md
├── MiniGPT4_Train.md
├── MiniGPTv2_Train.md
├── README.md
├── SECURITY.md
├── dataset/
│ ├── README_1_STAGE.md
│ ├── README_2_STAGE.md
│ ├── README_MINIGPTv2_FINETUNE.md
│ ├── convert_cc_sbu.py
│ └── convert_laion.py
├── demo.py
├── demo_v2.py
├── environment.yml
├── eval_configs/
│ ├── minigpt4_eval.yaml
│ ├── minigpt4_llama2_eval.yaml
│ ├── minigptv2_benchmark_evaluation.yaml
│ └── minigptv2_eval.yaml
├── eval_scripts/
│ ├── EVAL_README.md
│ ├── eval_ref.py
│ └── eval_vqa.py
├── minigpt4/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dist_utils.py
│ │ ├── eval_utils.py
│ │ ├── gradcam.py
│ │ ├── logger.py
│ │ ├── optims.py
│ │ ├── registry.py
│ │ ├── utils.py
│ │ └── vqa_tools/
│ │ ├── VQA/
│ │ │ ├── PythonEvaluationTools/
│ │ │ │ ├── vqaEvalDemo.py
│ │ │ │ └── vqaEvaluation/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqaEval.py
│ │ │ ├── PythonHelperTools/
│ │ │ │ ├── vqaDemo.py
│ │ │ │ └── vqaTools/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqa.py
│ │ │ ├── QuestionTypes/
│ │ │ │ ├── abstract_v002_question_types.txt
│ │ │ │ └── mscoco_question_types.txt
│ │ │ ├── README.md
│ │ │ └── license.txt
│ │ ├── __init__.py
│ │ ├── vqa.py
│ │ └── vqa_eval.py
│ ├── configs/
│ │ ├── datasets/
│ │ │ ├── aokvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── cc_sbu/
│ │ │ │ ├── align.yaml
│ │ │ │ └── defaults.yaml
│ │ │ ├── coco/
│ │ │ │ ├── caption.yaml
│ │ │ │ └── defaults_vqa.yaml
│ │ │ ├── coco_bbox/
│ │ │ │ ├── invrefcoco.yaml
│ │ │ │ ├── invrefcocog.yaml
│ │ │ │ ├── invrefcocop.yaml
│ │ │ │ ├── refcoco.yaml
│ │ │ │ ├── refcocog.yaml
│ │ │ │ └── refcocop.yaml
│ │ │ ├── flickr/
│ │ │ │ ├── caption_to_phrase.yaml
│ │ │ │ ├── default.yaml
│ │ │ │ └── object_to_phrase.yaml
│ │ │ ├── gqa/
│ │ │ │ └── balanced_val.yaml
│ │ │ ├── laion/
│ │ │ │ └── defaults.yaml
│ │ │ ├── llava/
│ │ │ │ ├── conversation.yaml
│ │ │ │ ├── detail.yaml
│ │ │ │ └── reason.yaml
│ │ │ ├── multitask_conversation/
│ │ │ │ └── default.yaml
│ │ │ ├── nlp/
│ │ │ │ └── unnatural_instruction.yaml
│ │ │ ├── ocrvqa/
│ │ │ │ └── ocrvqa.yaml
│ │ │ ├── okvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── textcaps/
│ │ │ │ └── caption.yaml
│ │ │ └── vg/
│ │ │ └── ref.yaml
│ │ ├── default.yaml
│ │ └── models/
│ │ ├── minigpt4_llama2.yaml
│ │ ├── minigpt4_vicuna0.yaml
│ │ └── minigpt_v2.yaml
│ ├── conversation/
│ │ ├── __init__.py
│ │ └── conversation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── builders/
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset_builder.py
│ │ │ └── image_text_pair_builder.py
│ │ ├── data_utils.py
│ │ └── datasets/
│ │ ├── __init__.py
│ │ ├── aok_vqa_datasets.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── coco_caption.py
│ │ ├── coco_dataset.py
│ │ ├── coco_vqa_datasets.py
│ │ ├── dataloader_utils.py
│ │ ├── flickr.py
│ │ ├── gqa_datasets.py
│ │ ├── laion_dataset.py
│ │ ├── llava_dataset.py
│ │ ├── multitask_conversation.py
│ │ ├── ocrvqa_dataset.py
│ │ ├── text_caps.py
│ │ ├── unnatural_instruction.py
│ │ ├── vg_dataset.py
│ │ └── vqa_datasets.py
│ ├── models/
│ │ ├── Qformer.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── eva_vit.py
│ │ ├── minigpt4.py
│ │ ├── minigpt_base.py
│ │ ├── minigpt_v2.py
│ │ └── modeling_llama.py
│ ├── processors/
│ │ ├── __init__.py
│ │ ├── base_processor.py
│ │ ├── blip_processors.py
│ │ └── randaugment.py
│ ├── runners/
│ │ ├── __init__.py
│ │ └── runner_base.py
│ └── tasks/
│ ├── __init__.py
│ ├── base_task.py
│ └── image_text_pretrain.py
├── train.py
└── train_configs/
├── minigpt4_llama2_stage1_pretrain.yaml
├── minigpt4_llama2_stage2_finetune.yaml
├── minigpt4_stage1_pretrain.yaml
├── minigpt4_stage2_finetune.yaml
└── minigptv2_finetune.yaml
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. iOS]
- Browser [e.g. chrome, safari]
- Version [e.g. 22]
**Smartphone (please complete the following information):**
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
wandb/
jobs/logs/
*.out
*ipynb
.history/
*.json
*.sh
.ipynb_common
logs/
results/
prompts/
output/
ckpt/
divide_vqa.py
jobs/
*.slurm
slurm*
sbatch_generate*
eval_data/
dataset/Evaluation.md
jupyter_notebook.slurm
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
https://discord.gg/2aNvvYVv.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: LICENSE.md
================================================
BSD 3-Clause License
Copyright 2023 Deyao Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: LICENSE_Lavis.md
================================================
BSD 3-Clause License
Copyright (c) 2022 Salesforce, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: MiniGPT4_Train.md
================================================
## Training of MiniGPT-4
The training of MiniGPT-4 contains two alignment stages.
**1. First pretraining stage**
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
to align the vision and language model. To download and prepare the datasets, please check
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
After the first stage, the visual features are mapped and can be understood by the language
model.
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
You can change the save path in the config file
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
```
A MiniGPT-4 checkpoint with only stage one training can be downloaded
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
**2. Second finetuning stage**
In the second stage, we use a small high quality image-text pair dataset created by ourselves
and convert it to a conversation format to further align MiniGPT-4.
To download and prepare our second stage dataset, please check our
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
To launch the second stage alignment,
first specify the path to the checkpoint file trained in stage 1 in
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
You can also specify the output path there.
Then, run the following command. In our experiments, we use 1 A100.
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
```
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
================================================
FILE: MiniGPTv2_Train.md
================================================
## Finetune of MiniGPT-4
You firstly need to prepare the dataset. you can follow this step to prepare the dataset.
our [dataset preparation](dataset/README_MINIGPTv2_FINETUNE.md).
In the train_configs/minigptv2_finetune.yaml, you need to set up the following paths:
llama_model checkpoint path: "/path/to/llama_checkpoint"
ckpt: "/path/to/pretrained_checkpoint"
ckpt save path: "/path/to/save_checkpoint"
For ckpt, you may load from our pretrained model checkpoints:
| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo) |
|------------------------------|------------------------------|------------------------------|
| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigptv2_finetune.yaml
```
================================================
FILE: README.md
================================================
# MiniGPT-V
**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**
Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨
☨equal last author
[](https://www.youtube.com/watch?v=atFCwV2hSY4)
**MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**
Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
*equal contribution
[](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
*King Abdullah University of Science and Technology*
## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
**Example Community Efforts Built on Top of MiniGPT-4 **
*
**InstructionGPT-4**: A 200-Instruction Paradigm for Fine-Tuning MiniGPT-4 Lai Wei, Zihao Jiang, Weiran Huang, Lichao Sun, Arxiv, 2023
*
**PatFig**: Generating Short and Long Captions for Patent Figures.", Aubakirova, Dana, Kim Gerdes, and Lufei Liu, ICCVW, 2023
*
**SkinGPT-4**: An Interactive Dermatology Diagnostic System with Visual Large Language Model, Juexiao Zhou and Xiaonan He and Liyuan Sun and Jiannan Xu and Xiuying Chen and Yuetan Chu and Longxi Zhou and Xingyu Liao and Bin Zhang and Xin Gao, Arxiv, 2023
*
**ArtGPT-4**: Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4.", Yuan, Zhengqing, Huiwen Xue, Xinyi Wang, Yongming Liu, Zhuanzhe Zhao, and Kun Wang, Arxiv, 2023
## News
[Oct.31 2023] We release the evaluation code of our MiniGPT-v2.
[Oct.24 2023] We release the finetuning code of our MiniGPT-v2.
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
## Online Demo
Click the image to chat with MiniGPT-v2 around your images
[](https://minigpt-v2.github.io/)
Click the image to chat with MiniGPT-4 around your images
[](https://minigpt-4.github.io)
## MiniGPT-v2 Examples

## MiniGPT-4 Examples
| | |
:-------------------------:|:-------------------------:
 | 
 | 
More examples can be found in the [project page](https://minigpt-4.github.io).
## Getting Started
### Installation
**1. Prepare the code and the environment**
Git clone our repository, creating a python environment and activate it via the following command
```bash
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
cd MiniGPT-4
conda env create -f environment.yml
conda activate minigptv
```
**2. Prepare the pretrained LLM weights**
**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version.
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main)
Then, set the variable *llama_model* in the model config file to the LLM weight path.
* For MiniGPT-v2, set the LLM path
[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14.
* For MiniGPT-4 (Llama2), set the LLM path
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
* For MiniGPT-4 (Vicuna), set the LLM path
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
**3. Prepare the pretrained model checkpoints**
Download the pretrained model checkpoints
| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo)|
|------------------------------|------------------------------|------------------------------|
| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) |
|----------------------------|---------------------------|---------------------------------|
| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) |
For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally
For MiniGPT-v2, run
```
python demo_v2.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0
```
For MiniGPT-4 (Vicuna version), run
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
For MiniGPT-4 (Llama2 version), run
```
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
```
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model
in 16 bit by setting `low_resource` to `False` in the relevant config file:
* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6)
* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6)
* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6)
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
### Training
For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
For finetuning details of MiniGPT-v2, check [here](MiniGPTv2_Train.md)
### Evaluation
For finetuning details of MiniGPT-v2, check [here](eval_scripts/EVAL_README.md)
## Acknowledgement
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
+ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model.
If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX:
```bibtex
@article{chen2023minigptv2,
title={MiniGPT-v2: large language model as a unified interface for vision-language multi-task learning},
author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed},
year={2023},
journal={arXiv preprint arXiv:2310.09478},
}
@article{zhu2023minigpt,
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
journal={arXiv preprint arXiv:2304.10592},
year={2023}
}
```
## License
This repository is under [BSD 3-Clause License](LICENSE.md).
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
BSD 3-Clause License [here](LICENSE_Lavis.md).
================================================
FILE: SECURITY.md
================================================
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ------- | ------------------ |
| 5.1.x | :white_check_mark: |
| 5.0.x | :x: |
| 4.0.x | :white_check_mark: |
| < 4.0 | :x: |
## Reporting a Vulnerability
Use this section to tell people how to report a vulnerability.
Tell them where to go, how often they can expect to get an update on a
reported vulnerability, what to expect if the vulnerability is accepted or
declined, etc.
================================================
FILE: dataset/README_1_STAGE.md
================================================
## Download the filtered Conceptual Captions, SBU, LAION datasets
### Pre-training datasets download:
We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP).
It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets
Image source | Filtered synthetic caption by ViT-L
--- | :---:
CC3M+CC12M+SBU | Download
LAION115M | Download
This will download two json files
```
ccs_synthetic_filtered_large.json
laion_synthetic_filtered_large.json
```
## prepare the data step-by-step
### setup the dataset folder and move the annotation file to the data storage folder
```
export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/
mkdir ${MINIGPT4_DATASET}/cc_sbu
mkdir ${MINIGPT4_DATASET}/laion
mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu
mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion
```
### Convert the scripts to data storate folder
```
cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu
cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu
cp convert_laion.py ${MINIGPT4_DATASET}/laion
cp download_laion.sh ${MINIGPT4_DATASET}/laion
```
### Convert the laion and cc_sbu annotation file format to be img2dataset format
```
cd ${MINIGPT4_DATASET}/cc_sbu
python convert_cc_sbu.py
cd ${MINIGPT4_DATASET}/laion
python convert_laion.py
```
### Download the datasets with img2dataset
```
cd ${MINIGPT4_DATASET}/cc_sbu
sh download_cc_sbu.sh
cd ${MINIGPT4_DATASET}/laion
sh download_laion.sh
```
The final dataset structure
```
.
├── ${MINIGPT4_DATASET}
│ ├── cc_sbu
│ ├── convert_cc_sbu.py
│ ├── download_cc_sbu.sh
│ ├── ccs_synthetic_filtered_large.json
│ ├── ccs_synthetic_filtered_large.tsv
│ └── cc_sbu_dataset
│ ├── 00000.tar
│ ├── 00000.parquet
│ ...
│ ├── laion
│ ├── convert_laion.py
│ ├── download_laion.sh
│ ├── laion_synthetic_filtered_large.json
│ ├── laion_synthetic_filtered_large.tsv
│ └── laion_dataset
│ ├── 00000.tar
│ ├── 00000.parquet
│ ...
...
```
## Set up the dataset configuration files
Then, set up the LAION dataset loading path in
[here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as
${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar
and the Conceptual Captoin and SBU datasets loading path in
[here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as
${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar
================================================
FILE: dataset/README_2_STAGE.md
================================================
## Second Stage Data Preparation
Our second stage dataset can be downloaded from
[here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link)
After extraction, you will get a data follder with the following structure:
```
cc_sbu_align
├── filter_cap.json
└── image
├── 2.jpg
├── 3.jpg
...
```
Put the folder to any path you want.
Then, set up the dataset path in the dataset config file
[here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5.
================================================
FILE: dataset/README_MINIGPTv2_FINETUNE.md
================================================
## Download the dataset for finetuning the MiniGPT-v2
Download the dataset
Image source | Download path
--- | :---:
COCO 2014 images | images captions
COCO VQA | vqa train vqa val
Visual Genome | images part1 images part2 image meta data
TextCaps | images annotations
RefCOCO | annotations
RefCOCO+ | annotations
RefCOCOg | annotations
OKVQA | annotations
AOK-VQA | annotations
OCR-VQA | annotations
GQA | images annotations
Filtered flickr-30k | annotations
Multi-task conversation | annotations
Filtered unnatural instruction | annotations
LLaVA | Compelex reasoning Detailed description Conversation
### COCO captions
Download the COCO 2014 images and captions
coco 2014 images path
```
${MINIGPTv2_DATASET}
├── coco
│ ├── images
...
```
coco caption annotation path
```
${MINIGPTv2_DATASET}
├── coco_captions
│ └── annotations
│ ├── coco_karpathy_train.json
...
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the coco_karpathy_train.json path
- [minigpt4/configs/datasets/coco/caption.yaml](../minigpt4/configs/datasets/coco/caption.yaml)
### COCO VQA
Download the vqa v2 train and validation json files
```
├── ${MINIGPTv2_DATASET}
│ ├── vqav2
│ ├── vqa_train.json
| ├── vqa_val.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the vqa_train.json and vqa_val.json path
- [minigpt4/configs/datasets/coco/defaults_vqa.yaml](../minigpt4/configs/datasets/coco/defaults_vqa.yaml)
### Visual genome
Download visiual genome images and annotation files
```
${MINIGPTv2_DATASET}
├── visual_genome
│ ├── VG_100K
│ ├── VG_100K_2
│ └── region_descriptions.json
│ └── image_data.json
...
```
Set **image_path** to visual_genome folder.
Similarly, set **ann_path** to the visual_genome folder.
- [minigpt4/configs/datasets/vg/ref.yaml](../minigpt4/configs/datasets/vg/ref.yaml)
### TextCaps
Download the TextCaps images and annotation files
```
├── ${MINIGPTv2_DATASET}
│ ├── textcaps
│ ├── train_images
│ ├── TextCaps_0.1_train.json
```
Set **image_path** to TextCaps train_images folder.
Similarly, set **ann_path** to the TextCaps_0.1_train.json path
- [minigpt4/configs/datasets/textcaps/caption.yaml](../minigpt4/configs/datasets/textcaps/caption.yaml)
### RefCOCO, RefCOCO+, RefCOCOg
Download the RefCOCO, RefCOCO+, RefCOCOg annotation files
```
${MINIGPTv2_DATASET}
├── refcoco_annotations
│ ├── refcoco
│ │ ├── instances.json
│ │ ├── refs(google).p
│ │ └── refs(unc).p
│ ├── refcoco+
│ │ ├── instances.json
│ │ └── refs(unc).p
│ └── refcocog
│ ├── instances.json
│ ├── refs(google).p
│ └─── refs(und).p
...
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** in all the following configs to the above folder *refcoco_annotations* that contains refcoco, refcoco+, and refcocog.
- [minigpt4/configs/datasets/coco_bbox/refcoco.yaml](../minigpt4/configs/datasets/coco_bbox/refcoco.yaml)
- [minigpt4/configs/datasets/coco_bbox/refcocog.yaml](../minigpt4/configs/datasets/coco_bbox/refcocog.yaml)
- [minigpt4/configs/datasets/coco_bbox/refcocop.yaml](../minigpt4/configs/datasets/coco_bbox/refcocop.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml)
### OKVQA
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── okvqa
│ ├── okvqa_train.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the OKVQA dataset
- [minigpt4/configs/datasets/okvqa/defaults.yaml](../minigpt4/configs/datasets/okvqa/defaults.yaml)
### COCO-VQA
- [OK-VQA Input Questions](https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip)
- [OK-VQA Annotations](https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip)
### AOK-VQA
Download the AOK-VQA annotation dataset
```
export AOKVQA_DIR=YOUR_DATASET_PATH
mkdir -p ${AOKVQA_DIR}
curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
```
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── aokvqa
│ ├── aokvqa_v1p0_train.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the AOKVQA dataset
- [minigpt4/configs/datasets/aokvqa/defaults.yaml](../minigpt4/configs/datasets/aokvqa/defaults.yaml)
### OCR-VQA
Download the OCR-VQA annotation files
download the images with loadDataset.py script
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── ocrvqa
│ ├── images
│ ├── dataset.json
```
Set **image_path** as the ocrvqa/images folder.
Similarly, set **ann_path** to the dataset.json
- [minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml](../minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml)
### GQA
Download the GQA annotation files and images
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── gqa
│ ├── images
│ ├── train_balanced_questions.json
```
Set **image_path** as the gqa/images folder.
Similarly, set **ann_path** to the train_balanced_questions.json
- [minigpt4/configs/datasets/gqa/balanced_val.yaml](../minigpt4/configs/datasets/gqa/balanced_val.yaml)
### filtered Flickr-30k
Download filtered Flickr-30k images (fill this [form](https://forms.illinois.edu/sec/229675) on official website or from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset/download?datasetVersionNumber=1)) and annotation files
```
${MINIGPTv2_DATASET}
├── filtered_flickr
│ ├── images
│ ├── captiontobbox.json
│ ├── groundedcaption.json
│ └── phrasetobbox.json
...
```
Set **image_path** as the flickr-30k images foler.
Similarly, set **ann_path** to the groundedcaption.json, captiontobbox.json and phrasetobbox.json for the
grounded image caption, caption to bbox, and phrase to bbox datasets.
- [minigpt4/configs/datasets/flickr/default.yaml](../minigpt4/configs/datasets/flickr/default.yaml)
- [minigpt4/configs/datasets/flickr/caption_to_phrase.yaml](../minigpt4/configs/datasets/flickr/caption_to_phrase.yaml)
- [minigpt4/configs/datasets/flickr/object_to_phrase.yaml](../minigpt4/configs/datasets/flickr/object_to_phrase.yaml)
### Multi-task conversation
Download the multi-task converstation dataset
```
Location_you_like
${MINIGPTv2_DATASET}
├── multitask_conversation
│ └── multitask_conversation.json
...
```
Set **image_path** as the COCO 2014 images folder.
Similarly, set **ann_path** to the multitask_conversation.json file path
- [minigpt4/configs/datasets/multitask_conversation/default.yaml](../minigpt4/configs/datasets/multitask_conversation/default.yaml)
### Unnatural instruction
Download the filtered unnatural instruction annotation files (we remove the very long sentences from the original unnatural instruction dataset)
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── unnatural_instructions
│ ├── filtered_unnatural_instruction.json
```
There is no image path.
Similarly, set **ann_path** to the filtered_unnatural_instruction.json file path
- [minigpt4/configs/datasets/nlp/unnatural_instruction.yaml](../minigpt4/configs/datasets/nlp/unnatural_instruction.yaml)
### LLaVA
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── llava
│ ├── conversation_58k.json
│ ├── detail_23k.json
│ ├── complex_reasoning_77k.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the previous downloaded conversation_58k.json,
detail_23k.json, and complex_reasoning_77k.json in conversation.yaml, detail.yaml, and reason.yaml, respectively.
- [minigpt4/configs/datasets/llava/conversation.yaml](../minigpt4/configs/datasets/llava/conversation.yaml)
- [minigpt4/configs/datasets/llava/detail.yaml](../minigpt4/configs/datasets/llava/detail.yaml)
- [minigpt4/configs/datasets/llava/reason.yaml](../minigpt4/configs/datasets/llava/reason.yaml)
================================================
FILE: dataset/convert_cc_sbu.py
================================================
import json
import csv
# specify input and output file paths
input_file = 'ccs_synthetic_filtered_large.json'
output_file = 'ccs_synthetic_filtered_large.tsv'
# load JSON data from input file
with open(input_file, 'r') as f:
data = json.load(f)
# extract header and data from JSON
header = data[0].keys()
rows = [x.values() for x in data]
# write data to TSV file
with open(output_file, 'w') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(header)
writer.writerows(rows)
================================================
FILE: dataset/convert_laion.py
================================================
import json
import csv
# specify input and output file paths
input_file = 'laion_synthetic_filtered_large.json'
output_file = 'laion_synthetic_filtered_large.tsv'
# load JSON data from input file
with open(input_file, 'r') as f:
data = json.load(f)
# extract header and data from JSON
header = data[0].keys()
rows = [x.values() for x in data]
# write data to TSV file
with open(output_file, 'w') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(header)
writer.writerows(rows)
================================================
FILE: demo.py
================================================
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from transformers import StoppingCriteriaList
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# ========================================
# Model Initialization
# ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
# ========================================
# Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """
Demo of MiniGPT-4
"""
description = """This is the demo of MiniGPT-4. Upload your images and start chatting!
"""
article = """


"""
#TODO show examples below
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
with gr.Column(scale=2):
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
demo.launch(share=True, enable_queue=True)
================================================
FILE: demo_v2.py
================================================
import argparse
import os
import random
from collections import defaultdict
import cv2
import re
import numpy as np
from PIL import Image
import torch
import html
import gradio as gr
import torchvision.transforms as T
import torch.backends.cudnn as cudnn
from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
cudnn.benchmark = False
cudnn.deterministic = True
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval()
CONV_VISION = Conversation(
system="",
roles=(r"[INST] ", r" [/INST]"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
def extract_substrings(string):
# first check if there is no-finished bracket
index = string.rfind('}')
if index != -1:
string = string[:index + 1]
pattern = r'(.*?)\}(?!<)'
matches = re.findall(pattern, string)
substrings = [match for match in matches]
return substrings
def is_overlapping(rect1, rect2):
x1, y1, x2, y2 = rect1
x3, y3, x4, y4 = rect2
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
def save_tmp_img(visual_img):
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
file_path = "/tmp/gradio" + file_name
visual_img.save(file_path)
return file_path
def mask2bbox(mask):
if mask is None:
return ''
mask = mask.resize([100, 100], resample=Image.NEAREST)
mask = np.array(mask)[:, :, 0]
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if rows.sum():
# Get the top, bottom, left, and right boundaries
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
else:
bbox = ''
return bbox
def escape_markdown(text):
# List of Markdown special characters that need to be escaped
md_chars = ['<', '>']
# Escape each special character
for char in md_chars:
text = text.replace(char, '\\' + char)
return text
def reverse_escape(text):
md_chars = ['\\<', '\\>']
for char in md_chars:
text = text.replace(char, char[1:])
return text
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(210, 210, 0),
(255, 0, 255),
(0, 255, 255),
(114, 128, 250),
(0, 165, 255),
(0, 128, 0),
(144, 238, 144),
(238, 238, 175),
(255, 191, 0),
(0, 128, 0),
(226, 43, 138),
(255, 0, 255),
(0, 215, 255),
]
color_map = {
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
color_id, color in enumerate(colors)
}
used_colors = colors
def visualize_all_bbox_together(image, generation):
if image is None:
return None, ''
generation = html.unescape(generation)
image_width, image_height = image.size
image = image.resize([500, int(500 / image_width * image_height)])
image_width, image_height = image.size
string_list = extract_substrings(generation)
if string_list: # it is grounding or detection
mode = 'all'
entities = defaultdict(list)
i = 0
j = 0
for string in string_list:
try:
obj, string = string.split('
')
except ValueError:
print('wrong string: ', string)
continue
bbox_list = string.split('')
flag = False
for bbox_string in bbox_list:
integers = re.findall(r'-?\d+', bbox_string)
if len(integers) == 4:
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities[obj].append([left, bottom, right, top])
j += 1
flag = True
if flag:
i += 1
else:
integers = re.findall(r'-?\d+', generation)
if len(integers) == 4: # it is refer
mode = 'single'
entities = list()
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities.append([left, bottom, right, top])
else:
# don't detect any valid bbox to visualize
return None, ''
if len(entities) == 0:
return None, ''
if isinstance(image, Image.Image):
image_h = image.height
image_w = image.width
image = np.array(image)
elif isinstance(image, str):
if os.path.exists(image):
pil_img = Image.open(image).convert("RGB")
image = np.array(pil_img)[:, :, [2, 1, 0]]
image_h = pil_img.height
image_w = pil_img.width
else:
raise ValueError(f"invaild image path, {image}")
elif isinstance(image, torch.Tensor):
image_tensor = image.cpu()
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
pil_img = T.ToPILImage()(image_tensor)
image_h = pil_img.height
image_w = pil_img.width
image = np.array(pil_img)[:, :, [2, 1, 0]]
else:
raise ValueError(f"invaild image format, {type(image)} for {image}")
indices = list(range(len(entities)))
new_image = image.copy()
previous_bboxes = []
# size of text
text_size = 0.5
# thickness of text
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
box_line = 2
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
base_height = int(text_height * 0.675)
text_offset_original = text_height - base_height
text_spaces = 2
# num_bboxes = sum(len(x[-1]) for x in entities)
used_colors = colors # random.sample(colors, k=num_bboxes)
color_id = -1
for entity_idx, entity_name in enumerate(entities):
if mode == 'single' or mode == 'identify':
bboxes = entity_name
bboxes = [bboxes]
else:
bboxes = entities[entity_name]
color_id += 1
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
skip_flag = False
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
if mode == 'all':
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
x1 = orig_x1 - l_o
y1 = orig_y1 - l_o
if y1 < text_height + text_offset_original + 2 * text_spaces:
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
x1 = orig_x1 + r_o
# add text background
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
text_line)
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
for prev_bbox in previous_bboxes:
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
prev_bbox['phrase'] == entity_name:
skip_flag = True
break
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
y1 += (text_height + text_offset_original + 2 * text_spaces)
if text_bg_y2 >= image_h:
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
text_bg_y2 = image_h
y1 = image_h
break
if not skip_flag:
alpha = 0.5
for i in range(text_bg_y1, text_bg_y2):
for j in range(text_bg_x1, text_bg_x2):
if i < image_h and j < image_w:
if j < text_bg_x1 + 1.35 * c_width:
# original color
bg_color = color
else:
# white
bg_color = [255, 255, 255]
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
np.uint8)
cv2.putText(
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
)
previous_bboxes.append(
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
if mode == 'all':
def color_iterator(colors):
while True:
for color in colors:
yield color
color_gen = color_iterator(colors)
# Add colors to phrases and remove
def colored_phrases(match):
phrase = match.group(1)
color = next(color_gen)
return f'{phrase}'
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation)
generation_colored = re.sub(r'(.*?)
', colored_phrases, generation)
else:
generation_colored = ''
pil_image = Image.fromarray(new_image)
return pil_image, generation_colored
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
interactive=True), chat_state, img_list
def image_upload_trigger(upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
upload_flag = 1
if img_list:
replace_flag = 1
return upload_flag, replace_flag
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
upload_flag = 1
if img_list or replace_flag == 1:
replace_flag = 1
return upload_flag, replace_flag
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
if len(user_message) == 0:
text_box_show = 'Input should not be empty!'
else:
text_box_show = ''
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
else:
mask = None
if '[identify]' in user_message:
# check if user provide bbox in the text input
integers = re.findall(r'-?\d+', user_message)
if len(integers) != 4: # no bbox in text
bbox = mask2bbox(mask)
user_message = user_message + bbox
if chat_state is None:
chat_state = CONV_VISION.copy()
if upload_flag:
if replace_flag:
chat_state = CONV_VISION.copy() # new image, reset everything
replace_flag = 0
chatbot = []
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
upload_flag = 0
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
if '[identify]' in user_message:
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
if visual_img is not None:
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[(file_path,), None]]
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
def gradio_answer(chatbot, chat_state, img_list, temperature):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
if len(img_list) > 0:
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)
output = ''
for new_output in streamer:
escapped = escape_markdown(new_output)
output += escapped
chatbot[-1][1] = output
yield chatbot, chat_state
chat_state.messages[-1][1] = ''
return chatbot, chat_state
def gradio_visualize(chatbot, gr_img):
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
unescaped = reverse_escape(chatbot[-1][1])
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
if visual_img is not None:
if len(generation_color):
chatbot[-1][1] = generation_color
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[None, (file_path,)]]
return chatbot
def gradio_taskselect(idx):
prompt_list = [
'',
'[grounding] describe this image in detail',
'[refer] ',
'[detection] ',
'[identify] what is this ',
'[vqa] '
]
instruct_list = [
'**Hint:** Type in whatever you want',
'**Hint:** Send the command to generate a grounded image description',
'**Hint:** Type in a phrase about an object in the image and send the command',
'**Hint:** Type in a caption or phrase, and see object locations in the image',
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
'**Hint:** Send a question to get a short answer',
]
return prompt_list[idx], instruct_list[idx]
chat = Chat(model, vis_processor, device=device)
title = """MiniGPT-v2 Demo
"""
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
# article = """



"""
article = """
"""
introduction = '''
For Abilities Involving Visual Grounding:
1. Grounding: CLICK **Send** to generate a grounded image description.
2. Refer: Input a referring object and CLICK **Send**.
3. Detection: Write a caption or phrase, and CLICK **Send**.
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
5. VQA: Input a visual question and CLICK **Send**.
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
You can also simply chat in free form!
'''
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
scale=8)
with gr.Blocks() as demo:
gr.Markdown(title)
# gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.6,
step=0.1,
interactive=True,
label="Temperature",
)
clear = gr.Button("Restart")
gr.Markdown(introduction)
with gr.Column():
chat_state = gr.State(value=None)
img_list = gr.State(value=[])
chatbot = gr.Chatbot(label='MiniGPT-v2')
dataset = gr.Dataset(
components=[gr.Textbox(visible=False)],
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
type="index",
label='Task Shortcuts',
)
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
with gr.Row():
text_input.render()
send = gr.Button("Send", variant='primary', size='sm', scale=1)
upload_flag = gr.State(value=0)
replace_flag = gr.State(value=0)
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
with gr.Row():
with gr.Column():
gr.Examples(examples=[
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
img_list],
["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list],
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
img_list],
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
with gr.Column():
gr.Examples(examples=[
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
upload_flag, replace_flag, img_list],
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
dataset.click(
gradio_taskselect,
inputs=[dataset],
outputs=[text_input, task_inst],
show_progress="hidden",
postprocess=False,
queue=False,
)
text_input.submit(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
send.click(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
demo.launch(share=True, enable_queue=True)
================================================
FILE: environment.yml
================================================
name: minigptv
channels:
- pytorch
- defaults
- anaconda
dependencies:
- python=3.9
- cudatoolkit
- pip
- pip:
- torch==2.0.0
- torchaudio
- torchvision
- huggingface-hub==0.18.0
- matplotlib==3.7.0
- psutil==5.9.4
- iopath
- pyyaml==6.0
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.30.0
- timm==0.6.13
- webdataset==0.2.48
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- decord==0.6.0
- peft==0.2.0
- sentence-transformers
- gradio==3.47.1
- accelerate==0.20.3
- bitsandbytes==0.37.0
- scikit-image
- visual-genome
- wandb
================================================
FILE: eval_configs/minigpt4_eval.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_configs/minigpt4_llama2_eval.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: ""
low_resource: True
prompt_template: '[INST] {} [/INST] '
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_configs/minigptv2_benchmark_evaluation.yaml
================================================
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 500
end_sym: ""
low_resource: False
prompt_template: '[INST] {} [/INST]'
llama_model: ""
ckpt: ""
lora_r: 64
lora_alpha: 16
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
evaluation_datasets:
refcoco:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
refcocog:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
refcoco+:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
gqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
okvqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
vizwiz:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
iconvqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
vsr:
eval_file_path: cambridgeltl/vsr_zeroshot
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
hm:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 100
run:
task: image_text_pretrain
name: minigptv2_evaluation
save_path: /path/to/save/folder_path
================================================
FILE: eval_configs/minigptv2_eval.yaml
================================================
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 500
end_sym: ""
low_resource: True
prompt_template: '[INST] {} [/INST]'
ckpt: "please set this value to the path of pretrained checkpoint"
lora_r: 64
lora_alpha: 16
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_scripts/EVAL_README.md
================================================
## Evaluation Instruction for MiniGPT-v2
### Data preparation
Images download
Image source | Download path
--- | :---:
OKVQA| annotations images
gqa | annotations images
hateful meme | images and annotations
iconqa | images and annotation
vizwiz | images and annotation
RefCOCO | annotations
RefCOCO+ | annotations
RefCOCOg | annotations
### Evaluation dataset structure
```
${MINIGPTv2_EVALUATION_DATASET}
├── gqa
│ └── test_balanced_questions.json
│ ├── testdev_balanced_questions.json
│ ├── gqa_images
├── hateful_meme
│ └── hm_images
│ ├── dev.jsonl
├── iconvqa
│ └── iconvqa_images
│ ├── choose_text_val.json
├── vizwiz
│ └── vizwiz_images
│ ├── val.json
├── vsr
│ └── vsr_images
├── okvqa
│ ├── okvqa_test_split.json
│ ├── mscoco_val2014_annotations_clean.json
│ ├── OpenEnded_mscoco_val2014_questions_clean.json
├── refcoco
│ └── instances.json
│ ├── refs(google).p
│ ├── refs(unc).p
├── refcoco+
│ └── instances.json
│ ├── refs(unc).p
├── refercocog
│ └── instances.json
│ ├── refs(google).p
│ ├── refs(und).p
...
```
### environment setup
```
export PYTHONPATH=$PYTHONPATH:/path/to/directory/of/MiniGPT-4
```
### config file setup
Set **llama_model** to the path of LLaMA model.
Set **ckpt** to the path of our pretrained model.
Set **eval_file_path** to the path of the annotation files for each evaluation data.
Set **img_path** to the img_path for each evaluation dataset.
Set **save_path** to the save_path for each evaluation dataset.
in [eval_configs/minigptv2_benchmark_evaluation.yaml](../eval_configs/minigptv2_benchmark_evaluation.yaml)
### start evalauting RefCOCO, RefCOCO+, RefCOCOg
port=port_number
cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
dataset names:
| refcoco | refcoco+ | refcocog |
| ------- | -------- | -------- |
```
torchrun --master-port ${port} --nproc_per_node 1 eval_ref.py \
--cfg-path ${cfg_path} --dataset refcoco,refcoco+,refcocog --resample
```
### start evaluating visual question answering
port=port_number
cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
dataset names:
| okvqa | vizwiz | iconvqa | gqa | vsr | hm |
| ------- | -------- | -------- |-------- | -------- | -------- |
```
torchrun --master-port ${port} --nproc_per_node 1 eval_vqa.py \
--cfg-path ${cfg_path} --dataset okvqa,vizwiz,iconvqa,gqa,vsr,hm
```
================================================
FILE: eval_scripts/eval_ref.py
================================================
import os
import re
import json
import argparse
from collections import defaultdict
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from minigpt4.common.config import Config
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData
def list_of_str(arg):
return list(map(str, arg.split(',')))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
args = parser.parse_args()
cfg = Config(args)
eval_dict = {'refcoco': ['val','testA','testB'],
'refcoco+': ['val','testA','testB'],
'refcocog': ['val','test']}
model, vis_processor = init_model(args)
model.eval()
CONV_VISION = CONV_VISION_minigptv2
conv_temp = CONV_VISION.copy()
conv_temp.system = ""
#
model.eval()
save_path = cfg.run_cfg.save_path
for dataset in args.dataset:
for split in eval_dict[dataset]:
eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f:
refcoco = json.load(f)
data = RefCOCOEvalData(refcoco, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = defaultdict(list)
resamples = []
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer):
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
if args.resample:
for i in range(20):
data = RefCOCOEvalData(resamples, vis_processor, img_path)
resamples = []
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer) or i == 4:
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
if len(resamples) == 0:
break
file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
count=0
total=len(refcoco)
res=args.res
refcoco_dict = defaultdict()
for item in refcoco:
refcoco_dict[item['img_id']] = item
for img_id in refcoco_dict:
item = refcoco_dict[img_id]
bbox = item['bbox']
outputs = minigpt4_predict[img_id]
for output in outputs:
try:
integers = re.findall(r'\d+', output)
pred_bbox = [int(num) for num in integers]
height = item['height']
width = item['width']
pred_bbox[0] = pred_bbox[0] / res * width
pred_bbox[1] = pred_bbox[1] / res * height
pred_bbox[2] = pred_bbox[2] / res * width
pred_bbox[3] = pred_bbox[3] / res * height
gt_bbox = [0,0,0,0]
gt_bbox[0] = bbox[0]
gt_bbox[1] = bbox[1]
gt_bbox[2] = bbox[0] + bbox[2]
gt_bbox[3] = bbox[1] + bbox[3]
iou_score = computeIoU(pred_bbox, gt_bbox)
if iou_score > 0.5:
count+=1
except:
continue
print(f'{dataset} {split}:', count / total * 100, flush=True)
================================================
FILE: eval_scripts/eval_vqa.py
================================================
import os
import re
import json
import argparse
from collections import defaultdict
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,IconQAEvalData,GQAEvalData,VSREvalData,HMEvalData
from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.common.config import Config
def list_of_str(arg):
return list(map(str, arg.split(',')))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
args = parser.parse_args()
cfg = Config(args)
model, vis_processor = init_model(args)
conv_temp = CONV_VISION_minigptv2.copy()
conv_temp.system = ""
model.eval()
save_path = cfg.run_cfg.save_path
if 'okvqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["okvqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["okvqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["okvqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["okvqa"]["max_new_tokens"]
evaluation_annntation_path = os.path.join(eval_file_path, "okvqa_test_split.json")
with open(evaluation_annntation_path) as f:
ok_vqa_test_split = json.load(f)
data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = []
for images, questions, question_ids, img_ids in eval_dataloader:
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, question_id, question, img_id in zip(answers, question_ids, questions, img_ids):
result = dict()
answer = answer.lower().replace('','').strip()
result['answer'] = answer
result['question_id'] = int(question_id)
minigpt4_predict.append(result)
file_save_path= os.path.join(save_path,"okvqa.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
annFile = os.path.join(eval_file_path,"mscoco_val2014_annotations_clean.json")
quesFile = os.path.join(eval_file_path,"OpenEnded_mscoco_val2014_questions_clean.json" )
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(file_save_path, quesFile)
vqaEval = VQAEval(vqa, vqaRes, n=2)
vqaEval.evaluate()
print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True)
if 'vizwiz' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["vizwiz"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["vizwiz"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vizwiz"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vizwiz"]["max_new_tokens"]
vizwiz = json.load(open(eval_file_path, 'r'))
data = VizWizEvalData(vizwiz, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = []
total_acc = []
for images, texts, gt_answers in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
with torch.no_grad():
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False,repetition_penalty=1.0)
for answer, gt_answer in zip(answers, gt_answers):
result = dict()
result['answer'] = answer.replace('','').strip()
minigpt4_predict.append(result)
count=0
gt_answer = gt_answer.split('_')
for gt in gt_answer:
if gt.lower() == answer.lower():
count += 1
acc = min(count/3.0, 1.0)
total_acc.append(acc)
file_save_path = os.path.join(save_path, "vizwiz.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
print('vizwiz Acc: ', np.average(total_acc)* 100.0, flush=True)
if 'iconvqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["iconvqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["iconvqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["iconvqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["iconvqa"]["max_new_tokens"]
iconqa_text_val = json.load(open(eval_file_path,"r"))
data = IconQAEvalData(iconqa_text_val, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count = 0
for images, texts, candidates, answers in tqdm(eval_dataloader):
candidates = [candidate.split('_') for candidate in candidates]
num_cand = [len(candidate) for candidate in candidates]
for candidate in candidates:
candidate.extend(['none'] * (max(num_cand) - len(candidate)))
candidates = [list(x) for x in zip(*candidates)]
instructions = ["[INST]
{} [/INST]".format(text) for text in texts]
answer_ranks = model.multi_select(images, instructions, candidates, num_cand=num_cand)
for idx, answer in enumerate(answers):
if answer_ranks[idx][0] == answer:
count += 1
print('iconqa Acc: ', count / len(iconqa_text_val) * 100.0, flush=True)
if 'gqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["gqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["gqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["gqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["gqa"]["max_new_tokens"]
gqa = json.load(open(eval_file_path))
data = GQAEvalData(gqa, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
result['pred'] = answer.lower().replace('','').strip()
result['gt'] = label
minigpt4_predict.append(result)
if answer.lower() == label:
count+=1
total+=1
print('gqa val:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path, "gqa.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
if 'vsr' in args.dataset:
img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
data = VSREvalData(annotation, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
result['pred'] = answer.replace('','').strip()
result['gt'] = label
minigpt4_predict.append(result)
if answer.lower() == label.lower():
count+=1
total+=1
print('vsr test:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path,"vsr.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
if 'hm' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["hm"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["hm"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["hm"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["hm"]["max_new_tokens"]
annotation = []
with open(eval_file_path, 'r') as jsonl_file:
for line in jsonl_file:
json_obj = json.loads(line)
annotation.append(json_obj)
data = HMEvalData(annotation, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
if answer.lower().strip() =="yes":
answer=1
elif answer.lower().strip()=="no":
answer=0
else:
print("non-matching answer",answer)
result['pred'] = answer
result['gt'] = int(label)
minigpt4_predict.append(result)
if answer == label:
count+=1
total+=1
print('hm val:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path, "hm.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
================================================
FILE: minigpt4/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import sys
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.tasks import *
root_dir = os.path.dirname(os.path.abspath(__file__))
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
registry.register_path("library_root", root_dir)
repo_root = os.path.join(root_dir, "..")
registry.register_path("repo_root", repo_root)
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
registry.register_path("cache_root", cache_root)
registry.register("MAX_INT", sys.maxsize)
registry.register("SPLIT_NAMES", ["train", "val", "test"])
================================================
FILE: minigpt4/common/__init__.py
================================================
================================================
FILE: minigpt4/common/config.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import json
from typing import Dict
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
class Config:
def __init__(self, args):
self.config = {}
self.args = args
# Register the config and configuration for setup
registry.register("configuration", self)
user_config = self._build_opt_list(self.args.options)
config = OmegaConf.load(self.args.cfg_path)
runner_config = self.build_runner_config(config)
model_config = self.build_model_config(config, **user_config)
dataset_config = self.build_dataset_config(config)
evaluation_dataset_config = self.build_evaluation_dataset_config(config)
# Validate the user-provided runner configuration
# model and dataset configuration are supposed to be validated by the respective classes
# [TODO] validate the model/dataset configuration
# self._validate_runner_config(runner_config)
# Override the default configuration with user options.
self.config = OmegaConf.merge(
runner_config, model_config, dataset_config,evaluation_dataset_config, user_config
)
def _validate_runner_config(self, runner_config):
"""
This method validates the configuration, such that
1) all the user specified options are valid;
2) no type mismatches between the user specified options and the config.
"""
runner_config_validator = create_runner_config_validator()
runner_config_validator.validate(runner_config)
def _build_opt_list(self, opts):
opts_dot_list = self._convert_to_dot_list(opts)
return OmegaConf.from_dotlist(opts_dot_list)
@staticmethod
def build_model_config(config, **kwargs):
model = config.get("model", None)
assert model is not None, "Missing model configuration file."
model_cls = registry.get_model_class(model.arch)
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
model_type = kwargs.get("model.model_type", None)
if not model_type:
model_type = model.get("model_type", None)
# else use the model type selected by user.
assert model_type is not None, "Missing model_type."
model_config_path = model_cls.default_config_path(model_type=model_type)
model_config = OmegaConf.create()
# hierarchy override, customized config > default config
model_config = OmegaConf.merge(
model_config,
OmegaConf.load(model_config_path),
{"model": config["model"]},
)
return model_config
@staticmethod
def build_runner_config(config):
return {"run": config.run}
@staticmethod
def build_dataset_config(config):
datasets = config.get("datasets", None)
if datasets is None:
raise KeyError(
"Expecting 'datasets' as the root key for dataset configuration."
)
dataset_config = OmegaConf.create()
for dataset_name in datasets:
builder_cls = registry.get_builder_class(dataset_name)
dataset_config_type = datasets[dataset_name].get("type", "default")
dataset_config_path = builder_cls.default_config_path(
type=dataset_config_type
)
# hierarchy override, customized config > default config
dataset_config = OmegaConf.merge(
dataset_config,
OmegaConf.load(dataset_config_path),
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
)
return dataset_config
@staticmethod
def build_evaluation_dataset_config(config):
datasets = config.get("evaluation_datasets", None)
# if datasets is None:
# raise KeyError(
# "Expecting 'datasets' as the root key for dataset configuration."
# )
dataset_config = OmegaConf.create()
if datasets is not None:
for dataset_name in datasets:
builder_cls = registry.get_builder_class(dataset_name)
# hierarchy override, customized config > default config
dataset_config = OmegaConf.merge(
dataset_config,
{"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}},
)
return dataset_config
def _convert_to_dot_list(self, opts):
if opts is None:
opts = []
if len(opts) == 0:
return opts
has_equal = opts[0].find("=") != -1
if has_equal:
return opts
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
def get_config(self):
return self.config
@property
def run_cfg(self):
return self.config.run
@property
def datasets_cfg(self):
return self.config.datasets
@property
def evaluation_datasets_cfg(self):
return self.config.evaluation_datasets
@property
def model_cfg(self):
return self.config.model
def pretty_print(self):
logging.info("\n===== Running Parameters =====")
logging.info(self._convert_node_to_json(self.config.run))
logging.info("\n====== Dataset Attributes ======")
datasets = self.config.datasets
for dataset in datasets:
if dataset in self.config.datasets:
logging.info(f"\n======== {dataset} =======")
dataset_config = self.config.datasets[dataset]
logging.info(self._convert_node_to_json(dataset_config))
else:
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
logging.info(f"\n====== Model Attributes ======")
logging.info(self._convert_node_to_json(self.config.model))
def _convert_node_to_json(self, node):
container = OmegaConf.to_container(node, resolve=True)
return json.dumps(container, indent=4, sort_keys=True)
def to_dict(self):
return OmegaConf.to_container(self.config)
def node_to_dict(node):
return OmegaConf.to_container(node)
class ConfigValidator:
"""
This is a preliminary implementation to centralize and validate the configuration.
May be altered in the future.
A helper class to validate configurations from yaml file.
This serves the following purposes:
1. Ensure all the options in the yaml are defined, raise error if not.
2. when type mismatches are found, the validator will raise an error.
3. a central place to store and display helpful messages for supported configurations.
"""
class _Argument:
def __init__(self, name, choices=None, type=None, help=None):
self.name = name
self.val = None
self.choices = choices
self.type = type
self.help = help
def __str__(self):
s = f"{self.name}={self.val}"
if self.type is not None:
s += f", ({self.type})"
if self.choices is not None:
s += f", choices: {self.choices}"
if self.help is not None:
s += f", ({self.help})"
return s
def __init__(self, description):
self.description = description
self.arguments = dict()
self.parsed_args = None
def __getitem__(self, key):
assert self.parsed_args is not None, "No arguments parsed yet."
return self.parsed_args[key]
def __str__(self) -> str:
return self.format_help()
def add_argument(self, *args, **kwargs):
"""
Assume the first argument is the name of the argument.
"""
self.arguments[args[0]] = self._Argument(*args, **kwargs)
def validate(self, config=None):
"""
Convert yaml config (dict-like) to list, required by argparse.
"""
for k, v in config.items():
assert (
k in self.arguments
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
if self.arguments[k].type is not None:
try:
self.arguments[k].val = self.arguments[k].type(v)
except ValueError:
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
if self.arguments[k].choices is not None:
assert (
v in self.arguments[k].choices
), f"""{k} must be one of {self.arguments[k].choices}."""
return config
def format_arguments(self):
return str([f"{k}" for k in sorted(self.arguments.keys())])
def format_help(self):
# description + key-value pair string for each argument
help_msg = str(self.description)
return help_msg + ", available arguments: " + self.format_arguments()
def print_help(self):
# display help message
print(self.format_help())
def create_runner_config_validator():
validator = ConfigValidator(description="Runner configurations")
validator.add_argument(
"runner",
type=str,
choices=["runner_base", "runner_iter"],
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
runner runs based on iters. Default: runner_base""",
)
# add argumetns for training dataset ratios
validator.add_argument(
"train_dataset_ratios",
type=Dict[str, float],
help="""Ratios of training dataset. This is used in iteration-based runner.
Do not support for epoch-based runner because how to define an epoch becomes tricky.
Default: None""",
)
validator.add_argument(
"max_iters",
type=float,
help="Maximum number of iterations to run.",
)
validator.add_argument(
"max_epoch",
type=int,
help="Maximum number of epochs to run.",
)
# add arguments for iters_per_inner_epoch
validator.add_argument(
"iters_per_inner_epoch",
type=float,
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
)
lr_scheds_choices = registry.list_lr_schedulers()
validator.add_argument(
"lr_sched",
type=str,
choices=lr_scheds_choices,
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
)
task_choices = registry.list_tasks()
validator.add_argument(
"task",
type=str,
choices=task_choices,
help="Task to use, from {}".format(task_choices),
)
# add arguments for init_lr
validator.add_argument(
"init_lr",
type=float,
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
)
# add arguments for min_lr
validator.add_argument(
"min_lr",
type=float,
help="Minimum learning rate (after decay).",
)
# add arguments for warmup_lr
validator.add_argument(
"warmup_lr",
type=float,
help="Starting learning rate for warmup.",
)
# add arguments for learning rate decay rate
validator.add_argument(
"lr_decay_rate",
type=float,
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
)
# add arguments for weight decay
validator.add_argument(
"weight_decay",
type=float,
help="Weight decay rate.",
)
# add arguments for training batch size
validator.add_argument(
"batch_size_train",
type=int,
help="Training batch size.",
)
# add arguments for evaluation batch size
validator.add_argument(
"batch_size_eval",
type=int,
help="Evaluation batch size, including validation and testing.",
)
# add arguments for number of workers for data loading
validator.add_argument(
"num_workers",
help="Number of workers for data loading.",
)
# add arguments for warm up steps
validator.add_argument(
"warmup_steps",
type=int,
help="Number of warmup steps. Required if a warmup schedule is used.",
)
# add arguments for random seed
validator.add_argument(
"seed",
type=int,
help="Random seed.",
)
# add arguments for output directory
validator.add_argument(
"output_dir",
type=str,
help="Output directory to save checkpoints and logs.",
)
# add arguments for whether only use evaluation
validator.add_argument(
"evaluate",
help="Whether to only evaluate the model. If true, training will not be performed.",
)
# add arguments for splits used for training, e.g. ["train", "val"]
validator.add_argument(
"train_splits",
type=list,
help="Splits to use for training.",
)
# add arguments for splits used for validation, e.g. ["val"]
validator.add_argument(
"valid_splits",
type=list,
help="Splits to use for validation. If not provided, will skip the validation.",
)
# add arguments for splits used for testing, e.g. ["test"]
validator.add_argument(
"test_splits",
type=list,
help="Splits to use for testing. If not provided, will skip the testing.",
)
# add arguments for accumulating gradient for iterations
validator.add_argument(
"accum_grad_iters",
type=int,
help="Number of iterations to accumulate gradient for.",
)
# ====== distributed training ======
validator.add_argument(
"device",
type=str,
choices=["cpu", "cuda"],
help="Device to use. Support 'cuda' or 'cpu' as for now.",
)
validator.add_argument(
"world_size",
type=int,
help="Number of processes participating in the job.",
)
validator.add_argument("dist_url", type=str)
validator.add_argument("distributed", type=bool)
# add arguments to opt using distributed sampler during evaluation or not
validator.add_argument(
"use_dist_eval_sampler",
type=bool,
help="Whether to use distributed sampler during evaluation or not.",
)
# ====== task specific ======
# generation task specific arguments
# add arguments for maximal length of text output
validator.add_argument(
"max_len",
type=int,
help="Maximal length of text output.",
)
# add arguments for minimal length of text output
validator.add_argument(
"min_len",
type=int,
help="Minimal length of text output.",
)
# add arguments number of beams
validator.add_argument(
"num_beams",
type=int,
help="Number of beams used for beam search.",
)
# vqa task specific arguments
# add arguments for number of answer candidates
validator.add_argument(
"num_ans_candidates",
type=int,
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
)
# add arguments for inference method
validator.add_argument(
"inference_method",
type=str,
choices=["genearte", "rank"],
help="""Inference method to use for question answering. If rank, requires a answer list.""",
)
# ====== model specific ======
validator.add_argument(
"k_test",
type=int,
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
)
return validator
================================================
FILE: minigpt4/common/dist_utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import functools
import os
import torch
import torch.distributed as dist
import timm.models.hub as timm_hub
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if args.distributed is False:
print("Not using distributed mode")
return
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
"| distributed init (rank {}, world {}): {}".format(
args.rank, args.world_size, args.dist_url
),
flush=True,
)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(
days=365
), # allow auto-downloading and de-compressing
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def get_dist_info():
if torch.__version__ < "1.0":
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else: # non-distributed training
rank = 0
world_size = 1
return rank, world_size
def main_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
================================================
FILE: minigpt4/common/eval_utils.py
================================================
import argparse
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from minigpt4.common.registry import registry
from minigpt4.common.config import Config
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def eval_parser():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--name", type=str, default='A2', help="evaluation name")
parser.add_argument("--ckpt", type=str, help="path to configuration file.")
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
return parser
def prepare_texts(texts, conv_temp):
convs = [conv_temp.copy() for _ in range(len(texts))]
[conv.append_message(
conv.roles[0], '
{}'.format(text)) for conv, text in zip(convs, texts)]
[conv.append_message(conv.roles[1], None) for conv in convs]
texts = [conv.get_prompt() for conv in convs]
return texts
def init_model(args):
print('Initialization Model')
cfg = Config(args)
# cfg.model_cfg.ckpt = args.ckpt
# cfg.model_cfg.lora_r = args.lora_r
# cfg.model_cfg.lora_alpha = args.lora_alpha
model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
# import pudb; pudb.set_trace()
key = list(cfg.datasets_cfg.keys())[0]
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
print('Initialization Finished')
return model, vis_processor
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
================================================
FILE: minigpt4/common/gradcam.py
================================================
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform
def getAttMap(img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
)
return attMap
================================================
FILE: minigpt4/common/logger.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
from minigpt4.common import dist_utils
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not dist_utils.is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def global_avg(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
log_msg = [
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
if torch.cuda.is_available():
log_msg.append("max mem: {memory:.0f}")
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def setup_logger():
logging.basicConfig(
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)
================================================
FILE: minigpt4/common/optims.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import math
from minigpt4.common.registry import registry
@registry.register_lr_scheduler("linear_warmup_step_lr")
class LinearWarmupStepLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
min_lr,
init_lr,
decay_rate=1,
warmup_start_lr=-1,
warmup_steps=0,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.min_lr = min_lr
self.decay_rate = decay_rate
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
if cur_epoch == 0:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
step_lr_schedule(
epoch=cur_epoch,
optimizer=self.optimizer,
init_lr=self.init_lr,
min_lr=self.min_lr,
decay_rate=self.decay_rate,
)
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
class LinearWarmupCosineLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
iters_per_epoch,
min_lr,
init_lr,
warmup_steps=0,
warmup_start_lr=-1,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.iters_per_epoch = iters_per_epoch
self.min_lr = min_lr
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
if total_cur_step < self.warmup_steps:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
cosine_lr_schedule(
epoch=total_cur_step,
optimizer=self.optimizer,
max_epoch=self.max_epoch * self.iters_per_epoch,
init_lr=self.init_lr,
min_lr=self.min_lr,
)
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
"""Decay the learning rate"""
lr = (init_lr - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * epoch / max_epoch)
) + min_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
"""Warmup the learning rate"""
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
"""Decay the learning rate"""
lr = max(min_lr, init_lr * (decay_rate**epoch))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
================================================
FILE: minigpt4/common/registry.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
class Registry:
mapping = {
"builder_name_mapping": {},
"task_name_mapping": {},
"processor_name_mapping": {},
"model_name_mapping": {},
"lr_scheduler_name_mapping": {},
"runner_name_mapping": {},
"state": {},
"paths": {},
}
@classmethod
def register_builder(cls, name):
r"""Register a dataset builder to registry with key 'name'
Args:
name: Key with which the builder will be registered.
Usage:
from minigpt4.common.registry import registry
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
"""
def wrap(builder_cls):
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
assert issubclass(
builder_cls, BaseDatasetBuilder
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls
)
if name in cls.mapping["builder_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["builder_name_mapping"][name]
)
)
cls.mapping["builder_name_mapping"][name] = builder_cls
return builder_cls
return wrap
@classmethod
def register_task(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(task_cls):
from minigpt4.tasks.base_task import BaseTask
assert issubclass(
task_cls, BaseTask
), "All tasks must inherit BaseTask class"
if name in cls.mapping["task_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["task_name_mapping"][name]
)
)
cls.mapping["task_name_mapping"][name] = task_cls
return task_cls
return wrap
@classmethod
def register_model(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(model_cls):
from minigpt4.models import BaseModel
assert issubclass(
model_cls, BaseModel
), "All models must inherit BaseModel class"
if name in cls.mapping["model_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["model_name_mapping"][name]
)
)
cls.mapping["model_name_mapping"][name] = model_cls
return model_cls
return wrap
@classmethod
def register_processor(cls, name):
r"""Register a processor to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(processor_cls):
from minigpt4.processors import BaseProcessor
assert issubclass(
processor_cls, BaseProcessor
), "All processors must inherit BaseProcessor class"
if name in cls.mapping["processor_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["processor_name_mapping"][name]
)
)
cls.mapping["processor_name_mapping"][name] = processor_cls
return processor_cls
return wrap
@classmethod
def register_lr_scheduler(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(lr_sched_cls):
if name in cls.mapping["lr_scheduler_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["lr_scheduler_name_mapping"][name]
)
)
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
return lr_sched_cls
return wrap
@classmethod
def register_runner(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(runner_cls):
if name in cls.mapping["runner_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["runner_name_mapping"][name]
)
)
cls.mapping["runner_name_mapping"][name] = runner_cls
return runner_cls
return wrap
@classmethod
def register_path(cls, name, path):
r"""Register a path to registry with key 'name'
Args:
name: Key with which the path will be registered.
Usage:
from minigpt4.common.registry import registry
"""
assert isinstance(path, str), "All path must be str."
if name in cls.mapping["paths"]:
raise KeyError("Name '{}' already registered.".format(name))
cls.mapping["paths"][name] = path
@classmethod
def register(cls, name, obj):
r"""Register an item to registry with key 'name'
Args:
name: Key with which the item will be registered.
Usage::
from minigpt4.common.registry import registry
registry.register("config", {})
"""
path = name.split(".")
current = cls.mapping["state"]
for part in path[:-1]:
if part not in current:
current[part] = {}
current = current[part]
current[path[-1]] = obj
# @classmethod
# def get_trainer_class(cls, name):
# return cls.mapping["trainer_name_mapping"].get(name, None)
@classmethod
def get_builder_class(cls, name):
return cls.mapping["builder_name_mapping"].get(name, None)
@classmethod
def get_model_class(cls, name):
return cls.mapping["model_name_mapping"].get(name, None)
@classmethod
def get_task_class(cls, name):
return cls.mapping["task_name_mapping"].get(name, None)
@classmethod
def get_processor_class(cls, name):
return cls.mapping["processor_name_mapping"].get(name, None)
@classmethod
def get_lr_scheduler_class(cls, name):
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
@classmethod
def get_runner_class(cls, name):
return cls.mapping["runner_name_mapping"].get(name, None)
@classmethod
def list_runners(cls):
return sorted(cls.mapping["runner_name_mapping"].keys())
@classmethod
def list_models(cls):
return sorted(cls.mapping["model_name_mapping"].keys())
@classmethod
def list_tasks(cls):
return sorted(cls.mapping["task_name_mapping"].keys())
@classmethod
def list_processors(cls):
return sorted(cls.mapping["processor_name_mapping"].keys())
@classmethod
def list_lr_schedulers(cls):
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
@classmethod
def list_datasets(cls):
return sorted(cls.mapping["builder_name_mapping"].keys())
@classmethod
def get_path(cls, name):
return cls.mapping["paths"].get(name, None)
@classmethod
def get(cls, name, default=None, no_warning=False):
r"""Get an item from registry with key 'name'
Args:
name (string): Key whose value needs to be retrieved.
default: If passed and key is not in registry, default value will
be returned with a warning. Default: None
no_warning (bool): If passed as True, warning when key doesn't exist
will not be generated. Useful for MMF's
internal operations. Default: False
"""
original_name = name
name = name.split(".")
value = cls.mapping["state"]
for subname in name:
value = value.get(subname, default)
if value is default:
break
if (
"writer" in cls.mapping["state"]
and value == default
and no_warning is False
):
cls.mapping["state"]["writer"].warning(
"Key {} is not present in registry, returning default value "
"of {}".format(original_name, default)
)
return value
@classmethod
def unregister(cls, name):
r"""Remove an item from registry with key 'name'
Args:
name: Key which needs to be removed.
Usage::
from mmf.common.registry import registry
config = registry.unregister("config")
"""
return cls.mapping["state"].pop(name, None)
registry = Registry()
================================================
FILE: minigpt4/common/utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import io
import json
import logging
import os
import pickle
import re
import shutil
import urllib
import urllib.error
import urllib.request
from typing import Optional
from urllib.parse import urlparse
import numpy as np
import pandas as pd
import yaml
from iopath.common.download import download
from iopath.common.file_io import file_lock, g_pathmgr
from minigpt4.common.registry import registry
from torch.utils.model_zoo import tqdm
from torchvision.datasets.utils import (
check_integrity,
download_file_from_google_drive,
extract_archive,
)
def now():
from datetime import datetime
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def get_cache_path(rel_path):
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
def get_abs_path(rel_path):
return os.path.join(registry.get_path("library_root"), rel_path)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
# The following are adapted from torchvision and vissl
# torchvision: https://github.com/pytorch/vision
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
print(f"Error creating directory: {dir_path}")
return is_success
def get_redirected_url(url: str):
"""
Given a URL, returns the URL it redirects to or the
original URL in case of no indirection
"""
import requests
with requests.Session() as session:
with session.get(url, stream=True, allow_redirects=True) as response:
if response.history:
return response.url
else:
return url
def to_google_drive_download_url(view_url: str) -> str:
"""
Utility function to transform a view URL of google drive
to a download URL for google drive
Example input:
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
Example output:
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
"""
splits = view_url.split("/")
assert splits[-1] == "view"
file_id = splits[-2]
return f"https://drive.google.com/uc?export=download&id={file_id}"
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
"""
Download a file from google drive
Downloading an URL from google drive requires confirmation when
the file of the size is too big (google drive notifies that
anti-viral checks cannot be performed on such files)
"""
import requests
with requests.Session() as session:
# First get the confirmation token and append it to the URL
with session.get(url, stream=True, allow_redirects=True) as response:
for k, v in response.cookies.items():
if k.startswith("download_warning"):
url = url + "&confirm=" + v
# Then download the content of the file
with session.get(url, stream=True, verify=True) as response:
makedir(output_path)
path = os.path.join(output_path, output_file_name)
total_size = int(response.headers.get("Content-length", 0))
with open(path, "wb") as file:
from tqdm import tqdm
with tqdm(total=total_size) as progress_bar:
for block in response.iter_content(
chunk_size=io.DEFAULT_BUFFER_SIZE
):
file.write(block)
progress_bar.update(len(block))
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(
urllib.request.Request(url, headers={"User-Agent": "vissl"})
) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def download_url(
url: str,
root: str,
filename: Optional[str] = None,
md5: Optional[str] = None,
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under.
If None, use the basename of the URL.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir(root)
# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return
# expand redirect chain if needed
url = get_redirected_url(url)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def cache_url(url: str, cache_dir: str) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
parsed_url = urlparse(url)
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
makedir(dirname)
filename = url.split("/")[-1]
cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logging.info(f"Downloading {url} to {cached} ...")
cached = download(url, dirname, filename=filename)
logging.info(f"URL {url} cached in {cached}")
return cached
# TODO (prigoyal): convert this into RAII-style API
def create_file_symlink(file1, file2):
"""
Simply create the symlinks for a given file1 to file2.
Useful during model checkpointing to symlinks to the
latest successful checkpoint.
"""
try:
if g_pathmgr.exists(file2):
g_pathmgr.rm(file2)
g_pathmgr.symlink(file1, file2)
except Exception as e:
logging.info(f"Could NOT create symlink. Error: {e}")
def save_file(data, filename, append_to_json=True, verbose=True):
"""
Common i/o utility to handle saving data to various file formats.
Supported:
.pkl, .pickle, .npy, .json
Specifically for .json, users have the option to either append (default)
or rewrite by passing in Boolean value to append_to_json.
"""
if verbose:
logging.info(f"Saving data to file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "wb") as fopen:
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
elif file_ext == ".npy":
with g_pathmgr.open(filename, "wb") as fopen:
np.save(fopen, data)
elif file_ext == ".json":
if append_to_json:
with g_pathmgr.open(filename, "a") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
else:
with g_pathmgr.open(filename, "w") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "w") as fopen:
dump = yaml.dump(data)
fopen.write(dump)
fopen.flush()
else:
raise Exception(f"Saving {file_ext} is not supported yet")
if verbose:
logging.info(f"Saved data to file: {filename}")
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
"""
Common i/o utility to handle loading data from various file formats.
Supported:
.pkl, .pickle, .npy, .json
For the npy files, we support reading the files in mmap_mode.
If the mmap_mode of reading is not successful, we load data without the
mmap_mode.
"""
if verbose:
logging.info(f"Loading data from file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext == ".txt":
with g_pathmgr.open(filename, "r") as fopen:
data = fopen.readlines()
elif file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "rb") as fopen:
data = pickle.load(fopen, encoding="latin1")
elif file_ext == ".npy":
if mmap_mode:
try:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(
fopen,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
except ValueError as e:
logging.info(
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
)
data = np.load(
filename,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
logging.info("Successfully loaded without g_pathmgr")
except Exception:
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
else:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
elif file_ext == ".json":
with g_pathmgr.open(filename, "r") as fopen:
data = json.load(fopen)
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "r") as fopen:
data = yaml.load(fopen, Loader=yaml.FullLoader)
elif file_ext == ".csv":
with g_pathmgr.open(filename, "r") as fopen:
data = pd.read_csv(fopen)
else:
raise Exception(f"Reading from {file_ext} is not supported yet")
return data
def abspath(resource_path: str):
"""
Make a path absolute, but take into account prefixes like
"http://" or "manifold://"
"""
regex = re.compile(r"^\w+://")
if regex.match(resource_path) is None:
return os.path.abspath(resource_path)
else:
return resource_path
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
logging.info(f"Error creating directory: {dir_path}")
return is_success
def is_url(input_url):
"""
Check if an input string is a url. look for http(s):// and ignoring the case
"""
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
return is_url
def cleanup_dir(dir):
"""
Utility for deleting a directory. Useful for cleaning the storage space
that contains various training artifacts like checkpoints, data etc.
"""
if os.path.exists(dir):
logging.info(f"Deleting directory: {dir}")
shutil.rmtree(dir)
logging.info(f"Deleted contents of directory: {dir}")
def get_file_size(filename):
"""
Given a file, get the size of file in MB
"""
size_in_mb = os.path.getsize(filename) / float(1024**2)
return size_in_mb
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
================================================
# coding: utf-8
import sys
dataDir = '../../VQA'
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
from vqa import VQA
from vqaEvaluation.vqaEval import VQAEval
import matplotlib.pyplot as plt
import skimage.io as io
import json
import random
import os
# set up file names and paths
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
resultType ='fake'
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
# An example result json file has been provided in './Results' folder.
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
resultType, fileType) for fileType in fileTypes]
# create vqa object and vqaRes object
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(resFile, quesFile)
# create vqaEval object by taking vqa and vqaRes
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
# evaluate results
"""
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
By default it uses all the question ids in annotation file
"""
vqaEval.evaluate()
# print accuracies
print "\n"
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
print "Per Question Type Accuracy is the following:"
for quesType in vqaEval.accuracy['perQuestionType']:
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
print "\n"
print "Per Answer Type Accuracy is the following:"
for ansType in vqaEval.accuracy['perAnswerType']:
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
print "\n"
# demo how to use evalQA to retrieve low score result
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
if len(evals) > 0:
print 'ground truth answers'
randomEval = random.choice(evals)
randomAnn = vqa.loadQA(randomEval)
vqa.showQA(randomAnn)
print '\n'
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
ann = vqaRes.loadQA(randomEval)[0]
print "Answer: %s\n" %(ann['answer'])
imgId = randomAnn[0]['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# plot accuracy for various question types
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
plt.title('Per Question Type Accuracy', fontsize=10)
plt.xlabel('Question Types', fontsize=10)
plt.ylabel('Accuracy', fontsize=10)
plt.show()
# save evaluation results to ./Results folder
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
================================================
author='aagrawal'
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
================================================
# coding=utf-8
__author__='aagrawal'
import re
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
class VQAEval:
def __init__(self, vqa, vqaRes, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
self.params = {'question_id': vqa.getQuesIds()}
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
"youll": "you'll", "youre": "you're", "youve": "you've"}
self.manualMap = { 'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10'
}
self.articles = ['a',
'an',
'the'
]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [';', r"/", '[', ']', '"', '{', '}',
'(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!']
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params['question_id']]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
# print "computing accuracy"
step = 0
for quesId in quesIds:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
ansDic['answer'] = ansDic['answer'].strip()
resAns = res[quesId]['answer']
resAns = resAns.replace('\n', ' ')
resAns = resAns.replace('\t', ' ')
resAns = resAns.strip()
gtAcc = []
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
for gtAnsDatum in gts[quesId]['answers']:
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
acc = min(1, float(len(matchingAns))/3)
gtAcc.append(acc)
quesType = gts[quesId]['question_type']
ansType = gts[quesId]['answer_type']
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step%100 == 0:
self.updateProgress(step/float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
# print "Done computing accuracy"
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = self.periodStrip.sub("",
outText,
re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = ' '.join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100*acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength*progress))
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
sys.stdout.write(text)
sys.stdout.flush()
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
================================================
# coding: utf-8
from vqaTools.vqa import VQA
import random
import skimage.io as io
import matplotlib.pyplot as plt
import os
dataDir ='../../VQA'
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
# initialize VQA api for QA annotations
vqa=VQA(annFile, quesFile)
# load and display QA annotations for given question types
"""
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
"""
annIds = vqa.getQuesIds(quesTypes='how many');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given answer types
"""
ansTypes can be one of the following
yes/no
number
other
"""
annIds = vqa.getQuesIds(ansTypes='yes/no');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given images
"""
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
"""
ids = vqa.getImgIds()
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
================================================
__author__ = 'aagrawal'
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
================================================
__author__ = 'aagrawal'
__version__ = '0.9'
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
# print 'loading VQA annotations and questions into memory...'
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, 'r'))
questions = json.load(open(question_file, 'r'))
# print datetime.datetime.utcnow() - time_t
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
for ann in self.dataset['annotations']:
imgToQA[ann['image_id']] += [ann]
qa[ann['question_id']] = ann
for ques in self.questions['questions']:
qqa[ques['question_id']] = ques
# print 'index created!'
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
# for key, value in self.datset['info'].items():
# print '%s: %s'%(key, value)
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['question_id'] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(quesIds) == 0:
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['image_id'] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann['question_id']
print("Question: %s" % (self.qqa[quesId]['question']))
for ans in ann['answers']:
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset['info'] = copy.deepcopy(self.questions['info'])
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
res.dataset['license'] = copy.deepcopy(self.questions['license'])
# print 'Loading and preparing results... '
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, 'results is not an array of objects'
annsQuesIds = [ann['question_id'] for ann in anns]
assert set(annsQuesIds) == set(self.getQuesIds()), \
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
for ann in anns:
quesId = ann['question_id']
if res.dataset['task_type'] == 'Multiple Choice':
assert ann['answer'] in self.qqa[quesId][
'multiple_choices'], 'predicted answer is not one of the multiple choices'
qaAnn = self.qa[quesId]
ann['image_id'] = qaAnn['image_id']
ann['question_type'] = qaAnn['question_type']
ann['answer_type'] = qaAnn['answer_type']
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
res.dataset['annotations'] = anns
res.createIndex()
return res
================================================
FILE: minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
================================================
how many
what color is the
is the
where is the
what
what is
are the
what is the
is there a
does the
is the woman
is the man
what is on the
is it
is the girl
is the boy
is the dog
are they
who is
what kind of
what color are the
what is in the
what is the man
is there
what is the woman
what are the
what is the boy
are there
what is the girl
is this
how
which
how many people are
is the cat
why is the
are
will the
what type of
what is the dog
do
is she
does
do the
is
is the baby
are there any
is the lady
can
what animal is
where are the
is the sun
what are they
did the
what is the cat
what is the lady
how many clouds are
is that
is the little girl
is he
are these
how many trees are
how many pillows
are the people
why
is the young
how many windows are
is this a
what is the little
is the tv
how many animals are
who
how many pictures
how many plants are
how many birds are
what color is
what is the baby
is anyone
what color
how many bushes
is the old man
none of the above
================================================
FILE: minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
================================================
how many
is the
what
what color is the
what is the
is this
is this a
what is
are the
what kind of
is there a
what type of
is it
what are the
where is the
is there
does the
what color are the
are these
are there
which
is
what is the man
is the man
are
how
does this
what is on the
what does the
how many people are
what is in the
what is this
do
what are
are they
what time
what sport is
are there any
is he
what color is
why
where are the
what color
who is
what animal is
is the woman
is this an
do you
how many people are in
what room is
has
is this person
what is the woman
can you
why is the
is the person
what is the color of the
what is the person
could
was
is that a
what number is
what is the name
what brand
none of the above
================================================
FILE: minigpt4/common/vqa_tools/VQA/README.md
================================================
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
===================
## VQA v2.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
There is only one type of task
- Open-ended task
## VQA v1.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
- Abstract
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
- 600,000 answers for training and 300,000 answers for validation (10 per question)
There are two types of tasks
- Open-ended task
- Multiple-choice task (18 choices per question)
## Requirements ##
- python 2.7
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
## Files ##
./Questions
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
./Annotations
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
./Images
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
./PythonHelperTools
- This directory contains the Python API to read and visualize the VQA dataset
- vqaDemo.py (demo script)
- vqaTools (API to read and visualize data)
./PythonEvaluationTools
- This directory contains the Python evaluation code
- vqaEvalDemo.py (evaluation demo script)
- vqaEvaluation (evaluation code)
./Results
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
./QuestionTypes
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
- mscoco_question_types.txt
- abstract_v002_question_types.txt
## References ##
- [VQA: Visual Question Answering](http://visualqa.org/)
- [Microsoft COCO](http://mscoco.org/)
## Developers ##
- Aishwarya Agrawal (Virginia Tech)
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
================================================
FILE: minigpt4/common/vqa_tools/VQA/license.txt
================================================
Copyright (c) 2014, Aishwarya Agrawal
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
The views and conclusions contained in the software and documentation are
those
of the authors and should not be interpreted as representing official
policies,
either expressed or implied, of the FreeBSD Project.
================================================
FILE: minigpt4/common/vqa_tools/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"
================================================
FILE: minigpt4/common/vqa_tools/vqa.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"
__version__ = "0.9"
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
print("loading VQA annotations and questions into memory...")
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, "r"))
questions = json.load(open(question_file, "r"))
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
# create index
print("creating index...")
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
for ann in self.dataset["annotations"]:
imgToQA[ann["image_id"]] += [ann]
qa[ann["question_id"]] = ann
for ques in self.questions["questions"]:
qqa[ques["question_id"]] = ques
print("index created!")
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
for key, value in self.datset["info"].items():
print("%s: %s" % (key, value))
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(imgIds) == 0:
anns = sum(
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
[],
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["question_id"] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(quesIds) == 0:
anns = sum(
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["image_id"] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann["question_id"]
print("Question: %s" % (self.qqa[quesId]["question"]))
for ans in ann["answers"]:
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset["info"] = copy.deepcopy(self.questions["info"])
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
res.dataset["license"] = copy.deepcopy(self.questions["license"])
print("Loading and preparing results... ")
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, "results is not an array of objects"
annsQuesIds = [ann["question_id"] for ann in anns]
assert set(annsQuesIds) == set(
self.getQuesIds()
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
for ann in anns:
quesId = ann["question_id"]
if res.dataset["task_type"] == "Multiple Choice":
assert (
ann["answer"] in self.qqa[quesId]["multiple_choices"]
), "predicted answer is not one of the multiple choices"
qaAnn = self.qa[quesId]
ann["image_id"] = qaAnn["image_id"]
ann["question_type"] = qaAnn["question_type"]
ann["answer_type"] = qaAnn["answer_type"]
print(
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
)
res.dataset["annotations"] = anns
res.createIndex()
return res
================================================
FILE: minigpt4/common/vqa_tools/vqa_eval.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
# coding=utf-8
__author__ = "aagrawal"
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re
class VQAEval:
def __init__(self, vqa=None, vqaRes=None, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
if vqa is not None:
self.params = {"question_id": vqa.getQuesIds()}
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params["question_id"]]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
print("computing accuracy")
step = 0
for quesId in quesIds:
resAns = res[quesId]["answer"]
resAns = resAns.replace("\n", " ")
resAns = resAns.replace("\t", " ")
resAns = resAns.strip()
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
gtAcc = []
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]["answers"]:
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
for gtAnsDatum in gts[quesId]["answers"]:
otherGTAns = [
item for item in gts[quesId]["answers"] if item != gtAnsDatum
]
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
quesType = gts[quesId]["question_type"]
ansType = gts[quesId]["answer_type"]
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step % 100 == 0:
self.updateProgress(step / float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
print("Done computing accuracy")
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = " ".join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
self.accuracy["perQuestionType"] = {
quesType: round(
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
self.n,
)
for quesType in accQuesType
}
self.accuracy["perAnswerType"] = {
ansType: round(
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
)
for ansType in accAnsType
}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100 * acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength * progress))
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
"#" * block + "-" * (barLength - block), int(progress * 100), status
)
sys.stdout.write(text)
sys.stdout.flush()
================================================
FILE: minigpt4/configs/datasets/aokvqa/defaults.yaml
================================================
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
datasets:
aok_vqa:
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]
build_info:
# Be careful not to append minus sign (-) before split to avoid itemizing
annotations:
train:
url:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
storage:
- /path/to/aokvqa_v1p0_train.json
images:
storage: /path/to/coco/images
================================================
FILE: minigpt4/configs/datasets/cc_sbu/align.yaml
================================================
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/
================================================
FILE: minigpt4/configs/datasets/cc_sbu/defaults.yaml
================================================
datasets:
cc_sbu:
data_type: images
build_info:
storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
================================================
FILE: minigpt4/configs/datasets/coco/caption.yaml
================================================
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
datasets:
coco_caption: # name of the dataset builder
# dataset_card: dataset_card/coco_caption.md
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]
build_info:
# Be careful not to append minus sign (-) before split to avoid itemizing
annotations:
train:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
md5: aa31ac474cf6250ebb81d18348a07ed8
storage: /path/to/coco_caption/coco_karpathy_train.json
images:
storage: /path/to/coco/images
================================================
FILE: minigpt4/configs/datasets/coco/defaults_vqa.yaml
================================================
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
datasets:
coco_vqa:
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]
build_info:
annotations:
train:
url:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json
storage:
- /path/to/vqav2/vqa_train.json
- /path/to/vqav2/vqa_val.json
images:
storage: /path/to/coco/images
================================================
FILE: minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml
================================================
datasets:
invrefcoco:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: invrefcoco
splitBy: unc
================================================
FILE: minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml
================================================
datasets:
invrefcocog:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: invrefcocog
splitBy: umd
================================================
FILE: minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml
================================================
datasets:
invrefcocop:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: invrefcoco+
splitBy: unc
================================================
FILE: minigpt4/configs/datasets/coco_bbox/refcoco.yaml
================================================
datasets:
refcoco:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: refcoco
splitBy: unc
================================================
FILE: minigpt4/configs/datasets/coco_bbox/refcocog.yaml
================================================
datasets:
refcocog:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: refcocog
splitBy: umd
================================================
FILE: minigpt4/configs/datasets/coco_bbox/refcocop.yaml
================================================
datasets:
refcocop:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/refcoco_annotations
dataset: refcoco+
splitBy: unc
================================================
FILE: minigpt4/configs/datasets/flickr/caption_to_phrase.yaml
================================================
datasets:
flickr_CaptionToPhrase:
data_type: images
build_info:
image_path: /path/to/filtered_flikcr/images
ann_path: /path/to/filtered_flickr/captiontobbox.json
================================================
FILE: minigpt4/configs/datasets/flickr/default.yaml
================================================
datasets:
flickr_grounded_caption:
data_type: images
build_info:
image_path: /path/to/filtered_flikcr/images
ann_path: /path/to/filtered_flikcr/groundedcaption.json
================================================
FILE: minigpt4/configs/datasets/flickr/object_to_phrase.yaml
================================================
datasets:
flickr_ObjectToPhrase:
data_type: images
build_info:
image_path: /path/to/filtered_flikcr/images
ann_path: /path/to/filtered_flikcr/phrasetobbox.json
================================================
FILE: minigpt4/configs/datasets/gqa/balanced_val.yaml
================================================
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
datasets:
gqa:
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]
build_info:
# Be careful not to append minus sign (-) before split to avoid itemizing
annotations:
train:
url:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json
storage:
- /path/to/gqa/train_balanced_questions.json
images:
storage: /path/to/gqa/images
================================================
FILE: minigpt4/configs/datasets/laion/defaults.yaml
================================================
datasets:
laion:
data_type: images
build_info:
storage: /path/to/laion_dataset/{00000..10488}.tar
================================================
FILE: minigpt4/configs/datasets/llava/conversation.yaml
================================================
datasets:
llava_conversation:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/llava/conversation_58k.json
================================================
FILE: minigpt4/configs/datasets/llava/detail.yaml
================================================
datasets:
llava_detail:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/llava/detail_23k.json
================================================
FILE: minigpt4/configs/datasets/llava/reason.yaml
================================================
datasets:
llava_reason:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/llava/complex_reasoning_77k.json
================================================
FILE: minigpt4/configs/datasets/multitask_conversation/default.yaml
================================================
datasets:
multitask_conversation:
data_type: images
build_info:
image_path: /path/to/coco/images
ann_path: /path/to/multitask_conversation/multi_task_conversation.json
================================================
FILE: minigpt4/configs/datasets/nlp/unnatural_instruction.yaml
================================================
datasets:
unnatural_instruction:
data_type: text
build_info:
ann_path: /path/to/unnatural_instructions/filtered_unnatural_instruction.json
================================================
FILE: minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml
================================================
datasets:
ocrvqa:
data_type: images
build_info:
image_path: /path/to/ocrvqa/images
ann_path: /path/to/ocrvqa/dataset.json
================================================
FILE: minigpt4/configs/datasets/okvqa/defaults.yaml
================================================
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
datasets:
ok_vqa:
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]
build_info:
# Be careful not to append minus sign (-) before split to avoid itemizing
annotations:
train:
url:
# TODO make this order insensitive
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json
storage:
- /path/to/okvqa/okvqa_train.json
images:
storage: /path/to/coco/images
================================================
FILE: minigpt4/configs/datasets/textcaps/caption.yaml
================================================
datasets:
textcaps_caption:
data_type: images
build_info:
image_path: /path/to/textcaps/train_images
ann_path: /path/to/textcaps/TextCaps_0.1_train.json
================================================
FILE: minigpt4/configs/datasets/vg/ref.yaml
================================================
datasets:
refvg:
data_type: images
build_info:
data_dir: /path/to/visual_genome
================================================
FILE: minigpt4/configs/default.yaml
================================================
env:
# For default users
# cache_root: "cache"
# For internal use with persistent storage
cache_root: "/export/home/.cache/minigpt4"
================================================
FILE: minigpt4/configs/models/minigpt4_llama2.yaml
================================================
model:
arch: minigpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
has_qformer: False
# generation configs
prompt: ""
llama_model: "please set this value to the path of llama2-chat-7b"
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
================================================
FILE: minigpt4/configs/models/minigpt4_vicuna0.yaml
================================================
model:
arch: minigpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True
# Q-Former
num_query_token: 32
# generation configs
prompt: ""
llama_model: "please set this value to the path of vicuna model"
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
================================================
FILE: minigpt4/configs/models/minigpt_v2.yaml
================================================
model:
arch: minigpt_v2
# vit encoder
image_size: 448
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
# generation configs
prompt: ""
llama_model: "please set this value to the path of llama2-chat-7b"
lora_r: 64
lora_alpha: 16
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
eval:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
================================================
FILE: minigpt4/conversation/__init__.py
================================================
================================================
FILE: minigpt4/conversation/conversation.py
================================================
import argparse
import time
from threading import Thread
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from minigpt4.common.registry import registry
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
# system_img: List[Image.Image] = []
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
# system_img=self.system_img,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id)
def dict(self):
return {
"system": self.system,
# "system_img": self.system_img,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all(input_ids[:, -len(stop):] == stop).item():
return True
return False
CONV_VISION_Vicuna0 = Conversation(
system="Give the following image:
ImageContent. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human: ", "Assistant: "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
CONV_VISION_LLama2 = Conversation(
system="Give the following image:
ImageContent. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("[INST] ", " [/INST] "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
CONV_VISION_minigptv2 = Conversation(
system="",
roles=("[INST] ", " [/INST]"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
class Chat:
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
self.device = device
self.model = model
self.vis_processor = vis_processor
if stopping_criteria is not None:
self.stopping_criteria = stopping_criteria
else:
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '': # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
embs = self.model.get_context_emb(prompt, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=float(temperature),
)
return generation_kwargs
def answer(self, conv, img_list, **kargs):
generation_dict = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model_generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
thread.start()
return streamer
def model_generate(self, *args, **kwargs):
# for 8 bit and 16 bit compatibility
with self.model.maybe_autocast():
output = self.model.llama_model.generate(*args, **kwargs)
return output
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "
")
img_list.append(image)
msg = "Received."
return msg
================================================
FILE: minigpt4/datasets/__init__.py
================================================
================================================
FILE: minigpt4/datasets/builders/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
from minigpt4.datasets.builders.image_text_pair_builder import (
CCSBUBuilder,
LaionBuilder,
CCSBUAlignBuilder
)
from minigpt4.common.registry import registry
__all__ = [
"CCSBUBuilder",
"LaionBuilder",
"CCSBUAlignBuilder"
]
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
"""
Example
>>> dataset = load_dataset("coco_caption", cfg=None)
>>> splits = dataset.keys()
>>> print([len(dataset[split]) for split in splits])
"""
if cfg_path is None:
cfg = None
else:
cfg = load_dataset_config(cfg_path)
try:
builder = registry.get_builder_class(name)(cfg)
except TypeError:
print(
f"Dataset {name} not found. Available datasets:\n"
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
)
exit(1)
if vis_path is not None:
if data_type is None:
# use default data type in the config
data_type = builder.config.data_type
assert (
data_type in builder.config.build_info
), f"Invalid data_type {data_type} for {name}."
builder.config.build_info.get(data_type).storage = vis_path
dataset = builder.build_datasets()
return dataset
class DatasetZoo:
def __init__(self) -> None:
self.dataset_zoo = {
k: list(v.DATASET_CONFIG_DICT.keys())
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
}
def get_names(self):
return list(self.dataset_zoo.keys())
dataset_zoo = DatasetZoo()
================================================
FILE: minigpt4/datasets/builders/base_dataset_builder.py
================================================
"""
This file is from
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import shutil
import warnings
from omegaconf import OmegaConf
import torch.distributed as dist
from torchvision.datasets.utils import download_url
import minigpt4.common.utils as utils
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
class BaseDatasetBuilder:
train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None):
super().__init__()
if cfg is None:
# help to create datasets from default config.
self.config = load_dataset_config(self.default_config_path())
elif isinstance(cfg, str):
self.config = load_dataset_config(cfg)
else:
# when called from task.build_dataset()
self.config = cfg
self.data_type = self.config.data_type
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
def build_datasets(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
if is_main_process():
self._download_data()
if is_dist_avail_and_initialized():
dist.barrier()
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
datasets = self.build() # dataset['train'/'val'/'test']
return datasets
def build_processors(self):
vis_proc_cfg = self.config.get("vis_processor")
txt_proc_cfg = self.config.get("text_processor")
if vis_proc_cfg is not None:
vis_train_cfg = vis_proc_cfg.get("train")
vis_eval_cfg = vis_proc_cfg.get("eval")
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
if txt_proc_cfg is not None:
txt_train_cfg = txt_proc_cfg.get("train")
txt_eval_cfg = txt_proc_cfg.get("eval")
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
@staticmethod
def _build_proc_from_cfg(cfg):
return (
registry.get_processor_class(cfg.name).from_config(cfg)
if cfg is not None
else None
)
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
def _download_data(self):
self._download_ann()
self._download_vis()
def _download_ann(self):
"""
Download annotation files if necessary.
All the vision-language datasets should have annotations of unified format.
storage_path can be:
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
Local annotation paths should be relative.
"""
anns = self.config.build_info.annotations
splits = anns.keys()
cache_root = registry.get_path("cache_root")
for split in splits:
info = anns[split]
urls, storage_paths = info.get("url", None), info.storage
if isinstance(urls, str):
urls = [urls]
if isinstance(storage_paths, str):
storage_paths = [storage_paths]
assert len(urls) == len(storage_paths)
for url_or_filename, storage_path in zip(urls, storage_paths):
# if storage_path is relative, make it full by prefixing with cache_root.
if not os.path.isabs(storage_path):
storage_path = os.path.join(cache_root, storage_path)
dirname = os.path.dirname(storage_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
if os.path.isfile(url_or_filename):
src, dst = url_or_filename, storage_path
if not os.path.exists(dst):
shutil.copyfile(src=src, dst=dst)
else:
logging.info("Using existing file {}.".format(dst))
else:
if os.path.isdir(storage_path):
# if only dirname is provided, suffix with basename of URL.
raise ValueError(
"Expecting storage_path to be a file path, got directory {}".format(
storage_path
)
)
else:
filename = os.path.basename(storage_path)
download_url(url=url_or_filename, root=dirname, filename=filename)
def _download_vis(self):
storage_path = self.config.build_info.get(self.data_type).storage
storage_path = utils.get_cache_path(storage_path)
if not os.path.exists(storage_path):
warnings.warn(
f"""
The specified path {storage_path} for visual inputs does not exist.
Please provide a correct path to the visual inputs or
refer to datasets/download_scripts/README.md for downloading instructions.
"""
)
def build(self):
"""
Create by split datasets inheriting torch.utils.data.Datasets.
# build() can be dataset-specific. Overwrite to customize.
"""
self.build_processors()
build_info = self.config.build_info
ann_info = build_info.annotations
vis_info = build_info.get(self.data_type)
datasets = dict()
for split in ann_info.keys():
if split not in ["train", "val", "test"]:
continue
is_train = split == "train"
# processors
vis_processor = (
self.vis_processors["train"]
if is_train
else self.vis_processors["eval"]
)
text_processor = (
self.text_processors["train"]
if is_train
else self.text_processors["eval"]
)
# annotation path
ann_paths = ann_info.get(split).storage
if isinstance(ann_paths, str):
ann_paths = [ann_paths]
abs_ann_paths = []
for ann_path in ann_paths:
if not os.path.isabs(ann_path):
ann_path = utils.get_cache_path(ann_path)
abs_ann_paths.append(ann_path)
ann_paths = abs_ann_paths
# visual data storage path
vis_path = os.path.join(vis_info.storage, split)
if not os.path.isabs(vis_path):
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
vis_path = utils.get_cache_path(vis_path)
if not os.path.exists(vis_path):
warnings.warn("storage path {} does not exist.".format(vis_path))
# create datasets
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
datasets[split] = dataset_cls(
vis_processor=vis_processor,
text_processor=text_processor,
ann_paths=ann_paths,
vis_root=vis_path,
)
return datasets
def load_dataset_config(cfg_path):
cfg = OmegaConf.load(cfg_path).datasets
cfg = cfg[list(cfg.keys())[0]]
return cfg
================================================
FILE: minigpt4/datasets/builders/image_text_pair_builder.py
================================================
import os
import logging
import warnings
from minigpt4.common.registry import registry
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
from minigpt4.datasets.datasets.text_caps import TextCapDataset
from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset
from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
from minigpt4.datasets.datasets.multitask_conversation import MultiTaskConversationDataset
from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObjectDataset,PhraseToObjectDataset
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
from minigpt4.datasets.datasets.gqa_datasets import GQADataset
from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
@registry.register_builder("multitask_conversation")
class MultitaskConversationBuilder(BaseDatasetBuilder):
train_dataset_cls = MultiTaskConversationDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/multitask_conversation/default.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("unnatural_instruction")
class UnnaturalInstructionBuilder(BaseDatasetBuilder):
train_dataset_cls = UnnaturalDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/nlp/unnatural_instruction.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
)
return datasets
@registry.register_builder("llava_detail")
class LlavaDetailBuilder(BaseDatasetBuilder):
train_dataset_cls = LlavaDetailDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/llava/detail.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("llava_reason")
class LlavaReasonBuilder(BaseDatasetBuilder):
train_dataset_cls = LlavaReasonDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/llava/reason.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("llava_conversation")
class LlavaReasonBuilder(BaseDatasetBuilder):
train_dataset_cls = LlavaConversationDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/llava/conversation.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
class AllRefCOCOBuilder(BaseDatasetBuilder):
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
image_path = build_info.image_path
ann_path = build_info.ann_path
datasets = dict()
if not os.path.exists(image_path):
warnings.warn("image path {} does not exist.".format(image_path))
if not os.path.exists(ann_path):
warnings.warn("ann path {} does not exist.".format(ann_path))
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=ann_path,
vis_root=image_path,
dataset=build_info.dataset,
splitBy=build_info.splitBy
)
return datasets
@registry.register_builder("refcoco")
class RefCOCOBuilder(AllRefCOCOBuilder):
train_dataset_cls = ReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/refcoco.yaml",
}
@registry.register_builder("refcocop")
class RefCOCOPBuilder(AllRefCOCOBuilder):
train_dataset_cls = ReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/refcocop.yaml",
}
@registry.register_builder("refcocog")
class RefCOCOGBuilder(AllRefCOCOBuilder):
train_dataset_cls = ReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/refcocog.yaml",
}
@registry.register_builder("invrefcoco")
class RefCOCOBuilder(AllRefCOCOBuilder):
train_dataset_cls = InvReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/invrefcoco.yaml",
}
@registry.register_builder("invrefcocop")
class RefCOCOPBuilder(AllRefCOCOBuilder):
train_dataset_cls = InvReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/invrefcocop.yaml",
}
@registry.register_builder("invrefcocog")
class RefCOCOGBuilder(AllRefCOCOBuilder):
train_dataset_cls = InvReferCOCODataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco_bbox/invrefcocog.yaml",
}
@registry.register_builder("refvg")
class RefVisualGenomeBuilder(BaseDatasetBuilder):
train_dataset_cls = ReferVisualGenomeDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/vg/ref.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
data_dir = build_info.data_dir
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
data_dir=data_dir,
)
return datasets
@registry.register_builder("textcaps_caption")
class TextcapCaptionBuilder(BaseDatasetBuilder):
train_dataset_cls = TextCapDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"}
def _download_ann(self):
pass
def _download_vis(self):
pass
def build(self):
self.build_processors()
build_info = self.config.build_info
datasets = dict()
split = "train"
# create datasets
# [NOTE] return inner_datasets (wds.DataPipeline)
dataset_cls = self.train_dataset_cls
datasets[split] = dataset_cls(
vis_processor=self.vis_processors[split],
text_processor=self.text_processors[split],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("coco_vqa")
class COCOVQABuilder(BaseDatasetBuilder):
train_dataset_cls = COCOVQADataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco/defaults_vqa.yaml",
}
@registry.register_builder("ok_vqa")
class OKVQABuilder(COCOVQABuilder):
DATASET_CONFIG_DICT = {
"default": "configs/datasets/okvqa/defaults.yaml",
}
@registry.register_builder("aok_vqa")
class AOKVQABuilder(BaseDatasetBuilder):
train_dataset_cls = AOKVQADataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
@registry.register_builder("gqa")
class GQABuilder(BaseDatasetBuilder):
train_dataset_cls = GQADataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/gqa/balanced_val.yaml",
}
@registry.register_builder("flickr_grounded_caption")
class GroundedCaptionBuilder(BaseDatasetBuilder):
train_dataset_cls = GroundedDetailDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/flickr/default.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("flickr_CaptionToPhrase")
class CaptionToPhraseBuilder(BaseDatasetBuilder):
train_dataset_cls = CaptionToObjectDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/flickr/caption_to_phrase.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
@registry.register_builder("flickr_ObjectToPhrase")
class CaptionToPhraseBuilder(BaseDatasetBuilder):
train_dataset_cls = PhraseToObjectDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/flickr/object_to_phrase.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_path=build_info.ann_path,
vis_root=build_info.image_path,
)
return datasets
class DocumentVQABuilder(BaseDatasetBuilder):
def _download_ann(self):
pass
def _download_vis(self):
pass
def build(self):
self.build_processors()
build_info = self.config.build_info
datasets = dict()
split = "train"
dataset_cls = self.train_dataset_cls
datasets[split] = dataset_cls(
vis_processor=self.vis_processors[split],
text_processor=self.text_processors[split],
vis_root=build_info.image_path,
ann_path=build_info.ann_path
)
return datasets
@registry.register_builder("ocrvqa")
class OCRVQABuilder(DocumentVQABuilder):
train_dataset_cls = OCRVQADataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/ocrvqa/ocrvqa.yaml"}
@registry.register_builder("cc_sbu")
class CCSBUBuilder(BaseDatasetBuilder):
train_dataset_cls = CCSBUDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
def _download_ann(self):
pass
def _download_vis(self):
pass
def build(self):
self.build_processors()
build_info = self.config.build_info
datasets = dict()
split = "train"
# create datasets
# [NOTE] return inner_datasets (wds.DataPipeline)
dataset_cls = self.train_dataset_cls
datasets[split] = dataset_cls(
vis_processor=self.vis_processors[split],
text_processor=self.text_processors[split],
location=build_info.storage,
).inner_dataset
return datasets
@registry.register_builder("laion")
class LaionBuilder(BaseDatasetBuilder):
train_dataset_cls = LaionDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
def _download_ann(self):
pass
def _download_vis(self):
pass
def build(self):
self.build_processors()
build_info = self.config.build_info
datasets = dict()
split = "train"
# create datasets
# [NOTE] return inner_datasets (wds.DataPipeline)
dataset_cls = self.train_dataset_cls
datasets[split] = dataset_cls(
vis_processor=self.vis_processors[split],
text_processor=self.text_processors[split],
location=build_info.storage,
).inner_dataset
return datasets
@registry.register_builder("coco_caption")
class COCOCapBuilder(BaseDatasetBuilder):
train_dataset_cls = COCOCapDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/coco/caption.yaml",
}
@registry.register_builder("cc_sbu_align")
class CCSBUAlignBuilder(BaseDatasetBuilder):
train_dataset_cls = CCSBUAlignDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/cc_sbu/align.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
storage_path = build_info.storage
datasets = dict()
if not os.path.exists(storage_path):
warnings.warn("storage path {} does not exist.".format(storage_path))
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
vis_root=os.path.join(storage_path, 'image'),
)
return datasets
================================================
FILE: minigpt4/datasets/data_utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import gzip
import logging
import os
import random as rnd
import tarfile
import zipfile
import random
from typing import List
from tqdm import tqdm
import decord
from decord import VideoReader
import webdataset as wds
import numpy as np
import torch
from torch.utils.data.dataset import IterableDataset
from minigpt4.common.registry import registry
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT")
class ChainDataset(wds.DataPipeline):
r"""Dataset for chaining multiple :class:`DataPipeline` s.
This class is useful to assemble different existing dataset streams. The
chaining operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.
Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
super().__init__()
self.datasets = datasets
self.prob = []
self.names = []
for dataset in self.datasets:
if hasattr(dataset, 'name'):
self.names.append(dataset.name)
else:
self.names.append('Unknown')
if hasattr(dataset, 'sample_ratio'):
self.prob.append(dataset.sample_ratio)
else:
self.prob.append(1)
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
def __iter__(self):
datastreams = [iter(dataset) for dataset in self.datasets]
while True:
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
yield next(select_datastream)
def apply_to_sample(f, sample):
if len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
else:
return x
return _apply(sample)
def move_to_cuda(sample):
def _move_to_cuda(tensor):
return tensor.cuda()
return apply_to_sample(_move_to_cuda, sample)
def prepare_sample(samples, cuda_enabled=True):
if cuda_enabled:
samples = move_to_cuda(samples)
# TODO fp16 support
return samples
def reorg_datasets_by_split(datasets, batch_sizes):
"""
Organizes datasets by split.
Args:
datasets: dict of torch.utils.data.Dataset objects by name.
Returns:
Dict of datasets by split {split_name: List[Datasets]}.
"""
# if len(datasets) == 1:
# return datasets[list(datasets.keys())[0]]
# else:
reorg_datasets = dict()
reorg_batch_sizes = dict()
# reorganize by split
for dataset_name, dataset in datasets.items():
for split_name, dataset_split in dataset.items():
if split_name not in reorg_datasets:
reorg_datasets[split_name] = [dataset_split]
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
else:
reorg_datasets[split_name].append(dataset_split)
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
return reorg_datasets, reorg_batch_sizes
def concat_datasets(datasets):
"""
Concatenates multiple datasets into a single dataset.
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
generic IterableDataset because it requires creating separate samplers.
Now only supports conctenating training datasets and assuming validation and testing
have only a single dataset. This is because metrics should not be computed on the concatenated
datasets.
Args:
datasets: dict of torch.utils.data.Dataset objects by split.
Returns:
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
"val" and "test" remain the same.
If the input training datasets contain both map-style and DataPipeline datasets, returns
a tuple, where the first element is a concatenated map-style dataset and the second
element is a chained DataPipeline dataset.
"""
# concatenate datasets in the same split
for split_name in datasets:
if split_name != "train":
assert (
len(datasets[split_name]) == 1
), "Do not support multiple {} datasets.".format(split_name)
datasets[split_name] = datasets[split_name][0]
else:
iterable_datasets, map_datasets = [], []
for dataset in datasets[split_name]:
if isinstance(dataset, wds.DataPipeline):
logging.info(
"Dataset {} is IterableDataset, can't be concatenated.".format(
dataset
)
)
iterable_datasets.append(dataset)
elif isinstance(dataset, IterableDataset):
raise NotImplementedError(
"Do not support concatenation of generic IterableDataset."
)
else:
map_datasets.append(dataset)
# if len(iterable_datasets) > 0:
# concatenate map-style datasets and iterable-style datasets separately
if len(iterable_datasets) > 1:
chained_datasets = (
ChainDataset(iterable_datasets)
)
elif len(iterable_datasets) == 1:
chained_datasets = iterable_datasets[0]
else:
chained_datasets = None
concat_datasets = (
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
)
train_datasets = concat_datasets, chained_datasets
train_datasets = tuple([x for x in train_datasets if x is not None])
train_datasets = (
train_datasets[0] if len(train_datasets) == 1 else train_datasets
)
datasets[split_name] = train_datasets
return datasets
================================================
FILE: minigpt4/datasets/datasets/__init__.py
================================================
================================================
FILE: minigpt4/datasets/datasets/aok_vqa_datasets.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from collections import OrderedDict
import json
import os
import random
import torch
from PIL import Image
from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"question": ann["question"],
"question_id": ann["question_id"],
"direct_answers": "; ".join(ann["direct_answers"]),
"choices": "; ".join(ann["choices"]),
"correct_choice": ann["choices"][ann["correct_choice_idx"]],
"image": sample["image"],
}
)
class AOKVQADataset(VQADataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.instruction_pool =[
"[vqa] {}",
"[vqa] Based on the image, respond to this question with a short answer: {}"
]
exist_annotation = []
for ann in self.annotation:
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
if os.path.exists(image_path):
exist_annotation.append(ann)
self.annotation = exist_annotation
def get_data(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
answer_key = "direct_answers"
answer_weight = {}
for answer in ann[answer_key]:
if answer in answer_weight.keys():
answer_weight[answer] += 1 / len(ann[answer_key])
else:
answer_weight[answer] = 1 / len(ann[answer_key])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
return {
"image": image,
"question": question,
"answer": answer,
}
def __getitem__(self, index):
data = self.get_data(index)
question = self.text_processor(data["question"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "
{} ".format(instruction)
answer = self.text_processor(data['answer'])
return {
"image": data['image'],
"instruction_input": instruction,
"answer": answer,
}
class AOKVQGDataset(AOKVQADataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.instruction_pool = [
'Given the image, generate a question whose answer is: {}',
'Based on the image, provide a question with the answer: {}',
'Given the visual representation, create a question for which the answer is "{}"',
'From the image provided, craft a question that leads to the reply: {}',
'Considering the picture, come up with a question where the answer is: {}',
'Taking the image into account, generate an question that has the answer: {}'
]
def __getitem__(self, index):
data = self.get_data(index)
instruction = random.choice(self.instruction_pool).format(data['answer'])
return {
"image": data['image'],
"instruction_input": instruction,
"answer": data['question'],
}
================================================
FILE: minigpt4/datasets/datasets/base_dataset.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import json
from typing import Iterable
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.dataloader import default_collate
class BaseDataset(Dataset):
def __init__(
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.annotation = []
# print("ann paths", ann_paths)
for ann_path in ann_paths:
# print("ann_path", ann_path)
ann = json.load(open(ann_path, "r"))
if isinstance(ann, dict):
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
# self.annotation.extend(json.load(open(ann_path, "r")))
else:
self.annotation.extend(json.load(open(ann_path, "r")))
self.vis_processor = vis_processor
self.text_processor = text_processor
self._add_instance_ids()
def __len__(self):
return len(self.annotation)
def collater(self, samples):
return default_collate(samples)
def set_processors(self, vis_processor, text_processor):
self.vis_processor = vis_processor
self.text_processor = text_processor
def _add_instance_ids(self, key="instance_id"):
for idx, ann in enumerate(self.annotation):
ann[key] = str(idx)
class ConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
def collater(self, samples):
# TODO For now only supports datasets with same underlying collater implementations
all_keys = set()
for s in samples:
all_keys.update(s)
shared_keys = all_keys
for s in samples:
shared_keys = shared_keys & set(s.keys())
samples_shared_keys = []
for s in samples:
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
return self.datasets[0].collater(samples_shared_keys)
================================================
FILE: minigpt4/datasets/datasets/caption_datasets.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
from collections import OrderedDict
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from PIL import Image
import random
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
}
)
class CaptionDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann["image_id"]
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __getitem__(self, index):
# TODO this assumes image input, not general enough
ann = self.annotation[index]
img_file = '{:0>12}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
caption = self.text_processor(ann["caption"])
return {
"image": image,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"]],
}
class COCOCaptionDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.img_ids = {}
n = 0
self.filter_anntation = []
for ann in self.annotation:
if "train" in ann["image"]:
self.filter_anntation.append(ann)
self.annotation = self.filter_anntation
for ann in self.annotation:
img_id = ann["image_id"]
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
self.instruction_pool = [
'Briefly describe this image.',
'Provide a concise depiction of this image.',
'Present a short description of this image.',
'Summarize this image in a few words.',
'A short image caption:',
'A short image description:',
'A photo of ',
'An image that shows ',
'Write a short description for the image. ',
'Write a description for the photo.',
'Provide a description of what is presented in the photo.',
'Briefly describe the content of the image.',
'Can you briefly explain what you see in the image?',
'Could you use a few words to describe what you perceive in the photo?',
'Please provide a short depiction of the picture.',
'Using language, provide a short account of the image.',
'Use a few words to illustrate what is happening in the picture.',
]
def __getitem__(self, index):
# TODO this assumes image input, not general enough
ann = self.annotation[index]
img_file = ann["image"].split("/")[-1]
image_path = os.path.join(self.vis_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
caption = self.text_processor(ann["caption"])
instruction = random.choice(self.instruction_pool)
instruction = "
[caption] {} ".format(instruction)
return {
"image": image,
"answer": caption,
"instruction_input": instruction,
}
class CaptionEvalDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
return {
"image": image,
"image_id": ann["image_id"],
"instance_id": ann["instance_id"],
}
================================================
FILE: minigpt4/datasets/datasets/cc_sbu_dataset.py
================================================
import os
from PIL import Image
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class CCSBUDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"answer": self.text_processor(sample[1]["caption"]),
}
class CCSBUAlignDataset(CaptionDataset):
def __getitem__(self, index):
# TODO this assumes image input, not general enough
ann = self.annotation[index]
img_file = '{}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
caption = ann["caption"]
return {
"image": image,
"answer": caption,
"image_id": self.img_ids[ann["image_id"]],
}
================================================
FILE: minigpt4/datasets/datasets/coco_caption.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import json
import torch
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset
COCOCapDataset = COCOCaptionDataset
class COCOCapEvalDataset(CaptionEvalDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
return {
"image": image,
"image_id": img_id,
"instance_id": ann["instance_id"],
}
class NoCapsEvalDataset(CaptionEvalDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
img_id = ann["img_id"]
return {
"image": image,
"image_id": img_id,
"instance_id": ann["instance_id"],
}
class RefCOCOEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
img_id = data['img_id']
sent = data['sents']
image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg')
image = Image.open(image_path).convert('RGB')
image = self.vis_processor(image)
question = f"[refer] give me the location of {sent}"
return image, question, img_id
class EvalCaptionData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
ann = dict()
for item in self.loaded_data:
image_id = item['image_id']
ann[image_id] = item['image']
self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann]
def __len__(self):
return len(self.ann)
def __getitem__(self, idx):
data = self.ann[idx]
image_id = data['image_id']
img_file = data['image'].split('/')[-1]
image_path = os.path.join(self.root_path, img_file)
image = Image.open(image_path).convert('RGB')
image = self.vis_processor(image)
question = f"[caption] please describe this image?"
return image, question, image_id
================================================
FILE: minigpt4/datasets/datasets/coco_dataset.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class ReferCOCODataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.refer = REFER(ann_path, vis_root, dataset, splitBy)
self.ref_ids = self.refer.getRefIds(split="train")
self.instruction_pool = [
"[refer] {}",
"[refer] give me the location of {}",
"[refer] where is {} ?",
"[refer] from this image, tell me the location of {}",
"[refer] the location of {} is",
"[refer] could you tell me the location for {} ?",
"[refer] where can I locate the {} ?",
]
def __len__(self):
return len(self.ref_ids)
def preprocess(self, index):
ref_id = self.ref_ids[index]
ref = self.refer.loadRefs(ref_id)[0]
image_file = 'COCO_train2014_{:0>12}.jpg'.format(ref["image_id"])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image_orig_size = image.size
image = self.vis_processor(image)
image_new_size = [image.shape[1], image.shape[2]]
image_new_size = [100,100]
sample_sentence = random.choice(ref['sentences'])['raw']
refer_sentence = self.text_processor(sample_sentence)
bbox = self.refer.getRefBox(ref['ref_id'])
bbox = [
bbox[0] / image_orig_size[0] * image_new_size[0],
bbox[1] / image_orig_size[1] * image_new_size[1],
(bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
(bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
]
bbox = [int(x) for x in bbox]
bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
return {
"image": image,
"refer_sentence": refer_sentence,
"bbox": bbox,
"image_id": ref['image_id'],
}
def __getitem__(self, index):
data = self.preprocess(index)
instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
instruction = "
{} ".format(instruction)
return {
"image": data['image'],
"instruction_input": instruction,
"answer": data['bbox'],
"image_id": data['image_id'],
}
class InvReferCOCODataset(ReferCOCODataset):
def __init__(self, *args, **kwargs):
super(InvReferCOCODataset, self).__init__(*args, **kwargs)
self.instruction_pool = [
"[identify] {}",
"[identify] what object is in this location {}",
"[identify] identify the object present at this location {}",
"[identify] what is it in {}",
"[identify] describe this object in {}",
"[identify] this {} is",
"[identify] the object in {} is",
]
def __getitem__(self, index):
data = self.preprocess(index)
instruction = random.choice(self.instruction_pool).format(data['bbox'])
instruction = "
{} ".format(instruction)
return {
"image": data['image'],
"instruction_input": instruction,
"answer": self.text_processor(data['refer_sentence']),
"image_id": data['image_id'],
}
class REFER:
def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='unc'):
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
# also provide dataset name and splitBy information
# e.g., dataset = 'refcoco', splitBy = 'unc'
dataset = dataset.split('inv')[-1] # inv dataset is stored in the same path as normal dataset
print('loading dataset %s into memory...' % dataset)
self.ann_dir = os.path.join(data_root, dataset)
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
self.vis_root = vis_root
elif dataset == 'refclef':
raise 'No RefClef image data'
else:
raise 'No refer dataset is called [%s]' % dataset
# load refs from data/dataset/refs(dataset).json
tic = time.time()
ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p')
self.data = {}
self.data['dataset'] = dataset
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
# load annotations from data/dataset/instances.json
instances_file = os.path.join(self.ann_dir, 'instances.json')
instances = json.load(open(instances_file, 'r'))
self.data['images'] = instances['images']
self.data['annotations'] = instances['annotations']
self.data['categories'] = instances['categories']
# create index
self.createIndex()
print('DONE (t=%.2fs)' % (time.time() - tic))
def createIndex(self):
# create sets of mapping
# 1) Refs: {ref_id: ref}
# 2) Anns: {ann_id: ann}
# 3) Imgs: {image_id: image}
# 4) Cats: {category_id: category_name}
# 5) Sents: {sent_id: sent}
# 6) imgToRefs: {image_id: refs}
# 7) imgToAnns: {image_id: anns}
# 8) refToAnn: {ref_id: ann}
# 9) annToRef: {ann_id: ref}
# 10) catToRefs: {category_id: refs}
# 11) sentToRef: {sent_id: ref}
# 12) sentToTokens: {sent_id: tokens}
print('creating index...')
# fetch info from instances
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
for ann in self.data['annotations']:
Anns[ann['id']] = ann
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
for img in self.data['images']:
Imgs[img['id']] = img
for cat in self.data['categories']:
Cats[cat['id']] = cat['name']
# fetch info from refs
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
Sents, sentToRef, sentToTokens = {}, {}, {}
for ref in self.data['refs']:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
category_id = ref['category_id']
image_id = ref['image_id']
# add mapping related to ref
Refs[ref_id] = ref
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
refToAnn[ref_id] = Anns[ann_id]
annToRef[ann_id] = ref
# add mapping of sent
for sent in ref['sentences']:
Sents[sent['sent_id']] = sent
sentToRef[sent['sent_id']] = ref
sentToTokens[sent['sent_id']] = sent['tokens']
# create class members
self.Refs = Refs
self.Anns = Anns
self.Imgs = Imgs
self.Cats = Cats
self.Sents = Sents
self.imgToRefs = imgToRefs
self.imgToAnns = imgToAnns
self.refToAnn = refToAnn
self.annToRef = annToRef
self.catToRefs = catToRefs
self.sentToRef = sentToRef
self.sentToTokens = sentToTokens
print('index created.')
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
image_ids = image_ids if type(image_ids) == list else [image_ids]
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
refs = self.data['refs']
else:
if not len(image_ids) == 0:
refs = [self.imgToRefs[image_id] for image_id in image_ids]
else:
refs = self.data['refs']
if not len(cat_ids) == 0:
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
if not len(ref_ids) == 0:
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
if not len(split) == 0:
if split in ['testA', 'testB', 'testC']:
refs = [ref for ref in refs if
split[-1] in ref['split']] # we also consider testAB, testBC, ...
elif split in ['testAB', 'testBC', 'testAC']:
refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
elif split == 'test':
refs = [ref for ref in refs if 'test' in ref['split']]
elif split == 'train' or split == 'val':
refs = [ref for ref in refs if ref['split'] == split]
else:
raise 'No such split [%s]' % split
ref_ids = [ref['ref_id'] for ref in refs]
return ref_ids
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
image_ids = image_ids if type(image_ids) == list else [image_ids]
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
ann_ids = [ann['id'] for ann in self.data['annotations']]
else:
if not len(image_ids) == 0:
lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
anns = list(itertools.chain.from_iterable(lists))
else:
anns = self.data['annotations']
if not len(cat_ids) == 0:
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
ann_ids = [ann['id'] for ann in anns]
if not len(ref_ids) == 0:
ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
return ann_ids
def getImgIds(self, ref_ids=[]):
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if not len(ref_ids) == 0:
image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
else:
image_ids = self.Imgs.keys()
return image_ids
def getCatIds(self):
return self.Cats.keys()
def loadRefs(self, ref_ids=[]):
if type(ref_ids) == list:
return [self.Refs[ref_id] for ref_id in ref_ids]
elif type(ref_ids) == int:
return [self.Refs[ref_ids]]
def loadAnns(self, ann_ids=[]):
if type(ann_ids) == list:
return [self.Anns[ann_id] for ann_id in ann_ids]
elif type(ann_ids) == int:
return [self.Anns[ann_ids]]
def loadImgs(self, image_ids=[]):
if type(image_ids) == list:
return [self.Imgs[image_id] for image_id in image_ids]
elif type(image_ids) == int:
return [self.Imgs[image_ids]]
def loadCats(self, cat_ids=[]):
if type(cat_ids) == list:
return [self.Cats[cat_id] for cat_id in cat_ids]
elif type(cat_ids) == int:
return [self.Cats[cat_ids]]
def getRefBox(self, ref_id):
ref = self.Refs[ref_id]
ann = self.refToAnn[ref_id]
return ann['bbox'] # [x, y, w, h]
def showRef(self, ref, seg_box='box'):
ax = plt.gca()
# show image
image = self.Imgs[ref['image_id']]
I = io.imread(os.path.join(self.vis_root, image['file_name']))
ax.imshow(I)
# show refer expression
for sid, sent in enumerate(ref['sentences']):
print('%s. %s' % (sid + 1, sent['sent']))
# show segmentations
if seg_box == 'seg':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
polygons = []
color = []
c = 'none'
if type(ann['segmentation'][0]) == list:
# polygon used for refcoco*
for seg in ann['segmentation']:
poly = np.array(seg).reshape((len(seg) / 2, 2))
polygons.append(Polygon(poly, True, alpha=0.4))
color.append(c)
p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1)
ax.add_collection(p) # thick yellow polygon
p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1)
ax.add_collection(p) # thin red polygon
else:
# mask used for refclef
raise NotImplementedError('RefClef is not downloaded')
# show bounding-box
elif seg_box == 'box':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
bbox = self.getRefBox(ref['ref_id'])
box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
ax.add_patch(box_plot)
================================================
FILE: minigpt4/datasets/datasets/coco_vqa_datasets.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import json
import random
from PIL import Image
from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
from collections import OrderedDict
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"question": ann["question"],
"question_id": ann["question_id"],
"answers": "; ".join(ann["answer"]),
"image": sample["image"],
}
)
class COCOVQADataset(VQADataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.instruction_pool =[
"[vqa] {}",
"[vqa] Based on the image, respond to this question with a short answer: {}"
]
exist_annotation = []
for ann in self.annotation:
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
if os.path.exists(image_path):
exist_annotation.append(ann)
self.annotation = exist_annotation
def get_data(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
question_id = ann["question_id"]
answer_weight = {}
for answer in ann["answer"]:
if answer in answer_weight.keys():
answer_weight[answer] += 1 / len(ann["answer"])
else:
answer_weight[answer] = 1 / len(ann["answer"])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
return {
"image": image,
"question": question,
"question_id": question_id,
"answer": answer,
}
def __getitem__(self, index):
data = self.get_data(index)
instruction = random.choice(self.instruction_pool).format(data['question'])
instruction = "
{} ".format(instruction)
return {
"image": data['image'],
"question_id": data["question_id"],
"instruction_input": instruction,
"answer": self.text_processor(data['answer']),
}
class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.instruction_pool = [
'Question: {} Short answer:',
]
self.vis_root = vis_root
self.annotation = json.load(open(ann_paths[0]))
answer_list_path = ann_paths[1]
if os.path.exists(answer_list_path):
self.answer_list = json.load(open(answer_list_path))
else:
self.answer_list = None
try:
self.coco_fmt_qust_file = ann_paths[2]
self.coco_fmt_anno_file = ann_paths[3]
except IndexError:
self.coco_fmt_qust_file = None
self.coco_fmt_anno_file = None
self.vis_processor = vis_processor
self.text_processor = text_processor
self._add_instance_ids()
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "
{} ".format(instruction)
return {
"image": image,
'image_path': image_path,
"question": question,
"question_id": ann["question_id"],
"instruction_input": instruction,
"instance_id": ann["instance_id"],
}
================================================
FILE: minigpt4/datasets/datasets/dataloader_utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import time
import random
import torch
from minigpt4.datasets.data_utils import move_to_cuda
from torch.utils.data import DataLoader
class MultiIterLoader:
"""
A simple wrapper for iterating over multiple iterators.
Args:
loaders (List[Loader]): List of Iterator loaders.
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
"""
def __init__(self, loaders, ratios=None):
# assert all loaders has __next__ method
for loader in loaders:
assert hasattr(
loader, "__next__"
), "Loader {} has no __next__ method.".format(loader)
if ratios is None:
ratios = [1.0] * len(loaders)
else:
assert len(ratios) == len(loaders)
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
self.loaders = loaders
self.ratios = ratios
def __next__(self):
# random sample from each loader by ratio
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
return next(self.loaders[loader_idx])
class PrefetchLoader(object):
"""
Modified from https://github.com/ChenRocks/UNITER.
overlap compute and cuda data transfer
(copied and then modified from nvidia apex)
"""
def __init__(self, loader):
self.loader = loader
self.stream = torch.cuda.Stream()
def __iter__(self):
loader_it = iter(self.loader)
self.preload(loader_it)
batch = self.next(loader_it)
while batch is not None:
is_tuple = isinstance(batch, tuple)
if is_tuple:
task, batch = batch
if is_tuple:
yield task, batch
else:
yield batch
batch = self.next(loader_it)
def __len__(self):
return len(self.loader)
def preload(self, it):
try:
self.batch = next(it)
except StopIteration:
self.batch = None
return
# if record_stream() doesn't work, another option is to make sure
# device inputs are created on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input,
# device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target,
# device='cuda')
# Need to make sure the memory allocated for next_* is not still in use
# by the main stream at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.batch = move_to_cuda(self.batch)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this
# side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
def next(self, it):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is not None:
record_cuda_stream(batch)
self.preload(it)
return batch
def __getattr__(self, name):
method = self.loader.__getattribute__(name)
return method
def record_cuda_stream(batch):
if isinstance(batch, torch.Tensor):
batch.record_stream(torch.cuda.current_stream())
elif isinstance(batch, list) or isinstance(batch, tuple):
for t in batch:
record_cuda_stream(t)
elif isinstance(batch, dict):
for t in batch.values():
record_cuda_stream(t)
else:
pass
class IterLoader:
"""
A wrapper to convert DataLoader as an infinite iterator.
Modified from:
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
"""
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._use_distributed = use_distributed
self._epoch = 0
@property
def epoch(self) -> int:
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __iter__(self):
return self
def __len__(self):
return len(self._dataloader)
================================================
FILE: minigpt4/datasets/datasets/flickr.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class GroundedDetailDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.instruction_pool = [
'[grounding] please describe this image in details',
'[grounding] describe this image as detailed as possible',
'[grounding] summarize this image in details',
'[grounding] give a thorough description of what you see in this image',
]
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
# image_file = 'COCO_train2014_{}.jpg'.format(info['image_id'])
image_file = '{}.jpg'.format(info['image_id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
answer = info['grounded_caption']
instruction = random.choice(self.instruction_pool)
instruction = "
{} ".format(instruction)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": info['image_id'],
}
class CaptionToObjectDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.instruction_pool = [
'[detection] {}',
]
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = '{}.jpg'.format(info['image_id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
input = info["caption"]
answer = info["output"]
instruction = random.choice(self.instruction_pool).format(input)
instruction = "
{} ".format(instruction)
print("CaptionToObject instruction", instruction)
print("CaptionToObject answer", answer)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": info['image_id'],
}
class PhraseToObjectDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.instruction_pool = [
'[detection] {}',
]
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = '{}.jpg'.format(info['image_id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
input = info["phrase"]
answer = ""+input+"
"+info["bbox"]
instruction = random.choice(self.instruction_pool).format(input)
instruction = "
{} ".format(instruction)
print("PhraseToObject instruction", instruction)
print("PhraseToObject answer", answer)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": info['image_id'],
}
================================================
FILE: minigpt4/datasets/datasets/gqa_datasets.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import json
from PIL import Image
from minigpt4.datasets.datasets.vqa_datasets import VQADataset
from collections import OrderedDict
import random
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"question": ann["question"],
"question_id": ann["question_id"],
"answers": "; ".join(ann["answer"]),
"image": sample["image"],
}
)
class GQADataset(VQADataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.instruction_pool =[
"[vqa] {}",
"[vqa] Based on the image, respond to this question with a short answer: {}"
]
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "
{} ".format(instruction)
answers = self.text_processor(ann["answer"])
return {
"image": image,
"instruction_input": instruction,
"answer": answers,
}
================================================
FILE: minigpt4/datasets/datasets/laion_dataset.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
class LaionDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"answer": self.text_processor(sample[1]["caption"]),
}
================================================
FILE: minigpt4/datasets/datasets/llava_dataset.py
================================================
import os
import json
import pickle
import random
import time
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class LlavaDetailDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
answer = info['conversations'][1]['value']
instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
instruction = '
{} '.format(self.text_processor(instruction))
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": info['id'],
}
class LlavaReasonDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
answer = info['conversations'][1]['value']
instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
instruction = '
{} '.format(self.text_processor(instruction))
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": info['id'],
}
class LlavaConversationDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.ann=[]
with open(ann_path, 'r') as f:
self.ann = json.load(f)
self.connect_sym = "!@#"
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
first_instruction = '
{} '.format(first_instruction)
questions = [first_instruction]
answers = []
for i, item in enumerate(info["conversations"][1:]):
if i % 2 ==0: # assistant
assistant_answer = item["value"]
answers.append(assistant_answer)
else:
human_instruction = item["value"]+" "
questions.append(human_instruction)
questions = self.connect_sym.join(questions)
answers = self.connect_sym.join(answers)
return {
"image": image,
"conv_q": questions,
'conv_a': answers,
"image_id": info['id'],
"connect_sym": self.connect_sym
}
================================================
FILE: minigpt4/datasets/datasets/multitask_conversation.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class MultiTaskConversationDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
with open(ann_path, 'r') as f:
self.ann = json.load(f)
self.connect_sym = "!@#"
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
first_instruction = '
{} '.format(first_instruction)
questions = [first_instruction]
answers = []
for i, item in enumerate(info["conversations"][1:]):
if i % 2 ==0: # assistant
assistant_answer = item["value"]
answers.append(assistant_answer)
else:
human_instruction = item["value"]+" "
questions.append(human_instruction)
questions = self.connect_sym.join(questions)
answers = self.connect_sym.join(answers)
return {
"image": image,
"conv_q": questions,
'conv_a': answers,
"image_id": info['id'],
"connect_sym": self.connect_sym
}
================================================
FILE: minigpt4/datasets/datasets/ocrvqa_dataset.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class OCRVQADataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.data = self.create_data(ann_path)
self.instruction_pool =[
"[vqa] {}",
"[vqa] Based on the image, respond to this question with a short answer: {}"
]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, 'r') as f:
data = json.load(f)
for k in data.keys():
if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]['imageURL'])[1]
imageFile = k + ext
assert len(data[k]['questions']) == len(data[k]['answers'])
for q, a in zip(data[k]['questions'], data[k]['answers']):
processed_data.append(
{'question': q,
'answer': a,
'image_path': imageFile,
'image_id': k,
'title': data[k]['title'],
'genre': data[k]['genre'],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(sample["question"])
answer = self.text_processor(sample["answer"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "
{} ".format(instruction)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": sample['image_id']
}
================================================
FILE: minigpt4/datasets/datasets/text_caps.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class TextCapDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.instruction_pool = [
'Briefly describe this image.',
'Provide a concise depiction of this image.',
'Present a short description of this image.',
'Summarize this image in a few words.',
'A short image caption:',
'A short image description:',
'A photo of ',
'An image that shows ',
'Write a short description for the image. ',
'Write a description for the photo.',
'Provide a description of what is presented in the photo.',
'Briefly describe the content of the image.',
'Can you briefly explain what you see in the image?',
'Could you use a few words to describe what you perceive in the photo?',
'Please provide a short depiction of the picture.',
'Using language, provide a short account of the image.',
'Use a few words to illustrate what is happening in the picture.',
]
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann["data"])
def __getitem__(self, index):
info = self.ann["data"][index]
image_file = '{}.jpg'.format(info['image_id'])
image_path = os.path.join(self.vis_root, image_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
caption = info["caption_str"]
caption = self.text_processor(caption)
instruction = "
[caption] {} ".format(random.choice(self.instruction_pool))
return {
"image": image,
"instruction_input": instruction,
"answer": caption,
}
================================================
FILE: minigpt4/datasets/datasets/unnatural_instruction.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class UnnaturalDataset(Dataset):
def __init__(self, text_processor, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.text_processor = text_processor
with open(ann_path, 'r') as f:
self.ann = json.load(f)
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
info = self.ann[index]["instances"][0]
instruction = info["instruction_with_input"]
constraints = info["constraints"]
answer = info["output"]
if constraints != None:
instruction = instruction+" "+constraints
return {
"instruction_input": self.text_processor(instruction),
"answer": self.text_processor(answer),
}
================================================
FILE: minigpt4/datasets/datasets/vg_dataset.py
================================================
import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from visual_genome import local
class ReferVisualGenomeDataset(Dataset):
def __init__(self, vis_processor, text_processor, data_dir):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.data_dir = data_dir
self.vis_processor = vis_processor
self.text_processor = text_processor
all_regions = local.get_all_region_descriptions(self.data_dir)
all_regions = [region for regions in all_regions for region in regions]
# follow OFA practice, only regions smaller than 16384 pixels are used for refer
self.regions = [region for region in all_regions if region.width * region.height < 16384]
self.instruction_pool = [
"[refer] {}",
"[refer] give me the location of {}",
"[refer] where is {} ?",
"[refer] from this image, tell me the location of {}",
"[refer] the location of {} is",
"[refer] could you tell me the location for {} ?",
"[refer] where can I locate the {} ?",
]
def __len__(self):
return len(self.regions)
def preprocess(self, index):
region = self.regions[index]
image_file = region.image.url.split('/')[-2:]
image_path = os.path.join(self.data_dir, *image_file)
image = Image.open(image_path).convert("RGB")
image_orig_size = image.size
image = self.vis_processor(image)
image_new_size = [100,100]
sample_sentence = region.phrase
refer_sentence = self.text_processor(sample_sentence)
bbox = [region.x, region.y, region.width, region.height]
bbox = [
bbox[0] / image_orig_size[0] * image_new_size[0],
bbox[1] / image_orig_size[1] * image_new_size[1],
(bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
(bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
]
bbox = [int(x) for x in bbox]
bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
return {
"image": image,
"refer_sentence": refer_sentence,
"bbox": bbox,
"image_id": region.image.id,
}
def __getitem__(self, index):
data = self.preprocess(index)
instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
instruction = "
{} ".format(instruction)
return {
"image": data['image'],
"instruction_input": instruction,
"answer": data['bbox'],
"image_id": data['image_id'],
}
================================================
FILE: minigpt4/datasets/datasets/vqa_datasets.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
from PIL import Image
import os
from minigpt4.datasets.datasets.base_dataset import BaseDataset
class VQADataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
class VQAEvalDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
class OKVQAEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
img_id = data['image_id']
question = data['question']
question_id = data['question_id']
img_file = '{:0>12}.jpg'.format(img_id)
image_path = os.path.join(self.root_path, img_file)
image = Image.open(image_path).convert('RGB')
image = self.vis_processor(image)
question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
return image, question, question_id, img_id
class VizWizEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
img_id = data['image']
question = data['question']
answers = data['answers']
answers = '_'.join([answer['answer'] for answer in answers])
image_path = os.path.join(self.root_path, img_id)
image = Image.open(image_path).convert('RGB')
image = self.vis_processor(image)
question = f"[vqa] The question is '{question}' Based on the image, answer the question with a single word or phrase. and reply 'unanswerable' when the provided information is insufficient"
return image, question, answers
class IconQAEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
image_id = data['image_id']
question = data['question']
image_path = os.path.join(self.root_path, image_id, 'image.png')
image = Image.open(image_path).convert('RGB')
image = self.vis_processor(image).half().cuda()
candidates = '_'.join(data['choices'])
answer = data['answer']
question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
return image, question, candidates, answer
class GQAEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
ann = self.loaded_data[idx]
image_id = ann["image"]
image_path = os.path.join(self.root_path, f"{image_id}")
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = ann["question"]
question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
labels = ann["answer"]
return image, question, labels
class HMEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
ann = self.loaded_data[idx]
image_id = ann["img"]
image_path = os.path.join(self.root_path, f"{image_id}")
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = ann["text"]
question = f"This is an image writting '{question}'. Is this image hateful? Answer yes or no. Answer:"
labels = ann["label"]
return image, question, labels
class VSREvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data
self.root_path = root_path
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
ann = self.loaded_data[idx]
image_path = os.path.join(self.root_path, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = ann["caption"]
question = f'[vqa] Based on the image, is this statement true or false? {question}'
labels = 'true' if ann["label"] == 1 else 'false'
return image, question, labels
================================================
FILE: minigpt4/models/Qformer.py
================================================
"""
* Copyright (c) 2023, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on huggingface code base
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
"""
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[
:, past_key_values_length : seq_length + past_key_values_length
].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(-1, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1
)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype
) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = (
attention_scores
+ relative_position_scores_query
+ relative_position_scores_key
)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = (
self.self.attention_head_size * self.self.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if (
self.config.add_cross_attention
and layer_num % self.config.cross_attention_freq == 0
):
self.crossattention = BertAttention(
config, is_cross_attention=self.config.add_cross_attention
)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.intermediate_query = BertIntermediate(config)
self.output_query = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
() if output_attentions and self.config.add_cross_attention else None
)
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
*inputs, past_key_value, output_attentions, query_length
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: Tensor,
input_shape: Tuple[int],
device: device,
is_decoder: bool,
has_query: bool = False,
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
if has_query: # UniLM style attention mask
causal_mask = torch.cat(
[
torch.zeros(
(batch_size, prefix_seq_len, seq_length),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=1,
)
causal_mask = torch.cat(
[
torch.ones(
(batch_size, causal_mask.shape[1], prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is None:
assert (
query_embeds is not None
), "You have to specify query_embeds when input_ids is None"
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length
if past_key_values is not None
else 0
)
query_length = query_embeds.shape[1] if query_embeds is not None else 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
query_embeds=query_embeds,
past_key_values_length=past_key_values_length,
)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if is_decoder:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask,
input_ids.shape,
device,
is_decoder,
has_query=(query_embeds is not None),
)
else:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0
].size()
else:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask) for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction="mean",
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
use_cache = False
if past_key_values is not None:
query_embeds = None
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
sequence_output = outputs[0]
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(
shifted_prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if reduction == "none":
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"query_embeds": query_embeds,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
================================================
FILE: minigpt4/models/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import torch
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.minigpt4 import MiniGPT4
from minigpt4.models.minigpt_v2 import MiniGPTv2
from minigpt4.processors.base_processor import BaseProcessor
__all__ = [
"load_model",
"BaseModel",
"MiniGPTBase",
"MiniGPT4",
"MiniGPTv2"
]
def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
"""
Load supported models.
To list all available models and types in registry:
>>> from minigpt4.models import model_zoo
>>> print(model_zoo)
Args:
name (str): name of the model.
model_type (str): type of the model.
is_eval (bool): whether the model is in eval mode. Default: False.
device (str): device to use. Default: "cpu".
checkpoint (str): path or to checkpoint. Default: None.
Note that expecting the checkpoint to have the same keys in state_dict as the model.
Returns:
model (torch.nn.Module): model.
"""
model = registry.get_model_class(name).from_pretrained(model_type=model_type)
if checkpoint is not None:
model.load_checkpoint(checkpoint)
if is_eval:
model.eval()
if device == "cpu":
model = model.float()
return model.to(device)
def load_preprocess(config):
"""
Load preprocessor configs and construct preprocessors.
If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
Args:
config (dict): preprocessor configs.
Returns:
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
Key is "train" or "eval" for processors used in training and evaluation respectively.
"""
def _build_proc_from_cfg(cfg):
return (
registry.get_processor_class(cfg.name).from_config(cfg)
if cfg is not None
else BaseProcessor()
)
vis_processors = dict()
txt_processors = dict()
vis_proc_cfg = config.get("vis_processor")
txt_proc_cfg = config.get("text_processor")
if vis_proc_cfg is not None:
vis_train_cfg = vis_proc_cfg.get("train")
vis_eval_cfg = vis_proc_cfg.get("eval")
else:
vis_train_cfg = None
vis_eval_cfg = None
vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
if txt_proc_cfg is not None:
txt_train_cfg = txt_proc_cfg.get("train")
txt_eval_cfg = txt_proc_cfg.get("eval")
else:
txt_train_cfg = None
txt_eval_cfg = None
txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
return vis_processors, txt_processors
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
"""
Load model and its related preprocessors.
List all available models and types in registry:
>>> from minigpt4.models import model_zoo
>>> print(model_zoo)
Args:
name (str): name of the model.
model_type (str): type of the model.
is_eval (bool): whether the model is in eval mode. Default: False.
device (str): device to use. Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model_cls = registry.get_model_class(name)
# load model
model = model_cls.from_pretrained(model_type=model_type)
if is_eval:
model.eval()
# load preprocess
cfg = OmegaConf.load(model_cls.default_config_path(model_type))
if cfg is not None:
preprocess_cfg = cfg.preprocess
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
else:
vis_processors, txt_processors = None, None
logging.info(
f"""No default preprocess for model {name} ({model_type}).
This can happen if the model is not finetuned on downstream datasets,
or it is not intended for direct use without finetuning.
"""
)
if device == "cpu" or device == torch.device("cpu"):
model = model.float()
return model.to(device), vis_processors, txt_processors
class ModelZoo:
"""
A utility class to create string representation of available model architectures and types.
>>> from minigpt4.models import model_zoo
>>> # list all available models
>>> print(model_zoo)
>>> # show total number of models
>>> print(len(model_zoo))
"""
def __init__(self) -> None:
self.model_zoo = {
k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
for k, v in registry.mapping["model_name_mapping"].items()
}
def __str__(self) -> str:
return (
"=" * 50
+ "\n"
+ f"{'Architectures':<30} {'Types'}\n"
+ "=" * 50
+ "\n"
+ "\n".join(
[
f"{name:<30} {', '.join(types)}"
for name, types in self.model_zoo.items()
]
)
)
def __iter__(self):
return iter(self.model_zoo.items())
def __len__(self):
return sum([len(v) for v in self.model_zoo.values()])
model_zoo = ModelZoo()
================================================
FILE: minigpt4/models/base_model.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import logging
import contextlib
from omegaconf import OmegaConf
import numpy as np
import torch
import torch.nn as nn
from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
)
from minigpt4.common.dist_utils import download_cached_file
from minigpt4.common.utils import get_abs_path, is_url
from minigpt4.models.eva_vit import create_eva_vit_g
from minigpt4.models.modeling_llama import LlamaForCausalLM
class BaseModel(nn.Module):
"""Base class for models."""
def __init__(self):
super().__init__()
@property
def device(self):
return list(self.parameters())[-1].device
def load_checkpoint(self, url_or_filename):
"""
Load from a finetuned checkpoint.
This should expect no mismatch in the model keys and the checkpoint keys.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
if "model" in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
@classmethod
def from_pretrained(cls, model_type):
"""
Build a pretrained model from default configuration file, specified by model_type.
Args:
- model_type (str): model type, specifying architecture and checkpoints.
Returns:
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
"""
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
model = cls.from_config(model_cfg)
return model
@classmethod
def default_config_path(cls, model_type):
assert (
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
), "Unknown model type {}".format(model_type)
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
def load_checkpoint_from_config(self, cfg, **kwargs):
"""
Load checkpoint as specified in the config file.
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
When loading the pretrained model, each task-specific architecture may define their
own load_from_pretrained() method.
"""
load_finetuned = cfg.get("load_finetuned", True)
if load_finetuned:
finetune_path = cfg.get("finetuned", None)
assert (
finetune_path is not None
), "Found load_finetuned is True, but finetune_path is None."
self.load_checkpoint(url_or_filename=finetune_path)
else:
# load pre-trained weights
pretrain_path = cfg.get("pretrained", None)
assert "Found load_finetuned is False, but pretrain_path is None."
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
def before_evaluation(self, **kwargs):
pass
def show_n_params(self, return_str=True):
tot = 0
for p in self.parameters():
w = 1
for x in p.shape:
w *= x
tot += w
if return_str:
if tot >= 1e6:
return "{:.1f}M".format(tot / 1e6)
else:
return "{:.1f}K".format(tot / 1e3)
else:
return tot
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
):
logging.info('Loading VIT')
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
if not freeze:
precision = "fp32" # fp16 is not for training
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
if freeze:
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
for name, param in ln_vision.named_parameters():
param.requires_grad = False
ln_vision = ln_vision.eval()
ln_vision.train = disabled_train
logging.info("freeze vision encoder")
logging.info('Loading VIT Done')
return visual_encoder, ln_vision
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA')
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
llama_tokenizer.pad_token = "$$"
if low_resource:
llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map={'': low_res_device}
)
else:
llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
)
if lora_r > 0:
llama_model = prepare_model_for_int8_training(llama_model)
loraconfig = LoraConfig(
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
**lora_kargs
)
llama_model = get_peft_model(llama_model, loraconfig)
llama_model.print_trainable_parameters()
else:
for name, param in llama_model.named_parameters():
param.requires_grad = False
logging.info('Loading LLAMA Done')
return llama_model, llama_tokenizer
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
================================================
FILE: minigpt4/models/eva_vit.py
================================================
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from minigpt4.common.dist_utils import download_cached_file
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
cached_file = download_cached_file(
url, check_hash=False, progress=True
)
state_dict = torch.load(cached_file, map_location="cpu")
interpolate_pos_embed(model,state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
# print(incompatible_keys)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model
================================================
FILE: minigpt4/models/minigpt4.py
================================================
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import disabled_train
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
@registry.register_model("minigpt4")
class MiniGPT4(MiniGPTBase):
"""
MiniGPT-4 model
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
has_qformer=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
)
self.has_qformer = has_qformer
if self.has_qformer:
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features, freeze_qformer
)
self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
img_f_dim = self.Qformer.config.hidden_size
print('Loading Q-Former Done')
else:
img_f_dim = self.visual_encoder.num_features * 4
print('Do not use Q-Former here.')
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
@classmethod
def init_Qformer(cls, num_query_token, vision_width, freeze):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = 2
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
if freeze:
for name, param in Qformer.named_parameters():
param.requires_grad = False
Qformer = Qformer.eval()
Qformer.train = disabled_train
query_tokens.requires_grad = False
logging.info("freeze Qformer")
return Qformer, query_tokens
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
model = cls(
vit_model=vit_model,
q_former_model=q_former_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
has_qformer=has_qformer,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
================================================
FILE: minigpt4/models/minigpt_base.py
================================================
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList
from minigpt4.conversation.conversation import StoppingCriteriaSub
class MiniGPTBase(BaseModel):
"""
Base class for MiniGPT-4 and MiniGPT-v2
"""
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
max_txt_len=32,
max_context_len=3800,
prompt_template="",
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0, # lora_r means lora is not used
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
):
super().__init__()
self.llama_model, self.llama_tokenizer = self.init_llm(
llama_model_path=llama_model,
low_resource=low_resource,
low_res_device=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
)
self.max_txt_len = max_txt_len
self.max_context_len = max_context_len
self.end_sym = end_sym
self.prompt_template = prompt_template
self.prompt_list = []
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
def get_context_emb(self, prompt, img_list):
device = img_list[0].device
prompt_segs = prompt.split('')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(
p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def tokenize_conversation(self, conv_q, conv_a):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
to_regress_token_ids_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
answers = [self.llama_tokenizer(a + self.end_sym,
return_tensors="pt",
add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):
cur_id.append(answers[i].input_ids)
cur_target.append(answers[i].input_ids)
cur_id.append(questions[i].input_ids)
cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
cur_id.append(answers[-1].input_ids)
cur_target.append(answers[-1].input_ids)
cur_id = torch.cat(cur_id, dim=1)
cur_target = torch.cat(cur_target, dim=1)
to_regress_token_ids_list.append(cur_id)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
to_regress_token_ids = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
targets = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = to_regress_token_ids_list[batch_idx].shape[1]
to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
return to_regress_token_ids, to_regress_token_attn, targets
def preparing_embedding(self, samples):
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
else:
if "instruction_input" in samples:
instruction = samples["instruction_input"]
elif self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = None
if hasattr(self, 'chat_template') and self.chat_template:
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction='mean'):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = cond_atts[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@torch.no_grad()
def generate(
self,
images,
texts,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
img_embeds, atts_img = self.encode_img(images.to(self.device))
image_lists = [[image_emb[None]] for image_emb in img_embeds]
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
)
# with self.maybe_autocast():
# outputs = self.llama_model.generate(
# inputs_embeds=embs,
# attention_mask=attn_mask,
# max_new_tokens=max_new_tokens,
# num_beams=num_beams,
# do_sample=do_sample,
# # stopping_criteria=stopping_criteria,
# )
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('')[0] # remove the stop sign
output_texts = output_texts.replace("", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
================================================
FILE: minigpt4/models/minigpt_v2.py
================================================
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import disabled_train
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
@registry.register_model("minigpt_v2")
class MiniGPTv2(MiniGPTBase):
"""
MiniGPT-v2 model
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain": "configs/models/minigpt_v2.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=448,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_template='[INST] {} [/INST]',
max_txt_len=300,
end_sym='\n',
lora_r=64,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
chat_template=False,
use_grad_checkpoint_llm=False,
max_context_len=3800,
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
max_txt_len=max_txt_len,
max_context_len=max_context_len,
end_sym=end_sym,
prompt_template=prompt_template,
low_resource=low_resource,
device_8bit=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
img_f_dim = self.visual_encoder.num_features * 4
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
self.chat_template = chat_template
if use_grad_checkpoint_llm:
self.llama_model.gradient_checkpointing_enable()
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
img_size = cfg.get("image_size")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
low_resource = cfg.get("low_resource", False)
prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r", 64)
lora_alpha = cfg.get("lora_alpha", 16)
chat_template = cfg.get("chat_template", False)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r=lora_r,
lora_alpha=lora_alpha,
chat_template=chat_template,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
================================================
FILE: minigpt4/models/modeling_llama.py
================================================
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
================================================
FILE: minigpt4/processors/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.blip_processors import (
Blip2ImageTrainProcessor,
Blip2ImageEvalProcessor,
BlipCaptionProcessor,
)
from minigpt4.common.registry import registry
__all__ = [
"BaseProcessor",
"Blip2ImageTrainProcessor",
"Blip2ImageEvalProcessor",
"BlipCaptionProcessor",
]
def load_processor(name, cfg=None):
"""
Example
>>> processor = load_processor("alpro_video_train", cfg=None)
"""
processor = registry.get_processor_class(name).from_config(cfg)
return processor
================================================
FILE: minigpt4/processors/base_processor.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
return cls()
def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)
return self.from_config(cfg)
================================================
FILE: minigpt4/processors/blip_processors.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import re
from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.randaugment import RandomAugment
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
@registry.register_processor("blip_caption")
class BlipCaptionProcessor(BaseProcessor):
def __init__(self, prompt="", max_words=50):
self.prompt = prompt
self.max_words = max_words
def __call__(self, caption):
caption = self.prompt + self.pre_caption(caption)
return caption
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
prompt = cfg.get("prompt", "")
max_words = cfg.get("max_words", 50)
return cls(prompt=prompt, max_words=max_words)
def pre_caption(self, caption):
caption = re.sub(
r"([.!\"()*#:;~])",
" ",
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
" ",
caption,
)
caption = caption.rstrip("\n")
caption = caption.strip(" ")
# truncate caption
caption_words = caption.split(" ")
if len(caption_words) > self.max_words:
caption = " ".join(caption_words[: self.max_words])
return caption
@registry.register_processor("blip2_image_train")
class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size,image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
)
@registry.register_processor("blip2_image_eval")
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)
================================================
FILE: minigpt4/processors/randaugment.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import cv2
import numpy as np
import torch
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
"""
same output as PIL.ImageOps.autocontrast
"""
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
"""
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
"""
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
"""
like PIL, rotate by degree, not radians
"""
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
"""
same output as PIL.ImageOps.posterize
"""
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
"""
same output as PIL.ImageEnhance.Color
"""
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = np.float32(
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
) * factor + np.float32([[0.114], [0.587], [0.299]])
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = (
np.array([(el - mean) * factor + mean for el in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
out = table[img]
return out
def brightness_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
"""
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
"""
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def posterize_func(img, bits):
"""
same output as PIL.ImageOps.posterize
"""
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level,)
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level,)
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
"Identity": identity_func,
"AutoContrast": autocontrast_func,
"Equalize": equalize_func,
"Rotate": rotate_func,
"Solarize": solarize_func,
"Color": color_func,
"Contrast": contrast_func,
"Brightness": brightness_func,
"Sharpness": sharpness_func,
"ShearX": shear_x_func,
"TranslateX": translate_x_func,
"TranslateY": translate_y_func,
"Posterize": posterize_func,
"ShearY": shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
"Identity": none_level_to_args,
"AutoContrast": none_level_to_args,
"Equalize": none_level_to_args,
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
"Solarize": solarize_level_to_args(MAX_LEVEL),
"Color": enhance_level_to_args(MAX_LEVEL),
"Contrast": enhance_level_to_args(MAX_LEVEL),
"Brightness": enhance_level_to_args(MAX_LEVEL),
"Sharpness": enhance_level_to_args(MAX_LEVEL),
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"Posterize": posterize_level_to_args(MAX_LEVEL),
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
class VideoRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert (
frames.shape[-1] == 3
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(
list(map(self._aug, frames, ops, apply_or_not)), dim=0
).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
if __name__ == "__main__":
a = RandomAugment()
img = np.random.randn(32, 32, 3)
a(img)
================================================
FILE: minigpt4/runners/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.runners.runner_base import RunnerBase
__all__ = ["RunnerBase"]
================================================
FILE: minigpt4/runners/runner_base.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import json
import logging
import os
import time
from pathlib import Path
import torch
import torch.distributed as dist
import webdataset as wds
from minigpt4.common.dist_utils import (
download_cached_file,
get_rank,
get_world_size,
is_main_process,
main_process,
)
from minigpt4.common.registry import registry
from minigpt4.common.utils import is_url
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
from minigpt4.datasets.datasets.dataloader_utils import (
IterLoader,
MultiIterLoader,
PrefetchLoader,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
@registry.register_runner("runner_base")
class RunnerBase:
"""
A runner class to train and evaluate a model given a task and datasets.
The runner uses pytorch distributed data parallel by default. Future release
will support other distributed frameworks.
"""
def __init__(self, cfg, task, model, datasets, job_id):
self.config = cfg
self.job_id = job_id
self.task = task
self.datasets = datasets
self._model = model
self._wrapped_model = None
self._device = None
self._optimizer = None
self._scaler = None
self._dataloaders = None
self._lr_sched = None
self.start_epoch = 0
# self.setup_seeds()
self.setup_output_dir()
@property
def device(self):
if self._device is None:
self._device = torch.device(self.config.run_cfg.device)
return self._device
@property
def use_distributed(self):
return self.config.run_cfg.distributed
@property
def model(self):
"""
A property to get the DDP-wrapped model on the device.
"""
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
# TODO make optimizer class and configurations
if self._optimizer is None:
num_parameters = 0
p_wd, p_non_wd = [], []
for n, p in self.model.named_parameters():
if not p.requires_grad:
continue # frozen weights
print(n)
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
p_non_wd.append(p)
else:
p_wd.append(p)
num_parameters += p.data.nelement()
logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [
{
"params": p_wd,
"weight_decay": float(self.config.run_cfg.weight_decay),
},
{"params": p_non_wd, "weight_decay": 0},
]
beta2 = self.config.run_cfg.get("beta2", 0.999)
self._optimizer = torch.optim.AdamW(
optim_params,
lr=float(self.config.run_cfg.init_lr),
weight_decay=float(self.config.run_cfg.weight_decay),
betas=(0.9, beta2),
)
return self._optimizer
@property
def scaler(self):
amp = self.config.run_cfg.get("amp", False)
if amp:
if self._scaler is None:
self._scaler = torch.cuda.amp.GradScaler()
return self._scaler
@property
def lr_scheduler(self):
"""
A property to get and create learning rate scheduler by split just in need.
"""
if self._lr_sched is None:
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
# max_epoch = self.config.run_cfg.max_epoch
max_epoch = self.max_epoch
# min_lr = self.config.run_cfg.min_lr
min_lr = self.min_lr
# init_lr = self.config.run_cfg.init_lr
init_lr = self.init_lr
# optional parameters
decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
if iters_per_epoch is None:
try:
iters_per_epoch = len(self.dataloaders['train'])
except (AttributeError, TypeError):
iters_per_epoch = 10000
self._lr_sched = lr_sched_cls(
optimizer=self.optimizer,
max_epoch=max_epoch,
iters_per_epoch=iters_per_epoch,
min_lr=min_lr,
init_lr=init_lr,
decay_rate=decay_rate,
warmup_start_lr=warmup_start_lr,
warmup_steps=warmup_steps,
)
return self._lr_sched
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# concatenate map-style datasets and chain wds.DataPipe datasets separately
# training set becomes a tuple (ConcatDataset, ChainDataset), both are
# optional but at least one of them is required. The resultant ConcatDataset
# and ChainDataset will be sampled evenly.
logging.info(
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
)
batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
for dataset_name in self.datasets.keys()}
datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
self.datasets = datasets
# self.datasets = concat_datasets(datasets)
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
if hasattr(self.datasets[split_name], "__len__"):
# a single map-style dataset
num_records = len(self.datasets[split_name])
else:
# a single wds.DataPipeline
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
batch_sizes = [batch_sizes[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
print("batch sizes", batch_sizes)
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
)
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders
@property
def cuda_enabled(self):
return self.device.type == "cuda"
@property
def max_epoch(self):
return int(self.config.run_cfg.max_epoch)
@property
def log_freq(self):
log_freq = self.config.run_cfg.get("log_freq", 50)
return int(log_freq)
@property
def init_lr(self):
return float(self.config.run_cfg.init_lr)
@property
def min_lr(self):
return float(self.config.run_cfg.min_lr)
@property
def accum_grad_iters(self):
return int(self.config.run_cfg.get("accum_grad_iters", 1))
@property
def valid_splits(self):
valid_splits = self.config.run_cfg.get("valid_splits", [])
if len(valid_splits) == 0:
logging.info("No validation splits found.")
return valid_splits
@property
def test_splits(self):
test_splits = self.config.run_cfg.get("test_splits", [])
return test_splits
@property
def train_splits(self):
train_splits = self.config.run_cfg.get("train_splits", [])
if len(train_splits) == 0:
logging.info("Empty train splits.")
return train_splits
@property
def evaluate_only(self):
"""
Set to True to skip training.
"""
return self.config.run_cfg.evaluate
@property
def use_dist_eval_sampler(self):
return self.config.run_cfg.get("use_dist_eval_sampler", True)
@property
def resume_ckpt_path(self):
return self.config.run_cfg.get("resume_ckpt_path", None)
@property
def train_loader(self):
train_dataloader = self.dataloaders["train"]
return train_dataloader
def setup_output_dir(self):
lib_root = Path(registry.get_path("library_root"))
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
# output_dir = lib_root / self.config.run_cfg.output_dir
result_dir = output_dir / "result"
output_dir.mkdir(parents=True, exist_ok=True)
result_dir.mkdir(parents=True, exist_ok=True)
registry.register_path("result_dir", str(result_dir))
registry.register_path("output_dir", str(output_dir))
self.result_dir = result_dir
self.output_dir = output_dir
def train(self):
start_time = time.time()
best_agg_metric = 0
best_epoch = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
for cur_epoch in range(self.start_epoch, self.max_epoch):
# training phase
if not self.evaluate_only:
logging.info("Start training")
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "val":
best_epoch, best_agg_metric = cur_epoch, agg_metrics
self._save_checkpoint(cur_epoch, is_best=True)
val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, split_name)
else:
# if no validation split is provided, we just save the checkpoint at the end of each epoch.
if not self.evaluate_only:
self._save_checkpoint(cur_epoch, is_best=False)
if self.evaluate_only:
break
if self.config.run_cfg.distributed:
dist.barrier()
# testing phase
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def evaluate(self, cur_epoch="best", skip_reload=False):
test_logs = dict()
if len(self.test_splits) > 0:
for split_name in self.test_splits:
test_logs[split_name] = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
)
return test_logs
def train_epoch(self, epoch):
# train
self.model.train()
return self.task.train_epoch(
epoch=epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@torch.no_grad()
def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
"""
Evaluate the model on a given split.
Args:
split_name (str): name of the split to evaluate on.
cur_epoch (int): current epoch.
skip_reload_best (bool): whether to skip reloading the best checkpoint.
During training, we will reload the best checkpoint for validation.
During testing, we will use provided weights and skip reloading the best checkpoint .
"""
data_loader = self.dataloaders.get(split_name, None)
assert data_loader, "data_loader for split {} is None.".format(split_name)
# TODO In validation, you need to compute loss as well as metrics
# TODO consider moving to model.before_evaluation()
model = self.unwrap_dist_model(self.model)
if not skip_reload and cur_epoch == "best":
model = self._reload_best_model(model)
model.eval()
self.task.before_evaluation(
model=model,
dataset=self.datasets[split_name],
)
results = self.task.evaluation(model, data_loader)
if results is not None:
return self.task.after_evaluation(
val_result=results,
split_name=split_name,
epoch=cur_epoch,
)
def unwrap_dist_model(self, model):
if self.use_distributed:
return model.module
else:
return model
def create_loaders(
self,
datasets,
num_workers,
batch_sizes,
is_trains,
collate_fns,
dataset_ratios=None,
):
"""
Create dataloaders for training and validation.
"""
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split
if isinstance(dataset, ChainDataset) or isinstance(
dataset, wds.DataPipeline
):
# wds.WebdDataset instance are chained together
# webdataset.DataPipeline has its own sampler and collate_fn
loader = iter(
DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
)
)
else:
# map-style dataset are concatenated together
# setup distributed sampler
if self.use_distributed:
sampler = DistributedSampler(
dataset,
shuffle=is_train,
num_replicas=get_world_size(),
rank=get_rank(),
)
if not self.use_dist_eval_sampler:
# e.g. retrieval evaluation
sampler = sampler if is_train else None
else:
sampler = None
loader = DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and is_train,
collate_fn=collate_fn,
drop_last=True if is_train else False,
)
loader = PrefetchLoader(loader)
if is_train:
loader = IterLoader(loader, use_distributed=self.use_distributed)
return loader
loaders = []
for dataset, bsz, is_train, collate_fn in zip(
datasets, batch_sizes, is_trains, collate_fns
):
if isinstance(dataset, list) or isinstance(dataset, tuple):
if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
dataset_ratios = [d.sample_ratio for d in dataset]
loader = MultiIterLoader(
loaders=[
_create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
for i, d in enumerate(dataset)
],
ratios=dataset_ratios,
)
else:
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
loaders.append(loader)
return loaders
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False):
"""
Save the checkpoint at the current epoch.
"""
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
)
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
def _reload_best_model(self, model):
"""
Load the best checkpoint for evaluation.
"""
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
logging.info("Loading checkpoint from {}.".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location="cpu")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
logging.warning(
"""
Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
Trying to load the model with strict=False.
"""
)
model.load_state_dict(checkpoint["model"], strict=False)
return model
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"] + 1
print("resume the checkpoint")
logging.info("Resume checkpoint from {}".format(url_or_filename))
@main_process
def log_stats(self, stats, split_name):
if isinstance(stats, dict):
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
elif isinstance(stats, list):
pass
@main_process
def log_config(self):
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
================================================
FILE: minigpt4/tasks/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.common.registry import registry
from minigpt4.tasks.base_task import BaseTask
from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
def setup_task(cfg):
assert "task" in cfg.run_cfg, "Task name must be provided."
task_name = cfg.run_cfg.task
task = registry.get_task_class(task_name).setup_task(cfg=cfg)
assert task is not None, "Task {} not properly registered.".format(task_name)
return task
__all__ = [
"BaseTask",
"ImageTextPretrainTask",
]
================================================
FILE: minigpt4/tasks/base_task.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import torch
import torch.distributed as dist
from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
from minigpt4.common.logger import MetricLogger, SmoothedValue
from minigpt4.common.registry import registry
from minigpt4.datasets.data_utils import prepare_sample
import wandb
class BaseTask:
def __init__(self, **kwargs):
super().__init__()
self.inst_id_key = "instance_id"
self.cfg = ""
@classmethod
def setup_task(cls, **kwargs):
return cls()
def build_model(self, cfg):
self.cfg = cfg
model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)
return model_cls.from_config(model_config)
def build_datasets(self, cfg):
"""
Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
Download dataset and annotations automatically if not exist.
Args:
cfg (common.config.Config): _description_
Returns:
dict: Dictionary of torch.utils.data.Dataset objects by split.
"""
datasets = dict()
datasets_config = cfg.datasets_cfg
assert len(datasets_config) > 0, "At least one dataset has to be specified."
for name in datasets_config:
dataset_config = datasets_config[name]
builder = registry.get_builder_class(name)(dataset_config)
dataset = builder.build_datasets()
dataset['train'].name = name
if 'sample_ratio' in dataset_config:
dataset['train'].sample_ratio = dataset_config.sample_ratio
datasets[name] = dataset
return datasets
def train_step(self, model, samples):
loss = model(samples)["loss"]
return loss
def valid_step(self, model, samples):
raise NotImplementedError
def before_evaluation(self, model, dataset, **kwargs):
model.before_evaluation(dataset=dataset, task_type=type(self))
def after_evaluation(self, **kwargs):
pass
def inference_step(self):
raise NotImplementedError
def evaluation(self, model, data_loader, cuda_enabled=True):
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation"
# TODO make it configurable
print_freq = 10
results = []
for samples in metric_logger.log_every(data_loader, print_freq, header):
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
eval_output = self.valid_step(model=model, samples=samples)
results.extend(eval_output)
if is_dist_avail_and_initialized():
dist.barrier()
return results
def train_epoch(
self,
epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
cuda_enabled=False,
log_freq=50,
accum_grad_iters=1,
):
return self._train_inner_loop(
epoch=epoch,
iters_per_epoch=lr_scheduler.iters_per_epoch,
model=model,
data_loader=data_loader,
optimizer=optimizer,
scaler=scaler,
lr_scheduler=lr_scheduler,
log_freq=log_freq,
cuda_enabled=cuda_enabled,
accum_grad_iters=accum_grad_iters,
)
def train_iters(
self,
epoch,
start_iters,
iters_per_inner_epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
cuda_enabled=False,
log_freq=50,
accum_grad_iters=1,
):
return self._train_inner_loop(
epoch=epoch,
start_iters=start_iters,
iters_per_epoch=iters_per_inner_epoch,
model=model,
data_loader=data_loader,
optimizer=optimizer,
scaler=scaler,
lr_scheduler=lr_scheduler,
log_freq=log_freq,
cuda_enabled=cuda_enabled,
accum_grad_iters=accum_grad_iters,
)
def _train_inner_loop(
self,
epoch,
iters_per_epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
start_iters=None,
log_freq=50,
cuda_enabled=False,
accum_grad_iters=1,
):
"""
An inner training loop compatible with both epoch-based and iter-based training.
When using epoch-based, training stops after one epoch; when using iter-based,
training stops after #iters_per_epoch iterations.
"""
use_amp = scaler is not None
if not hasattr(data_loader, "__next__"):
# convert to iterator if not already
data_loader = iter(data_loader)
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
# if iter-based runner, schedule lr based on inner epoch.
logging.info(
"Start training epoch {}, {} iters per inner epoch.".format(
epoch, iters_per_epoch
)
)
header = "Train: data epoch: [{}]".format(epoch)
if start_iters is None:
# epoch-based runner
inner_epoch = epoch
else:
# In iter-based runner, we schedule the learning rate based on iterations.
inner_epoch = start_iters // iters_per_epoch
header = header + "; inner epoch [{}]".format(inner_epoch)
for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
# if using iter-based runner, we stop after iters_per_epoch iterations.
if i >= iters_per_epoch:
break
samples = next(data_loader)
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
samples.update(
{
"epoch": inner_epoch,
"num_iters_per_epoch": iters_per_epoch,
"iters": i,
}
)
lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
with torch.cuda.amp.autocast(enabled=use_amp):
loss = self.train_step(model=model, samples=samples)
# after_train_step()
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
# update gradients every accum_grad_iters iterations
if (i + 1) % accum_grad_iters == 0:
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
# if self.cfg.wandb_log:
if self.cfg.run_cfg.wandb_log:
wandb.log({"epoch": inner_epoch, "loss": loss})
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# after train_epoch()
# gather the stats from all processes
metric_logger.synchronize_between_processes()
logging.info("Averaged stats: " + str(metric_logger.global_avg()))
return {
k: "{:.3f}".format(meter.global_avg)
for k, meter in metric_logger.meters.items()
}
@staticmethod
def save_result(result, result_dir, filename, remove_duplicate=""):
import json
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, get_rank())
)
final_result_file = os.path.join(result_dir, "%s.json" % filename)
json.dump(result, open(result_file, "w"))
if is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
logging.warning("rank %d starts merging results." % get_rank())
# combine results from all processes
result = []
for rank in range(get_world_size()):
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, rank)
)
res = json.load(open(result_file, "r"))
result += res
if remove_duplicate:
result_new = []
id_list = []
for res in result:
if res[remove_duplicate] not in id_list:
id_list.append(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result, open(final_result_file, "w"))
print("result file saved to %s" % final_result_file)
return final_result_file
================================================
FILE: minigpt4/tasks/image_text_pretrain.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.common.registry import registry
from minigpt4.tasks.base_task import BaseTask
@registry.register_task("image_text_pretrain")
class ImageTextPretrainTask(BaseTask):
def __init__(self):
super().__init__()
def evaluation(self, model, data_loader, cuda_enabled=True):
pass
================================================
FILE: train.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb
import minigpt4.tasks as tasks
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
from minigpt4.common.logger import setup_logger
from minigpt4.common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
)
from minigpt4.common.registry import registry
from minigpt4.common.utils import now
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
def get_runner_class(cfg):
"""
Get runner class from config. Default to epoch-based runner.
"""
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
return runner_cls
def main():
# allow auto-dl completes on main process without timeout when using NCCL backend.
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
job_id = now()
args = parse_args()
cfg = Config(args)
init_distributed_mode(cfg.run_cfg)
setup_seeds(cfg)
# set after init_distributed_mode() to only log on master.
setup_logger()
cfg.pretty_print()
task = tasks.setup_task(cfg)
datasets = task.build_datasets(cfg)
model = task.build_model(cfg)
if cfg.run_cfg.wandb_log:
wandb.login()
wandb.init(project="minigptv", name=cfg.run_cfg.job_name)
wandb.watch(model)
runner = get_runner_class(cfg)(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)
runner.train()
if __name__ == "__main__":
main()
================================================
FILE: train_configs/minigpt4_llama2_stage1_pretrain.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_llama2
datasets:
laion:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 14
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/minigpt4_stage1_pretrain"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
wandb_log: True
job_name: minigpt4_llama2_pretrain
================================================
FILE: train_configs/minigpt4_llama2_stage2_finetune.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: ""
prompt_path: "prompts/alignment.txt"
prompt_template: '[INST] {} [/INST] '
ckpt: '/path/to/stage1/checkpoint/'
datasets:
cc_sbu_align:
batch_size: 12
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 3e-5
min_lr: 1e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
num_workers: 4
warmup_steps: 200
seed: 42
output_dir: "output/minigpt4_stage2_finetune"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
wandb_log: True
job_name: minigpt4_llama2_finetune
================================================
FILE: train_configs/minigpt4_stage1_pretrain.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_vicuna0
datasets:
laion:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 14
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/minigpt4_stage1_pretrain"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
wandb_log: True
job_name: minigpt4_pretrain
================================================
FILE: train_configs/minigpt4_stage2_finetune.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/stage1/checkpoint/'
datasets:
cc_sbu_align:
batch_size: 12
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 3e-5
min_lr: 1e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
num_workers: 4
warmup_steps: 200
seed: 42
output_dir: "output/minigpt4_stage2_finetune"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
wandb_log: True
job_name: minigpt4_finetune
================================================
FILE: train_configs/minigptv2_finetune.yaml
================================================
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 1024
image_size: 448
end_sym: ""
llama_model: "/path/to/llama_checkpoint"
ckpt: "/path/to/pretrained_checkpoint"
use_grad_checkpoint: True
chat_template: True
lora_r: 64
lora_alpha: 16
datasets:
multitask_conversation:
batch_size: 2
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 50
llava_conversation:
batch_size: 2
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 30
unnatural_instruction:
batch_size: 1
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 10
refvg:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 40
llava_detail:
batch_size: 4
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 20
llava_reason:
batch_size: 4
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 80
flickr_grounded_caption:
batch_size: 2
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 80
flickr_CaptionToPhrase:
batch_size: 2
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 80
flickr_ObjectToPhrase:
batch_size: 2
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 80
coco_caption:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 10
textcaps_caption: #
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 30
refcoco:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 25
refcocop:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 25
refcocog:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 25
invrefcoco:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 10
invrefcocop:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 10
invrefcocog:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 10
coco_vqa:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 15
ok_vqa:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 8
aok_vqa:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 12
gqa:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 50
ocrvqa:
batch_size: 6
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
text_processor:
train:
name: "blip_caption"
sample_ratio: 30
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-5
min_lr: 1e-6
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 50
num_workers: 6
warmup_steps: 1000
iters_per_epoch: 1000
seed: 42
output_dir: "/path/to/save_checkpoint"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
wandb_log: True
job_name: minigptv2_finetune