Showing preview only (527K chars total). Download the full file or copy to clipboard to get everything.
Repository: Vision-CAIR/MiniGPT-4
Branch: main
Commit: d94738a7626e
Files: 126
Total size: 491.5 KB
Directory structure:
gitextract_bx6vsnx8/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE.md
├── LICENSE_Lavis.md
├── MiniGPT4_Train.md
├── MiniGPTv2_Train.md
├── README.md
├── SECURITY.md
├── dataset/
│ ├── README_1_STAGE.md
│ ├── README_2_STAGE.md
│ ├── README_MINIGPTv2_FINETUNE.md
│ ├── convert_cc_sbu.py
│ └── convert_laion.py
├── demo.py
├── demo_v2.py
├── environment.yml
├── eval_configs/
│ ├── minigpt4_eval.yaml
│ ├── minigpt4_llama2_eval.yaml
│ ├── minigptv2_benchmark_evaluation.yaml
│ └── minigptv2_eval.yaml
├── eval_scripts/
│ ├── EVAL_README.md
│ ├── eval_ref.py
│ └── eval_vqa.py
├── minigpt4/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dist_utils.py
│ │ ├── eval_utils.py
│ │ ├── gradcam.py
│ │ ├── logger.py
│ │ ├── optims.py
│ │ ├── registry.py
│ │ ├── utils.py
│ │ └── vqa_tools/
│ │ ├── VQA/
│ │ │ ├── PythonEvaluationTools/
│ │ │ │ ├── vqaEvalDemo.py
│ │ │ │ └── vqaEvaluation/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqaEval.py
│ │ │ ├── PythonHelperTools/
│ │ │ │ ├── vqaDemo.py
│ │ │ │ └── vqaTools/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqa.py
│ │ │ ├── QuestionTypes/
│ │ │ │ ├── abstract_v002_question_types.txt
│ │ │ │ └── mscoco_question_types.txt
│ │ │ ├── README.md
│ │ │ └── license.txt
│ │ ├── __init__.py
│ │ ├── vqa.py
│ │ └── vqa_eval.py
│ ├── configs/
│ │ ├── datasets/
│ │ │ ├── aokvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── cc_sbu/
│ │ │ │ ├── align.yaml
│ │ │ │ └── defaults.yaml
│ │ │ ├── coco/
│ │ │ │ ├── caption.yaml
│ │ │ │ └── defaults_vqa.yaml
│ │ │ ├── coco_bbox/
│ │ │ │ ├── invrefcoco.yaml
│ │ │ │ ├── invrefcocog.yaml
│ │ │ │ ├── invrefcocop.yaml
│ │ │ │ ├── refcoco.yaml
│ │ │ │ ├── refcocog.yaml
│ │ │ │ └── refcocop.yaml
│ │ │ ├── flickr/
│ │ │ │ ├── caption_to_phrase.yaml
│ │ │ │ ├── default.yaml
│ │ │ │ └── object_to_phrase.yaml
│ │ │ ├── gqa/
│ │ │ │ └── balanced_val.yaml
│ │ │ ├── laion/
│ │ │ │ └── defaults.yaml
│ │ │ ├── llava/
│ │ │ │ ├── conversation.yaml
│ │ │ │ ├── detail.yaml
│ │ │ │ └── reason.yaml
│ │ │ ├── multitask_conversation/
│ │ │ │ └── default.yaml
│ │ │ ├── nlp/
│ │ │ │ └── unnatural_instruction.yaml
│ │ │ ├── ocrvqa/
│ │ │ │ └── ocrvqa.yaml
│ │ │ ├── okvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── textcaps/
│ │ │ │ └── caption.yaml
│ │ │ └── vg/
│ │ │ └── ref.yaml
│ │ ├── default.yaml
│ │ └── models/
│ │ ├── minigpt4_llama2.yaml
│ │ ├── minigpt4_vicuna0.yaml
│ │ └── minigpt_v2.yaml
│ ├── conversation/
│ │ ├── __init__.py
│ │ └── conversation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── builders/
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset_builder.py
│ │ │ └── image_text_pair_builder.py
│ │ ├── data_utils.py
│ │ └── datasets/
│ │ ├── __init__.py
│ │ ├── aok_vqa_datasets.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── coco_caption.py
│ │ ├── coco_dataset.py
│ │ ├── coco_vqa_datasets.py
│ │ ├── dataloader_utils.py
│ │ ├── flickr.py
│ │ ├── gqa_datasets.py
│ │ ├── laion_dataset.py
│ │ ├── llava_dataset.py
│ │ ├── multitask_conversation.py
│ │ ├── ocrvqa_dataset.py
│ │ ├── text_caps.py
│ │ ├── unnatural_instruction.py
│ │ ├── vg_dataset.py
│ │ └── vqa_datasets.py
│ ├── models/
│ │ ├── Qformer.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── eva_vit.py
│ │ ├── minigpt4.py
│ │ ├── minigpt_base.py
│ │ ├── minigpt_v2.py
│ │ └── modeling_llama.py
│ ├── processors/
│ │ ├── __init__.py
│ │ ├── base_processor.py
│ │ ├── blip_processors.py
│ │ └── randaugment.py
│ ├── runners/
│ │ ├── __init__.py
│ │ └── runner_base.py
│ └── tasks/
│ ├── __init__.py
│ ├── base_task.py
│ └── image_text_pretrain.py
├── train.py
└── train_configs/
├── minigpt4_llama2_stage1_pretrain.yaml
├── minigpt4_llama2_stage2_finetune.yaml
├── minigpt4_stage1_pretrain.yaml
├── minigpt4_stage2_finetune.yaml
└── minigptv2_finetune.yaml
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. iOS]
- Browser [e.g. chrome, safari]
- Version [e.g. 22]
**Smartphone (please complete the following information):**
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
wandb/
jobs/logs/
*.out
*ipynb
.history/
*.json
*.sh
.ipynb_common
logs/
results/
prompts/
output/
ckpt/
divide_vqa.py
jobs/
*.slurm
slurm*
sbatch_generate*
eval_data/
dataset/Evaluation.md
jupyter_notebook.slurm
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
https://discord.gg/2aNvvYVv.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: LICENSE.md
================================================
BSD 3-Clause License
Copyright 2023 Deyao Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: LICENSE_Lavis.md
================================================
BSD 3-Clause License
Copyright (c) 2022 Salesforce, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: MiniGPT4_Train.md
================================================
## Training of MiniGPT-4
The training of MiniGPT-4 contains two alignment stages.
**1. First pretraining stage**
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
to align the vision and language model. To download and prepare the datasets, please check
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
After the first stage, the visual features are mapped and can be understood by the language
model.
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
You can change the save path in the config file
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
```
A MiniGPT-4 checkpoint with only stage one training can be downloaded
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
**2. Second finetuning stage**
In the second stage, we use a small high quality image-text pair dataset created by ourselves
and convert it to a conversation format to further align MiniGPT-4.
To download and prepare our second stage dataset, please check our
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
To launch the second stage alignment,
first specify the path to the checkpoint file trained in stage 1 in
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
You can also specify the output path there.
Then, run the following command. In our experiments, we use 1 A100.
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
```
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
================================================
FILE: MiniGPTv2_Train.md
================================================
## Finetune of MiniGPT-4
You firstly need to prepare the dataset. you can follow this step to prepare the dataset.
our [dataset preparation](dataset/README_MINIGPTv2_FINETUNE.md).
In the train_configs/minigptv2_finetune.yaml, you need to set up the following paths:
llama_model checkpoint path: "/path/to/llama_checkpoint"
ckpt: "/path/to/pretrained_checkpoint"
ckpt save path: "/path/to/save_checkpoint"
For ckpt, you may load from our pretrained model checkpoints:
| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo) |
|------------------------------|------------------------------|------------------------------|
| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigptv2_finetune.yaml
```
================================================
FILE: README.md
================================================
# MiniGPT-V
<font size='5'>**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**</font>
Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨
☨equal last author
<a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2310.09478.pdf'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/MiniGPT-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'> <a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Gradio-Demo-blue'></a> [](https://www.youtube.com/watch?v=atFCwV2hSY4)
<font size='5'> **MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**</font>
Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
*equal contribution
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
*King Abdullah University of Science and Technology*
## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
<font size='4'> **Example Community Efforts Built on Top of MiniGPT-4 ** </font>
* <a href='https://github.com/waltonfuture/InstructionGPT-4?tab=readme-ov-file'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **InstructionGPT-4**: A 200-Instruction Paradigm for Fine-Tuning MiniGPT-4 Lai Wei, Zihao Jiang, Weiran Huang, Lichao Sun, Arxiv, 2023
* <a href='https://openaccess.thecvf.com/content/ICCV2023W/CLVL/papers/Aubakirova_PatFig_Generating_Short_and_Long_Captions_for_Patent_Figures_ICCVW_2023_paper.pdf'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **PatFig**: Generating Short and Long Captions for Patent Figures.", Aubakirova, Dana, Kim Gerdes, and Lufei Liu, ICCVW, 2023
* <a href='https://github.com/JoshuaChou2018/SkinGPT-4'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **SkinGPT-4**: An Interactive Dermatology Diagnostic System with Visual Large Language Model, Juexiao Zhou and Xiaonan He and Liyuan Sun and Jiannan Xu and Xiuying Chen and Yuetan Chu and Longxi Zhou and Xingyu Liao and Bin Zhang and Xin Gao, Arxiv, 2023
* <a href='https://huggingface.co/Tyrannosaurus/ArtGPT-4'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **ArtGPT-4**: Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4.", Yuan, Zhengqing, Huiwen Xue, Xinyi Wang, Yongming Liu, Zhuanzhe Zhao, and Kun Wang, Arxiv, 2023
</font>
## News
[Oct.31 2023] We release the evaluation code of our MiniGPT-v2.
[Oct.24 2023] We release the finetuning code of our MiniGPT-v2.
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
## Online Demo
Click the image to chat with MiniGPT-v2 around your images
[](https://minigpt-v2.github.io/)
Click the image to chat with MiniGPT-4 around your images
[](https://minigpt-4.github.io)
## MiniGPT-v2 Examples

## MiniGPT-4 Examples
| | |
:-------------------------:|:-------------------------:
 | 
 | 
More examples can be found in the [project page](https://minigpt-4.github.io).
## Getting Started
### Installation
**1. Prepare the code and the environment**
Git clone our repository, creating a python environment and activate it via the following command
```bash
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
cd MiniGPT-4
conda env create -f environment.yml
conda activate minigptv
```
**2. Prepare the pretrained LLM weights**
**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version.
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main)
Then, set the variable *llama_model* in the model config file to the LLM weight path.
* For MiniGPT-v2, set the LLM path
[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14.
* For MiniGPT-4 (Llama2), set the LLM path
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
* For MiniGPT-4 (Vicuna), set the LLM path
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
**3. Prepare the pretrained model checkpoints**
Download the pretrained model checkpoints
| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo)|
|------------------------------|------------------------------|------------------------------|
| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) |
|----------------------------|---------------------------|---------------------------------|
| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) |
For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally
For MiniGPT-v2, run
```
python demo_v2.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0
```
For MiniGPT-4 (Vicuna version), run
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
For MiniGPT-4 (Llama2 version), run
```
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
```
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model
in 16 bit by setting `low_resource` to `False` in the relevant config file:
* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6)
* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6)
* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6)
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
### Training
For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
For finetuning details of MiniGPT-v2, check [here](MiniGPTv2_Train.md)
### Evaluation
For finetuning details of MiniGPT-v2, check [here](eval_scripts/EVAL_README.md)
## Acknowledgement
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
+ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model.
If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX:
```bibtex
@article{chen2023minigptv2,
title={MiniGPT-v2: large language model as a unified interface for vision-language multi-task learning},
author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed},
year={2023},
journal={arXiv preprint arXiv:2310.09478},
}
@article{zhu2023minigpt,
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
journal={arXiv preprint arXiv:2304.10592},
year={2023}
}
```
## License
This repository is under [BSD 3-Clause License](LICENSE.md).
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
BSD 3-Clause License [here](LICENSE_Lavis.md).
================================================
FILE: SECURITY.md
================================================
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ------- | ------------------ |
| 5.1.x | :white_check_mark: |
| 5.0.x | :x: |
| 4.0.x | :white_check_mark: |
| < 4.0 | :x: |
## Reporting a Vulnerability
Use this section to tell people how to report a vulnerability.
Tell them where to go, how often they can expect to get an update on a
reported vulnerability, what to expect if the vulnerability is accepted or
declined, etc.
================================================
FILE: dataset/README_1_STAGE.md
================================================
## Download the filtered Conceptual Captions, SBU, LAION datasets
### Pre-training datasets download:
We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP).
It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets
Image source | Filtered synthetic caption by ViT-L
--- | :---:
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
This will download two json files
```
ccs_synthetic_filtered_large.json
laion_synthetic_filtered_large.json
```
## prepare the data step-by-step
### setup the dataset folder and move the annotation file to the data storage folder
```
export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/
mkdir ${MINIGPT4_DATASET}/cc_sbu
mkdir ${MINIGPT4_DATASET}/laion
mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu
mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion
```
### Convert the scripts to data storate folder
```
cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu
cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu
cp convert_laion.py ${MINIGPT4_DATASET}/laion
cp download_laion.sh ${MINIGPT4_DATASET}/laion
```
### Convert the laion and cc_sbu annotation file format to be img2dataset format
```
cd ${MINIGPT4_DATASET}/cc_sbu
python convert_cc_sbu.py
cd ${MINIGPT4_DATASET}/laion
python convert_laion.py
```
### Download the datasets with img2dataset
```
cd ${MINIGPT4_DATASET}/cc_sbu
sh download_cc_sbu.sh
cd ${MINIGPT4_DATASET}/laion
sh download_laion.sh
```
The final dataset structure
```
.
├── ${MINIGPT4_DATASET}
│ ├── cc_sbu
│ ├── convert_cc_sbu.py
│ ├── download_cc_sbu.sh
│ ├── ccs_synthetic_filtered_large.json
│ ├── ccs_synthetic_filtered_large.tsv
│ └── cc_sbu_dataset
│ ├── 00000.tar
│ ├── 00000.parquet
│ ...
│ ├── laion
│ ├── convert_laion.py
│ ├── download_laion.sh
│ ├── laion_synthetic_filtered_large.json
│ ├── laion_synthetic_filtered_large.tsv
│ └── laion_dataset
│ ├── 00000.tar
│ ├── 00000.parquet
│ ...
...
```
## Set up the dataset configuration files
Then, set up the LAION dataset loading path in
[here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as
${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar
and the Conceptual Captoin and SBU datasets loading path in
[here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as
${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar
================================================
FILE: dataset/README_2_STAGE.md
================================================
## Second Stage Data Preparation
Our second stage dataset can be downloaded from
[here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link)
After extraction, you will get a data follder with the following structure:
```
cc_sbu_align
├── filter_cap.json
└── image
├── 2.jpg
├── 3.jpg
...
```
Put the folder to any path you want.
Then, set up the dataset path in the dataset config file
[here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5.
================================================
FILE: dataset/README_MINIGPTv2_FINETUNE.md
================================================
## Download the dataset for finetuning the MiniGPT-v2
Download the dataset
Image source | Download path
--- | :---:
COCO 2014 images | <a href="http://images.cocodataset.org/zips/train2014.zip">images</a> <a href="https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json"> captions</a>
COCO VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json">vqa train</a> <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json"> vqa val</a>
Visual Genome | <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip">images part1</a> <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip">images part2</a> <a href="https://homes.cs.washington.edu/~ranjay/visualgenome/data/dataset/image_data.json.zip"> image meta data </a>
TextCaps | <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip">images</a> <a href="https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json"> annotations</a>
RefCOCO | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip"> annotations </a>
RefCOCO+ | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip"> annotations </a>
RefCOCOg | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip"> annotations </a>
OKVQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json"> annotations </a>
AOK-VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json"> annotations </a>
OCR-VQA | <a href="https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing"> annotations </a>
GQA | <a href="https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip">images</a> <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json"> annotations </a>
Filtered flickr-30k | <a href="https://drive.google.com/drive/folders/19c_ggBI77AvdtYlPbuI0ZpnPz73T5teX?usp=sharing"> annotations </a>
Multi-task conversation | <a href="https://drive.google.com/file/d/11HHqB2c29hbSk-WLxdta-nG8UCUrcCN1/view?usp=sharing"> annotations </a>
Filtered unnatural instruction | <a href="https://drive.google.com/file/d/1lXNnBcb5WU-sc8Fe2T2N8J0NRw4sBLev/view?usp=sharing"> annotations </a>
LLaVA | <a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/complex_reasoning_77k.json"> Compelex reasoning </a> <a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/detail_23k.json"> Detailed description </a> <a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/conversation_58k.json"> Conversation </a>
### COCO captions
Download the COCO 2014 images and captions
coco 2014 images path
```
${MINIGPTv2_DATASET}
├── coco
│ ├── images
...
```
coco caption annotation path
```
${MINIGPTv2_DATASET}
├── coco_captions
│ └── annotations
│ ├── coco_karpathy_train.json
...
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the coco_karpathy_train.json path
- [minigpt4/configs/datasets/coco/caption.yaml](../minigpt4/configs/datasets/coco/caption.yaml)
### COCO VQA
Download the vqa v2 train and validation json files
```
├── ${MINIGPTv2_DATASET}
│ ├── vqav2
│ ├── vqa_train.json
| ├── vqa_val.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the vqa_train.json and vqa_val.json path
- [minigpt4/configs/datasets/coco/defaults_vqa.yaml](../minigpt4/configs/datasets/coco/defaults_vqa.yaml)
### Visual genome
Download visiual genome images and annotation files
```
${MINIGPTv2_DATASET}
├── visual_genome
│ ├── VG_100K
│ ├── VG_100K_2
│ └── region_descriptions.json
│ └── image_data.json
...
```
Set **image_path** to visual_genome folder.
Similarly, set **ann_path** to the visual_genome folder.
- [minigpt4/configs/datasets/vg/ref.yaml](../minigpt4/configs/datasets/vg/ref.yaml)
### TextCaps
Download the TextCaps images and annotation files
```
├── ${MINIGPTv2_DATASET}
│ ├── textcaps
│ ├── train_images
│ ├── TextCaps_0.1_train.json
```
Set **image_path** to TextCaps train_images folder.
Similarly, set **ann_path** to the TextCaps_0.1_train.json path
- [minigpt4/configs/datasets/textcaps/caption.yaml](../minigpt4/configs/datasets/textcaps/caption.yaml)
### RefCOCO, RefCOCO+, RefCOCOg
Download the RefCOCO, RefCOCO+, RefCOCOg annotation files
```
${MINIGPTv2_DATASET}
├── refcoco_annotations
│ ├── refcoco
│ │ ├── instances.json
│ │ ├── refs(google).p
│ │ └── refs(unc).p
│ ├── refcoco+
│ │ ├── instances.json
│ │ └── refs(unc).p
│ └── refcocog
│ ├── instances.json
│ ├── refs(google).p
│ └─── refs(und).p
...
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** in all the following configs to the above folder *refcoco_annotations* that contains refcoco, refcoco+, and refcocog.
- [minigpt4/configs/datasets/coco_bbox/refcoco.yaml](../minigpt4/configs/datasets/coco_bbox/refcoco.yaml)
- [minigpt4/configs/datasets/coco_bbox/refcocog.yaml](../minigpt4/configs/datasets/coco_bbox/refcocog.yaml)
- [minigpt4/configs/datasets/coco_bbox/refcocop.yaml](../minigpt4/configs/datasets/coco_bbox/refcocop.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml)
- [minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml)
### OKVQA
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── okvqa
│ ├── okvqa_train.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the OKVQA dataset
- [minigpt4/configs/datasets/okvqa/defaults.yaml](../minigpt4/configs/datasets/okvqa/defaults.yaml)
### COCO-VQA
- [OK-VQA Input Questions](https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip)
- [OK-VQA Annotations](https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip)
### AOK-VQA
Download the AOK-VQA annotation dataset
```
export AOKVQA_DIR=YOUR_DATASET_PATH
mkdir -p ${AOKVQA_DIR}
curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
```
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── aokvqa
│ ├── aokvqa_v1p0_train.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the AOKVQA dataset
- [minigpt4/configs/datasets/aokvqa/defaults.yaml](../minigpt4/configs/datasets/aokvqa/defaults.yaml)
### OCR-VQA
Download the OCR-VQA annotation files
download the images with loadDataset.py script
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── ocrvqa
│ ├── images
│ ├── dataset.json
```
Set **image_path** as the ocrvqa/images folder.
Similarly, set **ann_path** to the dataset.json
- [minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml](../minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml)
### GQA
Download the GQA annotation files and images
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── gqa
│ ├── images
│ ├── train_balanced_questions.json
```
Set **image_path** as the gqa/images folder.
Similarly, set **ann_path** to the train_balanced_questions.json
- [minigpt4/configs/datasets/gqa/balanced_val.yaml](../minigpt4/configs/datasets/gqa/balanced_val.yaml)
### filtered Flickr-30k
Download filtered Flickr-30k images (fill this [form](https://forms.illinois.edu/sec/229675) on official website or from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset/download?datasetVersionNumber=1)) and annotation files
```
${MINIGPTv2_DATASET}
├── filtered_flickr
│ ├── images
│ ├── captiontobbox.json
│ ├── groundedcaption.json
│ └── phrasetobbox.json
...
```
Set **image_path** as the flickr-30k images foler.
Similarly, set **ann_path** to the groundedcaption.json, captiontobbox.json and phrasetobbox.json for the
grounded image caption, caption to bbox, and phrase to bbox datasets.
- [minigpt4/configs/datasets/flickr/default.yaml](../minigpt4/configs/datasets/flickr/default.yaml)
- [minigpt4/configs/datasets/flickr/caption_to_phrase.yaml](../minigpt4/configs/datasets/flickr/caption_to_phrase.yaml)
- [minigpt4/configs/datasets/flickr/object_to_phrase.yaml](../minigpt4/configs/datasets/flickr/object_to_phrase.yaml)
### Multi-task conversation
Download the multi-task converstation dataset
```
Location_you_like
${MINIGPTv2_DATASET}
├── multitask_conversation
│ └── multitask_conversation.json
...
```
Set **image_path** as the COCO 2014 images folder.
Similarly, set **ann_path** to the multitask_conversation.json file path
- [minigpt4/configs/datasets/multitask_conversation/default.yaml](../minigpt4/configs/datasets/multitask_conversation/default.yaml)
### Unnatural instruction
Download the filtered unnatural instruction annotation files (we remove the very long sentences from the original unnatural instruction dataset)
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── unnatural_instructions
│ ├── filtered_unnatural_instruction.json
```
There is no image path.
Similarly, set **ann_path** to the filtered_unnatural_instruction.json file path
- [minigpt4/configs/datasets/nlp/unnatural_instruction.yaml](../minigpt4/configs/datasets/nlp/unnatural_instruction.yaml)
### LLaVA
```
Location_you_like
├── ${MINIGPTv2_DATASET}
│ ├── llava
│ ├── conversation_58k.json
│ ├── detail_23k.json
│ ├── complex_reasoning_77k.json
```
Set **image_path** to the COCO 2014 image folder.
Similarly, set **ann_path** to the location of the previous downloaded conversation_58k.json,
detail_23k.json, and complex_reasoning_77k.json in conversation.yaml, detail.yaml, and reason.yaml, respectively.
- [minigpt4/configs/datasets/llava/conversation.yaml](../minigpt4/configs/datasets/llava/conversation.yaml)
- [minigpt4/configs/datasets/llava/detail.yaml](../minigpt4/configs/datasets/llava/detail.yaml)
- [minigpt4/configs/datasets/llava/reason.yaml](../minigpt4/configs/datasets/llava/reason.yaml)
================================================
FILE: dataset/convert_cc_sbu.py
================================================
import json
import csv
# specify input and output file paths
input_file = 'ccs_synthetic_filtered_large.json'
output_file = 'ccs_synthetic_filtered_large.tsv'
# load JSON data from input file
with open(input_file, 'r') as f:
data = json.load(f)
# extract header and data from JSON
header = data[0].keys()
rows = [x.values() for x in data]
# write data to TSV file
with open(output_file, 'w') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(header)
writer.writerows(rows)
================================================
FILE: dataset/convert_laion.py
================================================
import json
import csv
# specify input and output file paths
input_file = 'laion_synthetic_filtered_large.json'
output_file = 'laion_synthetic_filtered_large.tsv'
# load JSON data from input file
with open(input_file, 'r') as f:
data = json.load(f)
# extract header and data from JSON
header = data[0].keys()
rows = [x.values() for x in data]
# write data to TSV file
with open(output_file, 'w') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(header)
writer.writerows(rows)
================================================
FILE: demo.py
================================================
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from transformers import StoppingCriteriaList
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# ========================================
# Model Initialization
# ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
# ========================================
# Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
"""
#TODO show examples below
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
with gr.Column(scale=2):
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
demo.launch(share=True, enable_queue=True)
================================================
FILE: demo_v2.py
================================================
import argparse
import os
import random
from collections import defaultdict
import cv2
import re
import numpy as np
from PIL import Image
import torch
import html
import gradio as gr
import torchvision.transforms as T
import torch.backends.cudnn as cudnn
from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
cudnn.benchmark = False
cudnn.deterministic = True
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval()
CONV_VISION = Conversation(
system="",
roles=(r"<s>[INST] ", r" [/INST]"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
def extract_substrings(string):
# first check if there is no-finished bracket
index = string.rfind('}')
if index != -1:
string = string[:index + 1]
pattern = r'<p>(.*?)\}(?!<)'
matches = re.findall(pattern, string)
substrings = [match for match in matches]
return substrings
def is_overlapping(rect1, rect2):
x1, y1, x2, y2 = rect1
x3, y3, x4, y4 = rect2
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
def save_tmp_img(visual_img):
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
file_path = "/tmp/gradio" + file_name
visual_img.save(file_path)
return file_path
def mask2bbox(mask):
if mask is None:
return ''
mask = mask.resize([100, 100], resample=Image.NEAREST)
mask = np.array(mask)[:, :, 0]
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if rows.sum():
# Get the top, bottom, left, and right boundaries
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
else:
bbox = ''
return bbox
def escape_markdown(text):
# List of Markdown special characters that need to be escaped
md_chars = ['<', '>']
# Escape each special character
for char in md_chars:
text = text.replace(char, '\\' + char)
return text
def reverse_escape(text):
md_chars = ['\\<', '\\>']
for char in md_chars:
text = text.replace(char, char[1:])
return text
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(210, 210, 0),
(255, 0, 255),
(0, 255, 255),
(114, 128, 250),
(0, 165, 255),
(0, 128, 0),
(144, 238, 144),
(238, 238, 175),
(255, 191, 0),
(0, 128, 0),
(226, 43, 138),
(255, 0, 255),
(0, 215, 255),
]
color_map = {
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
color_id, color in enumerate(colors)
}
used_colors = colors
def visualize_all_bbox_together(image, generation):
if image is None:
return None, ''
generation = html.unescape(generation)
image_width, image_height = image.size
image = image.resize([500, int(500 / image_width * image_height)])
image_width, image_height = image.size
string_list = extract_substrings(generation)
if string_list: # it is grounding or detection
mode = 'all'
entities = defaultdict(list)
i = 0
j = 0
for string in string_list:
try:
obj, string = string.split('</p>')
except ValueError:
print('wrong string: ', string)
continue
bbox_list = string.split('<delim>')
flag = False
for bbox_string in bbox_list:
integers = re.findall(r'-?\d+', bbox_string)
if len(integers) == 4:
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities[obj].append([left, bottom, right, top])
j += 1
flag = True
if flag:
i += 1
else:
integers = re.findall(r'-?\d+', generation)
if len(integers) == 4: # it is refer
mode = 'single'
entities = list()
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities.append([left, bottom, right, top])
else:
# don't detect any valid bbox to visualize
return None, ''
if len(entities) == 0:
return None, ''
if isinstance(image, Image.Image):
image_h = image.height
image_w = image.width
image = np.array(image)
elif isinstance(image, str):
if os.path.exists(image):
pil_img = Image.open(image).convert("RGB")
image = np.array(pil_img)[:, :, [2, 1, 0]]
image_h = pil_img.height
image_w = pil_img.width
else:
raise ValueError(f"invaild image path, {image}")
elif isinstance(image, torch.Tensor):
image_tensor = image.cpu()
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
pil_img = T.ToPILImage()(image_tensor)
image_h = pil_img.height
image_w = pil_img.width
image = np.array(pil_img)[:, :, [2, 1, 0]]
else:
raise ValueError(f"invaild image format, {type(image)} for {image}")
indices = list(range(len(entities)))
new_image = image.copy()
previous_bboxes = []
# size of text
text_size = 0.5
# thickness of text
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
box_line = 2
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
base_height = int(text_height * 0.675)
text_offset_original = text_height - base_height
text_spaces = 2
# num_bboxes = sum(len(x[-1]) for x in entities)
used_colors = colors # random.sample(colors, k=num_bboxes)
color_id = -1
for entity_idx, entity_name in enumerate(entities):
if mode == 'single' or mode == 'identify':
bboxes = entity_name
bboxes = [bboxes]
else:
bboxes = entities[entity_name]
color_id += 1
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
skip_flag = False
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
if mode == 'all':
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
x1 = orig_x1 - l_o
y1 = orig_y1 - l_o
if y1 < text_height + text_offset_original + 2 * text_spaces:
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
x1 = orig_x1 + r_o
# add text background
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
text_line)
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
for prev_bbox in previous_bboxes:
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
prev_bbox['phrase'] == entity_name:
skip_flag = True
break
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
y1 += (text_height + text_offset_original + 2 * text_spaces)
if text_bg_y2 >= image_h:
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
text_bg_y2 = image_h
y1 = image_h
break
if not skip_flag:
alpha = 0.5
for i in range(text_bg_y1, text_bg_y2):
for j in range(text_bg_x1, text_bg_x2):
if i < image_h and j < image_w:
if j < text_bg_x1 + 1.35 * c_width:
# original color
bg_color = color
else:
# white
bg_color = [255, 255, 255]
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
np.uint8)
cv2.putText(
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
)
previous_bboxes.append(
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
if mode == 'all':
def color_iterator(colors):
while True:
for color in colors:
yield color
color_gen = color_iterator(colors)
# Add colors to phrases and remove <p></p>
def colored_phrases(match):
phrase = match.group(1)
color = next(color_gen)
return f'<span style="color:rgb{color}">{phrase}</span>'
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
else:
generation_colored = ''
pil_image = Image.fromarray(new_image)
return pil_image, generation_colored
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
interactive=True), chat_state, img_list
def image_upload_trigger(upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
upload_flag = 1
if img_list:
replace_flag = 1
return upload_flag, replace_flag
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
upload_flag = 1
if img_list or replace_flag == 1:
replace_flag = 1
return upload_flag, replace_flag
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
if len(user_message) == 0:
text_box_show = 'Input should not be empty!'
else:
text_box_show = ''
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
else:
mask = None
if '[identify]' in user_message:
# check if user provide bbox in the text input
integers = re.findall(r'-?\d+', user_message)
if len(integers) != 4: # no bbox in text
bbox = mask2bbox(mask)
user_message = user_message + bbox
if chat_state is None:
chat_state = CONV_VISION.copy()
if upload_flag:
if replace_flag:
chat_state = CONV_VISION.copy() # new image, reset everything
replace_flag = 0
chatbot = []
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
upload_flag = 0
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
if '[identify]' in user_message:
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
if visual_img is not None:
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[(file_path,), None]]
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
def gradio_answer(chatbot, chat_state, img_list, temperature):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
if len(img_list) > 0:
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)
output = ''
for new_output in streamer:
escapped = escape_markdown(new_output)
output += escapped
chatbot[-1][1] = output
yield chatbot, chat_state
chat_state.messages[-1][1] = '</s>'
return chatbot, chat_state
def gradio_visualize(chatbot, gr_img):
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
unescaped = reverse_escape(chatbot[-1][1])
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
if visual_img is not None:
if len(generation_color):
chatbot[-1][1] = generation_color
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[None, (file_path,)]]
return chatbot
def gradio_taskselect(idx):
prompt_list = [
'',
'[grounding] describe this image in detail',
'[refer] ',
'[detection] ',
'[identify] what is this ',
'[vqa] '
]
instruct_list = [
'**Hint:** Type in whatever you want',
'**Hint:** Send the command to generate a grounded image description',
'**Hint:** Type in a phrase about an object in the image and send the command',
'**Hint:** Type in a caption or phrase, and see object locations in the image',
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
'**Hint:** Send a question to get a short answer',
]
return prompt_list[idx], instruct_list[idx]
chat = Chat(model, vis_processor, device=device)
title = """<h1 align="center">MiniGPT-v2 Demo</h1>"""
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
# article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p><a href='https://www.youtube.com/watch?v=atFCwV2hSY4'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p>"""
article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
introduction = '''
For Abilities Involving Visual Grounding:
1. Grounding: CLICK **Send** to generate a grounded image description.
2. Refer: Input a referring object and CLICK **Send**.
3. Detection: Write a caption or phrase, and CLICK **Send**.
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
5. VQA: Input a visual question and CLICK **Send**.
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
You can also simply chat in free form!
'''
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
scale=8)
with gr.Blocks() as demo:
gr.Markdown(title)
# gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.6,
step=0.1,
interactive=True,
label="Temperature",
)
clear = gr.Button("Restart")
gr.Markdown(introduction)
with gr.Column():
chat_state = gr.State(value=None)
img_list = gr.State(value=[])
chatbot = gr.Chatbot(label='MiniGPT-v2')
dataset = gr.Dataset(
components=[gr.Textbox(visible=False)],
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
type="index",
label='Task Shortcuts',
)
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
with gr.Row():
text_input.render()
send = gr.Button("Send", variant='primary', size='sm', scale=1)
upload_flag = gr.State(value=0)
replace_flag = gr.State(value=0)
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
with gr.Row():
with gr.Column():
gr.Examples(examples=[
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
img_list],
["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list],
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
img_list],
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
with gr.Column():
gr.Examples(examples=[
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
upload_flag, replace_flag, img_list],
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
dataset.click(
gradio_taskselect,
inputs=[dataset],
outputs=[text_input, task_inst],
show_progress="hidden",
postprocess=False,
queue=False,
)
text_input.submit(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
send.click(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
demo.launch(share=True, enable_queue=True)
================================================
FILE: environment.yml
================================================
name: minigptv
channels:
- pytorch
- defaults
- anaconda
dependencies:
- python=3.9
- cudatoolkit
- pip
- pip:
- torch==2.0.0
- torchaudio
- torchvision
- huggingface-hub==0.18.0
- matplotlib==3.7.0
- psutil==5.9.4
- iopath
- pyyaml==6.0
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.30.0
- timm==0.6.13
- webdataset==0.2.48
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- decord==0.6.0
- peft==0.2.0
- sentence-transformers
- gradio==3.47.1
- accelerate==0.20.3
- bitsandbytes==0.37.0
- scikit-image
- visual-genome
- wandb
================================================
FILE: eval_configs/minigpt4_eval.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_configs/minigpt4_llama2_eval.yaml
================================================
model:
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST] '
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_configs/minigptv2_benchmark_evaluation.yaml
================================================
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 500
end_sym: "</s>"
low_resource: False
prompt_template: '[INST] {} [/INST]'
llama_model: ""
ckpt: ""
lora_r: 64
lora_alpha: 16
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
evaluation_datasets:
refcoco:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
refcocog:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
refcoco+:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
gqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
okvqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
vizwiz:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
iconvqa:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
vsr:
eval_file_path: cambridgeltl/vsr_zeroshot
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 10
hm:
eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path
max_new_tokens: 20
batch_size: 100
run:
task: image_text_pretrain
name: minigptv2_evaluation
save_path: /path/to/save/folder_path
================================================
FILE: eval_configs/minigptv2_eval.yaml
================================================
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 500
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST]'
ckpt: "please set this value to the path of pretrained checkpoint"
lora_r: 64
lora_alpha: 16
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
================================================
FILE: eval_scripts/EVAL_README.md
================================================
## Evaluation Instruction for MiniGPT-v2
### Data preparation
Images download
Image source | Download path
--- | :---:
OKVQA| <a href="https://drive.google.com/drive/folders/1jxIgAhtaLu_YqnZEl8Ym11f7LhX3nptN?usp=sharing">annotations</a> <a href="http://images.cocodataset.org/zips/train2017.zip"> images</a>
gqa | <a href="https://drive.google.com/drive/folders/1-dF-cgFwstutS4qq2D9CFQTDS0UTmIft?usp=drive_link">annotations</a> <a href="https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip">images</a>
hateful meme | <a href="https://github.com/faizanahemad/facebook-hateful-memes">images and annotations</a>
iconqa | <a href="https://iconqa.github.io/#download">images and annotation</a>
vizwiz | <a href="https://vizwiz.org/tasks-and-datasets/vqa/">images and annotation</a>
RefCOCO | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip"> annotations </a>
RefCOCO+ | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip"> annotations </a>
RefCOCOg | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip"> annotations </a>
### Evaluation dataset structure
```
${MINIGPTv2_EVALUATION_DATASET}
├── gqa
│ └── test_balanced_questions.json
│ ├── testdev_balanced_questions.json
│ ├── gqa_images
├── hateful_meme
│ └── hm_images
│ ├── dev.jsonl
├── iconvqa
│ └── iconvqa_images
│ ├── choose_text_val.json
├── vizwiz
│ └── vizwiz_images
│ ├── val.json
├── vsr
│ └── vsr_images
├── okvqa
│ ├── okvqa_test_split.json
│ ├── mscoco_val2014_annotations_clean.json
│ ├── OpenEnded_mscoco_val2014_questions_clean.json
├── refcoco
│ └── instances.json
│ ├── refs(google).p
│ ├── refs(unc).p
├── refcoco+
│ └── instances.json
│ ├── refs(unc).p
├── refercocog
│ └── instances.json
│ ├── refs(google).p
│ ├── refs(und).p
...
```
### environment setup
```
export PYTHONPATH=$PYTHONPATH:/path/to/directory/of/MiniGPT-4
```
### config file setup
Set **llama_model** to the path of LLaMA model.
Set **ckpt** to the path of our pretrained model.
Set **eval_file_path** to the path of the annotation files for each evaluation data.
Set **img_path** to the img_path for each evaluation dataset.
Set **save_path** to the save_path for each evaluation dataset.
in [eval_configs/minigptv2_benchmark_evaluation.yaml](../eval_configs/minigptv2_benchmark_evaluation.yaml)
### start evalauting RefCOCO, RefCOCO+, RefCOCOg
port=port_number
cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
dataset names:
| refcoco | refcoco+ | refcocog |
| ------- | -------- | -------- |
```
torchrun --master-port ${port} --nproc_per_node 1 eval_ref.py \
--cfg-path ${cfg_path} --dataset refcoco,refcoco+,refcocog --resample
```
### start evaluating visual question answering
port=port_number
cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
dataset names:
| okvqa | vizwiz | iconvqa | gqa | vsr | hm |
| ------- | -------- | -------- |-------- | -------- | -------- |
```
torchrun --master-port ${port} --nproc_per_node 1 eval_vqa.py \
--cfg-path ${cfg_path} --dataset okvqa,vizwiz,iconvqa,gqa,vsr,hm
```
================================================
FILE: eval_scripts/eval_ref.py
================================================
import os
import re
import json
import argparse
from collections import defaultdict
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from minigpt4.common.config import Config
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData
def list_of_str(arg):
return list(map(str, arg.split(',')))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
args = parser.parse_args()
cfg = Config(args)
eval_dict = {'refcoco': ['val','testA','testB'],
'refcoco+': ['val','testA','testB'],
'refcocog': ['val','test']}
model, vis_processor = init_model(args)
model.eval()
CONV_VISION = CONV_VISION_minigptv2
conv_temp = CONV_VISION.copy()
conv_temp.system = ""
#
model.eval()
save_path = cfg.run_cfg.save_path
for dataset in args.dataset:
for split in eval_dict[dataset]:
eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f:
refcoco = json.load(f)
data = RefCOCOEvalData(refcoco, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = defaultdict(list)
resamples = []
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("<unk>","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer):
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
if args.resample:
for i in range(20):
data = RefCOCOEvalData(resamples, vis_processor, img_path)
resamples = []
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("<unk>","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer) or i == 4:
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
if len(resamples) == 0:
break
file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
count=0
total=len(refcoco)
res=args.res
refcoco_dict = defaultdict()
for item in refcoco:
refcoco_dict[item['img_id']] = item
for img_id in refcoco_dict:
item = refcoco_dict[img_id]
bbox = item['bbox']
outputs = minigpt4_predict[img_id]
for output in outputs:
try:
integers = re.findall(r'\d+', output)
pred_bbox = [int(num) for num in integers]
height = item['height']
width = item['width']
pred_bbox[0] = pred_bbox[0] / res * width
pred_bbox[1] = pred_bbox[1] / res * height
pred_bbox[2] = pred_bbox[2] / res * width
pred_bbox[3] = pred_bbox[3] / res * height
gt_bbox = [0,0,0,0]
gt_bbox[0] = bbox[0]
gt_bbox[1] = bbox[1]
gt_bbox[2] = bbox[0] + bbox[2]
gt_bbox[3] = bbox[1] + bbox[3]
iou_score = computeIoU(pred_bbox, gt_bbox)
if iou_score > 0.5:
count+=1
except:
continue
print(f'{dataset} {split}:', count / total * 100, flush=True)
================================================
FILE: eval_scripts/eval_vqa.py
================================================
import os
import re
import json
import argparse
from collections import defaultdict
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,IconQAEvalData,GQAEvalData,VSREvalData,HMEvalData
from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.common.config import Config
def list_of_str(arg):
return list(map(str, arg.split(',')))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
args = parser.parse_args()
cfg = Config(args)
model, vis_processor = init_model(args)
conv_temp = CONV_VISION_minigptv2.copy()
conv_temp.system = ""
model.eval()
save_path = cfg.run_cfg.save_path
if 'okvqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["okvqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["okvqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["okvqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["okvqa"]["max_new_tokens"]
evaluation_annntation_path = os.path.join(eval_file_path, "okvqa_test_split.json")
with open(evaluation_annntation_path) as f:
ok_vqa_test_split = json.load(f)
data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = []
for images, questions, question_ids, img_ids in eval_dataloader:
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, question_id, question, img_id in zip(answers, question_ids, questions, img_ids):
result = dict()
answer = answer.lower().replace('<unk>','').strip()
result['answer'] = answer
result['question_id'] = int(question_id)
minigpt4_predict.append(result)
file_save_path= os.path.join(save_path,"okvqa.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
annFile = os.path.join(eval_file_path,"mscoco_val2014_annotations_clean.json")
quesFile = os.path.join(eval_file_path,"OpenEnded_mscoco_val2014_questions_clean.json" )
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(file_save_path, quesFile)
vqaEval = VQAEval(vqa, vqaRes, n=2)
vqaEval.evaluate()
print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True)
if 'vizwiz' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["vizwiz"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["vizwiz"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vizwiz"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vizwiz"]["max_new_tokens"]
vizwiz = json.load(open(eval_file_path, 'r'))
data = VizWizEvalData(vizwiz, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = []
total_acc = []
for images, texts, gt_answers in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
with torch.no_grad():
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False,repetition_penalty=1.0)
for answer, gt_answer in zip(answers, gt_answers):
result = dict()
result['answer'] = answer.replace('<unk>','').strip()
minigpt4_predict.append(result)
count=0
gt_answer = gt_answer.split('_')
for gt in gt_answer:
if gt.lower() == answer.lower():
count += 1
acc = min(count/3.0, 1.0)
total_acc.append(acc)
file_save_path = os.path.join(save_path, "vizwiz.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
print('vizwiz Acc: ', np.average(total_acc)* 100.0, flush=True)
if 'iconvqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["iconvqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["iconvqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["iconvqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["iconvqa"]["max_new_tokens"]
iconqa_text_val = json.load(open(eval_file_path,"r"))
data = IconQAEvalData(iconqa_text_val, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count = 0
for images, texts, candidates, answers in tqdm(eval_dataloader):
candidates = [candidate.split('_') for candidate in candidates]
num_cand = [len(candidate) for candidate in candidates]
for candidate in candidates:
candidate.extend(['none'] * (max(num_cand) - len(candidate)))
candidates = [list(x) for x in zip(*candidates)]
instructions = ["<s>[INST] <Img><ImageHere></Img> {} [/INST]".format(text) for text in texts]
answer_ranks = model.multi_select(images, instructions, candidates, num_cand=num_cand)
for idx, answer in enumerate(answers):
if answer_ranks[idx][0] == answer:
count += 1
print('iconqa Acc: ', count / len(iconqa_text_val) * 100.0, flush=True)
if 'gqa' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["gqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["gqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["gqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["gqa"]["max_new_tokens"]
gqa = json.load(open(eval_file_path))
data = GQAEvalData(gqa, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
result['pred'] = answer.lower().replace('<unk>','').strip()
result['gt'] = label
minigpt4_predict.append(result)
if answer.lower() == label:
count+=1
total+=1
print('gqa val:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path, "gqa.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
if 'vsr' in args.dataset:
img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
data = VSREvalData(annotation, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
result['pred'] = answer.replace('<unk>','').strip()
result['gt'] = label
minigpt4_predict.append(result)
if answer.lower() == label.lower():
count+=1
total+=1
print('vsr test:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path,"vsr.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
if 'hm' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["hm"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["hm"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["hm"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["hm"]["max_new_tokens"]
annotation = []
with open(eval_file_path, 'r') as jsonl_file:
for line in jsonl_file:
json_obj = json.loads(line)
annotation.append(json_obj)
data = HMEvalData(annotation, vis_processor, img_path)
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
count=0
total=0
minigpt4_predict = []
for images, texts, labels in tqdm(eval_dataloader):
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, label in zip(answers, labels):
result = dict()
if answer.lower().strip() =="yes":
answer=1
elif answer.lower().strip()=="no":
answer=0
else:
print("non-matching answer",answer)
result['pred'] = answer
result['gt'] = int(label)
minigpt4_predict.append(result)
if answer == label:
count+=1
total+=1
print('hm val:', count / total * 100, flush=True)
file_save_path = os.path.join(save_path, "hm.json")
with open(file_save_path,'w') as f:
json.dump(minigpt4_predict, f)
================================================
FILE: minigpt4/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import sys
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.tasks import *
root_dir = os.path.dirname(os.path.abspath(__file__))
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
registry.register_path("library_root", root_dir)
repo_root = os.path.join(root_dir, "..")
registry.register_path("repo_root", repo_root)
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
registry.register_path("cache_root", cache_root)
registry.register("MAX_INT", sys.maxsize)
registry.register("SPLIT_NAMES", ["train", "val", "test"])
================================================
FILE: minigpt4/common/__init__.py
================================================
================================================
FILE: minigpt4/common/config.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import json
from typing import Dict
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
class Config:
def __init__(self, args):
self.config = {}
self.args = args
# Register the config and configuration for setup
registry.register("configuration", self)
user_config = self._build_opt_list(self.args.options)
config = OmegaConf.load(self.args.cfg_path)
runner_config = self.build_runner_config(config)
model_config = self.build_model_config(config, **user_config)
dataset_config = self.build_dataset_config(config)
evaluation_dataset_config = self.build_evaluation_dataset_config(config)
# Validate the user-provided runner configuration
# model and dataset configuration are supposed to be validated by the respective classes
# [TODO] validate the model/dataset configuration
# self._validate_runner_config(runner_config)
# Override the default configuration with user options.
self.config = OmegaConf.merge(
runner_config, model_config, dataset_config,evaluation_dataset_config, user_config
)
def _validate_runner_config(self, runner_config):
"""
This method validates the configuration, such that
1) all the user specified options are valid;
2) no type mismatches between the user specified options and the config.
"""
runner_config_validator = create_runner_config_validator()
runner_config_validator.validate(runner_config)
def _build_opt_list(self, opts):
opts_dot_list = self._convert_to_dot_list(opts)
return OmegaConf.from_dotlist(opts_dot_list)
@staticmethod
def build_model_config(config, **kwargs):
model = config.get("model", None)
assert model is not None, "Missing model configuration file."
model_cls = registry.get_model_class(model.arch)
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
model_type = kwargs.get("model.model_type", None)
if not model_type:
model_type = model.get("model_type", None)
# else use the model type selected by user.
assert model_type is not None, "Missing model_type."
model_config_path = model_cls.default_config_path(model_type=model_type)
model_config = OmegaConf.create()
# hierarchy override, customized config > default config
model_config = OmegaConf.merge(
model_config,
OmegaConf.load(model_config_path),
{"model": config["model"]},
)
return model_config
@staticmethod
def build_runner_config(config):
return {"run": config.run}
@staticmethod
def build_dataset_config(config):
datasets = config.get("datasets", None)
if datasets is None:
raise KeyError(
"Expecting 'datasets' as the root key for dataset configuration."
)
dataset_config = OmegaConf.create()
for dataset_name in datasets:
builder_cls = registry.get_builder_class(dataset_name)
dataset_config_type = datasets[dataset_name].get("type", "default")
dataset_config_path = builder_cls.default_config_path(
type=dataset_config_type
)
# hierarchy override, customized config > default config
dataset_config = OmegaConf.merge(
dataset_config,
OmegaConf.load(dataset_config_path),
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
)
return dataset_config
@staticmethod
def build_evaluation_dataset_config(config):
datasets = config.get("evaluation_datasets", None)
# if datasets is None:
# raise KeyError(
# "Expecting 'datasets' as the root key for dataset configuration."
# )
dataset_config = OmegaConf.create()
if datasets is not None:
for dataset_name in datasets:
builder_cls = registry.get_builder_class(dataset_name)
# hierarchy override, customized config > default config
dataset_config = OmegaConf.merge(
dataset_config,
{"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}},
)
return dataset_config
def _convert_to_dot_list(self, opts):
if opts is None:
opts = []
if len(opts) == 0:
return opts
has_equal = opts[0].find("=") != -1
if has_equal:
return opts
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
def get_config(self):
return self.config
@property
def run_cfg(self):
return self.config.run
@property
def datasets_cfg(self):
return self.config.datasets
@property
def evaluation_datasets_cfg(self):
return self.config.evaluation_datasets
@property
def model_cfg(self):
return self.config.model
def pretty_print(self):
logging.info("\n===== Running Parameters =====")
logging.info(self._convert_node_to_json(self.config.run))
logging.info("\n====== Dataset Attributes ======")
datasets = self.config.datasets
for dataset in datasets:
if dataset in self.config.datasets:
logging.info(f"\n======== {dataset} =======")
dataset_config = self.config.datasets[dataset]
logging.info(self._convert_node_to_json(dataset_config))
else:
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
logging.info(f"\n====== Model Attributes ======")
logging.info(self._convert_node_to_json(self.config.model))
def _convert_node_to_json(self, node):
container = OmegaConf.to_container(node, resolve=True)
return json.dumps(container, indent=4, sort_keys=True)
def to_dict(self):
return OmegaConf.to_container(self.config)
def node_to_dict(node):
return OmegaConf.to_container(node)
class ConfigValidator:
"""
This is a preliminary implementation to centralize and validate the configuration.
May be altered in the future.
A helper class to validate configurations from yaml file.
This serves the following purposes:
1. Ensure all the options in the yaml are defined, raise error if not.
2. when type mismatches are found, the validator will raise an error.
3. a central place to store and display helpful messages for supported configurations.
"""
class _Argument:
def __init__(self, name, choices=None, type=None, help=None):
self.name = name
self.val = None
self.choices = choices
self.type = type
self.help = help
def __str__(self):
s = f"{self.name}={self.val}"
if self.type is not None:
s += f", ({self.type})"
if self.choices is not None:
s += f", choices: {self.choices}"
if self.help is not None:
s += f", ({self.help})"
return s
def __init__(self, description):
self.description = description
self.arguments = dict()
self.parsed_args = None
def __getitem__(self, key):
assert self.parsed_args is not None, "No arguments parsed yet."
return self.parsed_args[key]
def __str__(self) -> str:
return self.format_help()
def add_argument(self, *args, **kwargs):
"""
Assume the first argument is the name of the argument.
"""
self.arguments[args[0]] = self._Argument(*args, **kwargs)
def validate(self, config=None):
"""
Convert yaml config (dict-like) to list, required by argparse.
"""
for k, v in config.items():
assert (
k in self.arguments
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
if self.arguments[k].type is not None:
try:
self.arguments[k].val = self.arguments[k].type(v)
except ValueError:
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
if self.arguments[k].choices is not None:
assert (
v in self.arguments[k].choices
), f"""{k} must be one of {self.arguments[k].choices}."""
return config
def format_arguments(self):
return str([f"{k}" for k in sorted(self.arguments.keys())])
def format_help(self):
# description + key-value pair string for each argument
help_msg = str(self.description)
return help_msg + ", available arguments: " + self.format_arguments()
def print_help(self):
# display help message
print(self.format_help())
def create_runner_config_validator():
validator = ConfigValidator(description="Runner configurations")
validator.add_argument(
"runner",
type=str,
choices=["runner_base", "runner_iter"],
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
runner runs based on iters. Default: runner_base""",
)
# add argumetns for training dataset ratios
validator.add_argument(
"train_dataset_ratios",
type=Dict[str, float],
help="""Ratios of training dataset. This is used in iteration-based runner.
Do not support for epoch-based runner because how to define an epoch becomes tricky.
Default: None""",
)
validator.add_argument(
"max_iters",
type=float,
help="Maximum number of iterations to run.",
)
validator.add_argument(
"max_epoch",
type=int,
help="Maximum number of epochs to run.",
)
# add arguments for iters_per_inner_epoch
validator.add_argument(
"iters_per_inner_epoch",
type=float,
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
)
lr_scheds_choices = registry.list_lr_schedulers()
validator.add_argument(
"lr_sched",
type=str,
choices=lr_scheds_choices,
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
)
task_choices = registry.list_tasks()
validator.add_argument(
"task",
type=str,
choices=task_choices,
help="Task to use, from {}".format(task_choices),
)
# add arguments for init_lr
validator.add_argument(
"init_lr",
type=float,
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
)
# add arguments for min_lr
validator.add_argument(
"min_lr",
type=float,
help="Minimum learning rate (after decay).",
)
# add arguments for warmup_lr
validator.add_argument(
"warmup_lr",
type=float,
help="Starting learning rate for warmup.",
)
# add arguments for learning rate decay rate
validator.add_argument(
"lr_decay_rate",
type=float,
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
)
# add arguments for weight decay
validator.add_argument(
"weight_decay",
type=float,
help="Weight decay rate.",
)
# add arguments for training batch size
validator.add_argument(
"batch_size_train",
type=int,
help="Training batch size.",
)
# add arguments for evaluation batch size
validator.add_argument(
"batch_size_eval",
type=int,
help="Evaluation batch size, including validation and testing.",
)
# add arguments for number of workers for data loading
validator.add_argument(
"num_workers",
help="Number of workers for data loading.",
)
# add arguments for warm up steps
validator.add_argument(
"warmup_steps",
type=int,
help="Number of warmup steps. Required if a warmup schedule is used.",
)
# add arguments for random seed
validator.add_argument(
"seed",
type=int,
help="Random seed.",
)
# add arguments for output directory
validator.add_argument(
"output_dir",
type=str,
help="Output directory to save checkpoints and logs.",
)
# add arguments for whether only use evaluation
validator.add_argument(
"evaluate",
help="Whether to only evaluate the model. If true, training will not be performed.",
)
# add arguments for splits used for training, e.g. ["train", "val"]
validator.add_argument(
"train_splits",
type=list,
help="Splits to use for training.",
)
# add arguments for splits used for validation, e.g. ["val"]
validator.add_argument(
"valid_splits",
type=list,
help="Splits to use for validation. If not provided, will skip the validation.",
)
# add arguments for splits used for testing, e.g. ["test"]
validator.add_argument(
"test_splits",
type=list,
help="Splits to use for testing. If not provided, will skip the testing.",
)
# add arguments for accumulating gradient for iterations
validator.add_argument(
"accum_grad_iters",
type=int,
help="Number of iterations to accumulate gradient for.",
)
# ====== distributed training ======
validator.add_argument(
"device",
type=str,
choices=["cpu", "cuda"],
help="Device to use. Support 'cuda' or 'cpu' as for now.",
)
validator.add_argument(
"world_size",
type=int,
help="Number of processes participating in the job.",
)
validator.add_argument("dist_url", type=str)
validator.add_argument("distributed", type=bool)
# add arguments to opt using distributed sampler during evaluation or not
validator.add_argument(
"use_dist_eval_sampler",
type=bool,
help="Whether to use distributed sampler during evaluation or not.",
)
# ====== task specific ======
# generation task specific arguments
# add arguments for maximal length of text output
validator.add_argument(
"max_len",
type=int,
help="Maximal length of text output.",
)
# add arguments for minimal length of text output
validator.add_argument(
"min_len",
type=int,
help="Minimal length of text output.",
)
# add arguments number of beams
validator.add_argument(
"num_beams",
type=int,
help="Number of beams used for beam search.",
)
# vqa task specific arguments
# add arguments for number of answer candidates
validator.add_argument(
"num_ans_candidates",
type=int,
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
)
# add arguments for inference method
validator.add_argument(
"inference_method",
type=str,
choices=["genearte", "rank"],
help="""Inference method to use for question answering. If rank, requires a answer list.""",
)
# ====== model specific ======
validator.add_argument(
"k_test",
type=int,
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
)
return validator
================================================
FILE: minigpt4/common/dist_utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import functools
import os
import torch
import torch.distributed as dist
import timm.models.hub as timm_hub
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if args.distributed is False:
print("Not using distributed mode")
return
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
"| distributed init (rank {}, world {}): {}".format(
args.rank, args.world_size, args.dist_url
),
flush=True,
)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(
days=365
), # allow auto-downloading and de-compressing
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def get_dist_info():
if torch.__version__ < "1.0":
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else: # non-distributed training
rank = 0
world_size = 1
return rank, world_size
def main_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
================================================
FILE: minigpt4/common/eval_utils.py
================================================
import argparse
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from minigpt4.common.registry import registry
from minigpt4.common.config import Config
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def eval_parser():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--name", type=str, default='A2', help="evaluation name")
parser.add_argument("--ckpt", type=str, help="path to configuration file.")
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
return parser
def prepare_texts(texts, conv_temp):
convs = [conv_temp.copy() for _ in range(len(texts))]
[conv.append_message(
conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)]
[conv.append_message(conv.roles[1], None) for conv in convs]
texts = [conv.get_prompt() for conv in convs]
return texts
def init_model(args):
print('Initialization Model')
cfg = Config(args)
# cfg.model_cfg.ckpt = args.ckpt
# cfg.model_cfg.lora_r = args.lora_r
# cfg.model_cfg.lora_alpha = args.lora_alpha
model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
# import pudb; pudb.set_trace()
key = list(cfg.datasets_cfg.keys())[0]
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
print('Initialization Finished')
return model, vis_processor
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
================================================
FILE: minigpt4/common/gradcam.py
================================================
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform
def getAttMap(img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
)
return attMap
================================================
FILE: minigpt4/common/logger.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
from minigpt4.common import dist_utils
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not dist_utils.is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def global_avg(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
log_msg = [
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
if torch.cuda.is_available():
log_msg.append("max mem: {memory:.0f}")
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def setup_logger():
logging.basicConfig(
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)
================================================
FILE: minigpt4/common/optims.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import math
from minigpt4.common.registry import registry
@registry.register_lr_scheduler("linear_warmup_step_lr")
class LinearWarmupStepLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
min_lr,
init_lr,
decay_rate=1,
warmup_start_lr=-1,
warmup_steps=0,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.min_lr = min_lr
self.decay_rate = decay_rate
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
if cur_epoch == 0:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
step_lr_schedule(
epoch=cur_epoch,
optimizer=self.optimizer,
init_lr=self.init_lr,
min_lr=self.min_lr,
decay_rate=self.decay_rate,
)
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
class LinearWarmupCosineLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
iters_per_epoch,
min_lr,
init_lr,
warmup_steps=0,
warmup_start_lr=-1,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.iters_per_epoch = iters_per_epoch
self.min_lr = min_lr
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
if total_cur_step < self.warmup_steps:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
cosine_lr_schedule(
epoch=total_cur_step,
optimizer=self.optimizer,
max_epoch=self.max_epoch * self.iters_per_epoch,
init_lr=self.init_lr,
min_lr=self.min_lr,
)
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
"""Decay the learning rate"""
lr = (init_lr - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * epoch / max_epoch)
) + min_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
"""Warmup the learning rate"""
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
"""Decay the learning rate"""
lr = max(min_lr, init_lr * (decay_rate**epoch))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
================================================
FILE: minigpt4/common/registry.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
class Registry:
mapping = {
"builder_name_mapping": {},
"task_name_mapping": {},
"processor_name_mapping": {},
"model_name_mapping": {},
"lr_scheduler_name_mapping": {},
"runner_name_mapping": {},
"state": {},
"paths": {},
}
@classmethod
def register_builder(cls, name):
r"""Register a dataset builder to registry with key 'name'
Args:
name: Key with which the builder will be registered.
Usage:
from minigpt4.common.registry import registry
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
"""
def wrap(builder_cls):
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
assert issubclass(
builder_cls, BaseDatasetBuilder
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls
)
if name in cls.mapping["builder_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["builder_name_mapping"][name]
)
)
cls.mapping["builder_name_mapping"][name] = builder_cls
return builder_cls
return wrap
@classmethod
def register_task(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(task_cls):
from minigpt4.tasks.base_task import BaseTask
assert issubclass(
task_cls, BaseTask
), "All tasks must inherit BaseTask class"
if name in cls.mapping["task_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["task_name_mapping"][name]
)
)
cls.mapping["task_name_mapping"][name] = task_cls
return task_cls
return wrap
@classmethod
def register_model(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(model_cls):
from minigpt4.models import BaseModel
assert issubclass(
model_cls, BaseModel
), "All models must inherit BaseModel class"
if name in cls.mapping["model_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["model_name_mapping"][name]
)
)
cls.mapping["model_name_mapping"][name] = model_cls
return model_cls
return wrap
@classmethod
def register_processor(cls, name):
r"""Register a processor to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(processor_cls):
from minigpt4.processors import BaseProcessor
assert issubclass(
processor_cls, BaseProcessor
), "All processors must inherit BaseProcessor class"
if name in cls.mapping["processor_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["processor_name_mapping"][name]
)
)
cls.mapping["processor_name_mapping"][name] = processor_cls
return processor_cls
return wrap
@classmethod
def register_lr_scheduler(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(lr_sched_cls):
if name in cls.mapping["lr_scheduler_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["lr_scheduler_name_mapping"][name]
)
)
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
return lr_sched_cls
return wrap
@classmethod
def register_runner(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(runner_cls):
if name in cls.mapping["runner_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["runner_name_mapping"][name]
)
)
cls.mapping["runner_name_mapping"][name] = runner_cls
return runner_cls
return wrap
@classmethod
def register_path(cls, name, path):
r"""Register a path to registry with key 'name'
Args:
name: Key with which the path will be registered.
Usage:
from minigpt4.common.registry import registry
"""
assert isinstance(path, str), "All path must be str."
if name in cls.mapping["paths"]:
raise KeyError("Name '{}' already registered.".format(name))
cls.mapping["paths"][name] = path
@classmethod
def register(cls, name, obj):
r"""Register an item to registry with key 'name'
Args:
name: Key with which the item will be registered.
Usage::
from minigpt4.common.registry import registry
registry.register("config", {})
"""
path = name.split(".")
current = cls.mapping["state"]
for part in path[:-1]:
if part not in current:
current[part] = {}
current = current[part]
current[path[-1]] = obj
# @classmethod
# def get_trainer_class(cls, name):
# return cls.mapping["trainer_name_mapping"].get(name, None)
@classmethod
def get_builder_class(cls, name):
return cls.mapping["builder_name_mapping"].get(name, None)
@classmethod
def get_model_class(cls, name):
return cls.mapping["model_name_mapping"].get(name, None)
@classmethod
def get_task_class(cls, name):
return cls.mapping["task_name_mapping"].get(name, None)
@classmethod
def get_processor_class(cls, name):
return cls.mapping["processor_name_mapping"].get(name, None)
@classmethod
def get_lr_scheduler_class(cls, name):
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
@classmethod
def get_runner_class(cls, name):
return cls.mapping["runner_name_mapping"].get(name, None)
@classmethod
def list_runners(cls):
return sorted(cls.mapping["runner_name_mapping"].keys())
@classmethod
def list_models(cls):
return sorted(cls.mapping["model_name_mapping"].keys())
@classmethod
def list_tasks(cls):
return sorted(cls.mapping["task_name_mapping"].keys())
@classmethod
def list_processors(cls):
return sorted(cls.mapping["processor_name_mapping"].keys())
@classmethod
def list_lr_schedulers(cls):
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
@classmethod
def list_datasets(cls):
return sorted(cls.mapping["builder_name_mapping"].keys())
@classmethod
def get_path(cls, name):
return cls.mapping["paths"].get(name, None)
@classmethod
def get(cls, name, default=None, no_warning=False):
r"""Get an item from registry with key 'name'
Args:
name (string): Key whose value needs to be retrieved.
default: If passed and key is not in registry, default value will
be returned with a warning. Default: None
no_warning (bool): If passed as True, warning when key doesn't exist
will not be generated. Useful for MMF's
internal operations. Default: False
"""
original_name = name
name = name.split(".")
value = cls.mapping["state"]
for subname in name:
value = value.get(subname, default)
if value is default:
break
if (
"writer" in cls.mapping["state"]
and value == default
and no_warning is False
):
cls.mapping["state"]["writer"].warning(
"Key {} is not present in registry, returning default value "
"of {}".format(original_name, default)
)
return value
@classmethod
def unregister(cls, name):
r"""Remove an item from registry with key 'name'
Args:
name: Key which needs to be removed.
Usage::
from mmf.common.registry import registry
config = registry.unregister("config")
"""
return cls.mapping["state"].pop(name, None)
registry = Registry()
================================================
FILE: minigpt4/common/utils.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import io
import json
import logging
import os
import pickle
import re
import shutil
import urllib
import urllib.error
import urllib.request
from typing import Optional
from urllib.parse import urlparse
import numpy as np
import pandas as pd
import yaml
from iopath.common.download import download
from iopath.common.file_io import file_lock, g_pathmgr
from minigpt4.common.registry import registry
from torch.utils.model_zoo import tqdm
from torchvision.datasets.utils import (
check_integrity,
download_file_from_google_drive,
extract_archive,
)
def now():
from datetime import datetime
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def get_cache_path(rel_path):
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
def get_abs_path(rel_path):
return os.path.join(registry.get_path("library_root"), rel_path)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
# The following are adapted from torchvision and vissl
# torchvision: https://github.com/pytorch/vision
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
print(f"Error creating directory: {dir_path}")
return is_success
def get_redirected_url(url: str):
"""
Given a URL, returns the URL it redirects to or the
original URL in case of no indirection
"""
import requests
with requests.Session() as session:
with session.get(url, stream=True, allow_redirects=True) as response:
if response.history:
return response.url
else:
return url
def to_google_drive_download_url(view_url: str) -> str:
"""
Utility function to transform a view URL of google drive
to a download URL for google drive
Example input:
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
Example output:
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
"""
splits = view_url.split("/")
assert splits[-1] == "view"
file_id = splits[-2]
return f"https://drive.google.com/uc?export=download&id={file_id}"
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
"""
Download a file from google drive
Downloading an URL from google drive requires confirmation when
the file of the size is too big (google drive notifies that
anti-viral checks cannot be performed on such files)
"""
import requests
with requests.Session() as session:
# First get the confirmation token and append it to the URL
with session.get(url, stream=True, allow_redirects=True) as response:
for k, v in response.cookies.items():
if k.startswith("download_warning"):
url = url + "&confirm=" + v
# Then download the content of the file
with session.get(url, stream=True, verify=True) as response:
makedir(output_path)
path = os.path.join(output_path, output_file_name)
total_size = int(response.headers.get("Content-length", 0))
with open(path, "wb") as file:
from tqdm import tqdm
with tqdm(total=total_size) as progress_bar:
for block in response.iter_content(
chunk_size=io.DEFAULT_BUFFER_SIZE
):
file.write(block)
progress_bar.update(len(block))
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(
urllib.request.Request(url, headers={"User-Agent": "vissl"})
) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def download_url(
url: str,
root: str,
filename: Optional[str] = None,
md5: Optional[str] = None,
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under.
If None, use the basename of the URL.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir(root)
# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return
# expand redirect chain if needed
url = get_redirected_url(url)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def cache_url(url: str, cache_dir: str) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
parsed_url = urlparse(url)
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
makedir(dirname)
filename = url.split("/")[-1]
cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logging.info(f"Downloading {url} to {cached} ...")
cached = download(url, dirname, filename=filename)
logging.info(f"URL {url} cached in {cached}")
return cached
# TODO (prigoyal): convert this into RAII-style API
def create_file_symlink(file1, file2):
"""
Simply create the symlinks for a given file1 to file2.
Useful during model checkpointing to symlinks to the
latest successful checkpoint.
"""
try:
if g_pathmgr.exists(file2):
g_pathmgr.rm(file2)
g_pathmgr.symlink(file1, file2)
except Exception as e:
logging.info(f"Could NOT create symlink. Error: {e}")
def save_file(data, filename, append_to_json=True, verbose=True):
"""
Common i/o utility to handle saving data to various file formats.
Supported:
.pkl, .pickle, .npy, .json
Specifically for .json, users have the option to either append (default)
or rewrite by passing in Boolean value to append_to_json.
"""
if verbose:
logging.info(f"Saving data to file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "wb") as fopen:
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
elif file_ext == ".npy":
with g_pathmgr.open(filename, "wb") as fopen:
np.save(fopen, data)
elif file_ext == ".json":
if append_to_json:
with g_pathmgr.open(filename, "a") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
else:
with g_pathmgr.open(filename, "w") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "w") as fopen:
dump = yaml.dump(data)
fopen.write(dump)
fopen.flush()
else:
raise Exception(f"Saving {file_ext} is not supported yet")
if verbose:
logging.info(f"Saved data to file: {filename}")
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
"""
Common i/o utility to handle loading data from various file formats.
Supported:
.pkl, .pickle, .npy, .json
For the npy files, we support reading the files in mmap_mode.
If the mmap_mode of reading is not successful, we load data without the
mmap_mode.
"""
if verbose:
logging.info(f"Loading data from file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext == ".txt":
with g_pathmgr.open(filename, "r") as fopen:
data = fopen.readlines()
elif file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "rb") as fopen:
data = pickle.load(fopen, encoding="latin1")
elif file_ext == ".npy":
if mmap_mode:
try:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(
fopen,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
except ValueError as e:
logging.info(
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
)
data = np.load(
filename,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
logging.info("Successfully loaded without g_pathmgr")
except Exception:
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
else:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
elif file_ext == ".json":
with g_pathmgr.open(filename, "r") as fopen:
data = json.load(fopen)
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "r") as fopen:
data = yaml.load(fopen, Loader=yaml.FullLoader)
elif file_ext == ".csv":
with g_pathmgr.open(filename, "r") as fopen:
data = pd.read_csv(fopen)
else:
raise Exception(f"Reading from {file_ext} is not supported yet")
return data
def abspath(resource_path: str):
"""
Make a path absolute, but take into account prefixes like
"http://" or "manifold://"
"""
regex = re.compile(r"^\w+://")
if regex.match(resource_path) is None:
return os.path.abspath(resource_path)
else:
return resource_path
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
logging.info(f"Error creating directory: {dir_path}")
return is_success
def is_url(input_url):
"""
Check if an input string is a url. look for http(s):// and ignoring the case
"""
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
return is_url
def cleanup_dir(dir):
"""
Utility for deleting a directory. Useful for cleaning the storage space
that contains various training artifacts like checkpoints, data etc.
"""
if os.path.exists(dir):
logging.info(f"Deleting directory: {dir}")
shutil.rmtree(dir)
logging.info(f"Deleted contents of directory: {dir}")
def get_file_size(filename):
"""
Given a file, get the size of file in MB
"""
size_in_mb = os.path.getsize(filename) / float(1024**2)
return size_in_mb
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
================================================
# coding: utf-8
import sys
dataDir = '../../VQA'
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
from vqa import VQA
from vqaEvaluation.vqaEval import VQAEval
import matplotlib.pyplot as plt
import skimage.io as io
import json
import random
import os
# set up file names and paths
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
resultType ='fake'
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
# An example result json file has been provided in './Results' folder.
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
resultType, fileType) for fileType in fileTypes]
# create vqa object and vqaRes object
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(resFile, quesFile)
# create vqaEval object by taking vqa and vqaRes
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
# evaluate results
"""
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
By default it uses all the question ids in annotation file
"""
vqaEval.evaluate()
# print accuracies
print "\n"
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
print "Per Question Type Accuracy is the following:"
for quesType in vqaEval.accuracy['perQuestionType']:
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
print "\n"
print "Per Answer Type Accuracy is the following:"
for ansType in vqaEval.accuracy['perAnswerType']:
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
print "\n"
# demo how to use evalQA to retrieve low score result
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
if len(evals) > 0:
print 'ground truth answers'
randomEval = random.choice(evals)
randomAnn = vqa.loadQA(randomEval)
vqa.showQA(randomAnn)
print '\n'
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
ann = vqaRes.loadQA(randomEval)[0]
print "Answer: %s\n" %(ann['answer'])
imgId = randomAnn[0]['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# plot accuracy for various question types
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
plt.title('Per Question Type Accuracy', fontsize=10)
plt.xlabel('Question Types', fontsize=10)
plt.ylabel('Accuracy', fontsize=10)
plt.show()
# save evaluation results to ./Results folder
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
================================================
author='aagrawal'
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
================================================
# coding=utf-8
__author__='aagrawal'
import re
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
class VQAEval:
def __init__(self, vqa, vqaRes, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
self.params = {'question_id': vqa.getQuesIds()}
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
"youll": "you'll", "youre": "you're", "youve": "you've"}
self.manualMap = { 'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10'
}
self.articles = ['a',
'an',
'the'
]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [';', r"/", '[', ']', '"', '{', '}',
'(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!']
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params['question_id']]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
# print "computing accuracy"
step = 0
for quesId in quesIds:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
ansDic['answer'] = ansDic['answer'].strip()
resAns = res[quesId]['answer']
resAns = resAns.replace('\n', ' ')
resAns = resAns.replace('\t', ' ')
resAns = resAns.strip()
gtAcc = []
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
for gtAnsDatum in gts[quesId]['answers']:
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
acc = min(1, float(len(matchingAns))/3)
gtAcc.append(acc)
quesType = gts[quesId]['question_type']
ansType = gts[quesId]['answer_type']
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step%100 == 0:
self.updateProgress(step/float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
# print "Done computing accuracy"
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = self.periodStrip.sub("",
outText,
re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = ' '.join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100*acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength*progress))
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
sys.stdout.write(text)
sys.stdout.flush()
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
================================================
# coding: utf-8
from vqaTools.vqa import VQA
import random
import skimage.io as io
import matplotlib.pyplot as plt
import os
dataDir ='../../VQA'
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
# initialize VQA api for QA annotations
vqa=VQA(annFile, quesFile)
# load and display QA annotations for given question types
"""
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
"""
annIds = vqa.getQuesIds(quesTypes='how many');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given answer types
"""
ansTypes can be one of the following
yes/no
number
other
"""
annIds = vqa.getQuesIds(ansTypes='yes/no');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given images
"""
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
"""
ids = vqa.getImgIds()
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
================================================
__author__ = 'aagrawal'
================================================
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
================================================
__author__ = 'aagrawal'
__version__ = '0.9'
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
# print 'loading VQA annotations and questions into memory...'
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, 'r'))
questions = json.load(open(question_file, 'r'))
# print datetime.datetime.utcnow() - time_t
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
for ann in self.dataset['annotations']:
imgToQA[ann['image_id']] += [ann]
qa[ann['question_id']] = ann
for ques in self.questions['questions']:
qqa[ques['question_id']] = ques
# print 'index created!'
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
# for key, value in self.datset['info'].items():
# print '%s: %s'%(key, value)
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['question_id'] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(quesIds) == 0:
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['image_id'] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann['question_id']
print("Question: %s" % (self.qqa[quesId]['question']))
for ans in ann['answers']:
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset['info'] = copy.deepcopy(self.questions['info'])
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
res.dataset['license'] = copy.deepcopy(self.questions['license'])
# print 'Loading and preparing results... '
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, 'results is not an array of objects'
annsQuesIds = [ann['question_id'] for ann in anns]
assert set(annsQuesIds) == set(self.getQuesIds()), \
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
for ann in anns:
quesId = ann['question_id']
if res.dataset['task_type'] == 'Multiple Choice':
assert ann['answer'] in self.qqa[quesId][
'multiple_choices'], 'predicted answer is not one of the multiple choices'
qaAnn = self.qa[quesId]
ann['image_id'] = qaAnn['image_id']
ann['question_type'] = qaAnn['question_type']
ann['answer_type'] = qaAnn['answer_type']
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
res.dataset['annotations'] = anns
res.createIndex()
return res
================================================
FILE: minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
================================================
how many
what color is the
is the
where is the
what
what is
are the
what is the
is there a
does the
is the woman
is the man
what is on the
is it
is the girl
is the boy
is the dog
are they
who is
what kind of
what color are the
what is in the
what is the man
is there
what is the woman
what are the
what is the boy
are there
what is the girl
is this
how
which
how many people are
is the cat
why is the
are
will the
what type of
what is the dog
do
is she
does
do the
is
is the baby
are there any
is the lady
can
what animal is
where are the
is the sun
what are they
did the
what is the cat
what is the lady
how many clouds are
is that
is the little girl
is he
are these
how many trees are
how many pillows
are the people
why
is the young
how many windows are
is this a
what is the little
is the tv
how many animals are
who
how many pictures
how many plants are
how many birds are
what color is
what is the baby
is anyone
what color
how many bushes
is the old man
none of the above
================================================
FILE: minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
================================================
how many
is the
what
what color is the
what is the
is this
is this a
what is
are the
what kind of
is there a
what type of
is it
what are the
where is the
is there
does the
what color are the
are these
are there
which
is
what is the man
is the man
are
how
does this
what is on the
what does the
how many people are
what is in the
what is this
do
what are
are they
what time
what sport is
are there any
is he
what color is
why
where are the
what color
who is
what animal is
is the woman
is this an
do you
how many people are in
what room is
has
is this person
what is the woman
can you
why is the
is the person
what is the color of the
what is the person
could
was
is that a
what number is
what is the name
what brand
none of the above
================================================
FILE: minigpt4/common/vqa_tools/VQA/README.md
================================================
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
===================
## VQA v2.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
There is only one type of task
- Open-ended task
## VQA v1.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
- Abstract
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
- 600,000 answers for training and 300,000 answers for validation (10 per question)
There are two types of tasks
- Open-ended task
- Multiple-choice task (18 choices per question)
## Requirements ##
- python 2.7
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
## Files ##
./Questions
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
./Annotations
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
./Images
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
./PythonHelperTools
- This directory contains the Python API to read and visualize the VQA dataset
- vqaDemo.py (demo script)
- vqaTools (API to read and visualize data)
./PythonEvaluationTools
- This directory contains the Python evaluation code
- vqaEvalDemo.py (evaluation demo script)
- vqaEvaluation (evaluation code)
./Results
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
./QuestionTypes
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
- mscoco_question_types.txt
- abstract_v002_question_types.txt
## References ##
- [VQA: Visual Question Answering](http://visualqa.org/)
- [Microsoft COCO](http://mscoco.org/)
## Developers ##
- Aishwarya Agrawal (Virginia Tech)
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
================================================
FILE: minigpt4/common/vqa_tools/VQA/license.txt
================================================
Copyright (c) 2014, Aishwarya Agrawal
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
The views and conclusions contained in the software and documentation are
those
of the authors and should not be interpreted as representing official
policies,
either expressed or implied, of the FreeBSD Project.
================================================
FILE: minigpt4/common/vqa_tools/__init__.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"
================================================
FILE: minigpt4/common/vqa_tools/vqa.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"
__version__ = "0.9"
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
print("loading VQA annotations and questions into memory...")
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, "r"))
questions = json.load(open(question_file, "r"))
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
# create index
print("creating index...")
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
for ann in self.dataset["annotations"]:
imgToQA[ann["image_id"]] += [ann]
qa[ann["question_id"]] = ann
for ques in self.questions["questions"]:
qqa[ques["question_id"]] = ques
print("index created!")
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
for key, value in self.datset["info"].items():
print("%s: %s" % (key, value))
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(imgIds) == 0:
anns = sum(
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
[],
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["question_id"] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(quesIds) == 0:
anns = sum(
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["image_id"] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann["question_id"]
print("Question: %s" % (self.qqa[quesId]["question"]))
for ans in ann["answers"]:
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset["info"] = copy.deepcopy(self.questions["info"])
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
res.dataset["license"] = copy.deepcopy(self.questions["license"])
print("Loading and preparing results... ")
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, "results is not an array of objects"
annsQuesIds = [ann["question_id"] for ann in anns]
assert set(annsQuesIds) == set(
self.getQuesIds()
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
for ann in anns:
quesId = ann["question_id"]
if res.dataset["task_type"] == "Multiple Choice":
assert (
ann["answer"] in self.qqa[quesId]["multiple_choices"]
), "predicted answer is not one of the multiple choices"
qaAnn = self.qa[quesId]
ann["image_id"] = qaAnn["image_id"]
ann["question_type"] = qaAnn["question_type"]
ann["answer_type"] = qaAnn["answer_type"]
print(
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
)
res.dataset["annotations"] = anns
res.createIndex()
return res
================================================
FILE: minigpt4/common/vqa_tools/vqa_eval.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
# coding=utf-8
__author__ = "aagrawal"
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re
class VQAEval:
def __init__(self, vqa=None, vqaRes=None, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
if vqa is not None:
self.params = {"question_id": vqa.getQuesIds()}
self.contractions = {
"aint": "ain't",
gitextract_bx6vsnx8/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE.md
├── LICENSE_Lavis.md
├── MiniGPT4_Train.md
├── MiniGPTv2_Train.md
├── README.md
├── SECURITY.md
├── dataset/
│ ├── README_1_STAGE.md
│ ├── README_2_STAGE.md
│ ├── README_MINIGPTv2_FINETUNE.md
│ ├── convert_cc_sbu.py
│ └── convert_laion.py
├── demo.py
├── demo_v2.py
├── environment.yml
├── eval_configs/
│ ├── minigpt4_eval.yaml
│ ├── minigpt4_llama2_eval.yaml
│ ├── minigptv2_benchmark_evaluation.yaml
│ └── minigptv2_eval.yaml
├── eval_scripts/
│ ├── EVAL_README.md
│ ├── eval_ref.py
│ └── eval_vqa.py
├── minigpt4/
│ ├── __init__.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dist_utils.py
│ │ ├── eval_utils.py
│ │ ├── gradcam.py
│ │ ├── logger.py
│ │ ├── optims.py
│ │ ├── registry.py
│ │ ├── utils.py
│ │ └── vqa_tools/
│ │ ├── VQA/
│ │ │ ├── PythonEvaluationTools/
│ │ │ │ ├── vqaEvalDemo.py
│ │ │ │ └── vqaEvaluation/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqaEval.py
│ │ │ ├── PythonHelperTools/
│ │ │ │ ├── vqaDemo.py
│ │ │ │ └── vqaTools/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqa.py
│ │ │ ├── QuestionTypes/
│ │ │ │ ├── abstract_v002_question_types.txt
│ │ │ │ └── mscoco_question_types.txt
│ │ │ ├── README.md
│ │ │ └── license.txt
│ │ ├── __init__.py
│ │ ├── vqa.py
│ │ └── vqa_eval.py
│ ├── configs/
│ │ ├── datasets/
│ │ │ ├── aokvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── cc_sbu/
│ │ │ │ ├── align.yaml
│ │ │ │ └── defaults.yaml
│ │ │ ├── coco/
│ │ │ │ ├── caption.yaml
│ │ │ │ └── defaults_vqa.yaml
│ │ │ ├── coco_bbox/
│ │ │ │ ├── invrefcoco.yaml
│ │ │ │ ├── invrefcocog.yaml
│ │ │ │ ├── invrefcocop.yaml
│ │ │ │ ├── refcoco.yaml
│ │ │ │ ├── refcocog.yaml
│ │ │ │ └── refcocop.yaml
│ │ │ ├── flickr/
│ │ │ │ ├── caption_to_phrase.yaml
│ │ │ │ ├── default.yaml
│ │ │ │ └── object_to_phrase.yaml
│ │ │ ├── gqa/
│ │ │ │ └── balanced_val.yaml
│ │ │ ├── laion/
│ │ │ │ └── defaults.yaml
│ │ │ ├── llava/
│ │ │ │ ├── conversation.yaml
│ │ │ │ ├── detail.yaml
│ │ │ │ └── reason.yaml
│ │ │ ├── multitask_conversation/
│ │ │ │ └── default.yaml
│ │ │ ├── nlp/
│ │ │ │ └── unnatural_instruction.yaml
│ │ │ ├── ocrvqa/
│ │ │ │ └── ocrvqa.yaml
│ │ │ ├── okvqa/
│ │ │ │ └── defaults.yaml
│ │ │ ├── textcaps/
│ │ │ │ └── caption.yaml
│ │ │ └── vg/
│ │ │ └── ref.yaml
│ │ ├── default.yaml
│ │ └── models/
│ │ ├── minigpt4_llama2.yaml
│ │ ├── minigpt4_vicuna0.yaml
│ │ └── minigpt_v2.yaml
│ ├── conversation/
│ │ ├── __init__.py
│ │ └── conversation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── builders/
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset_builder.py
│ │ │ └── image_text_pair_builder.py
│ │ ├── data_utils.py
│ │ └── datasets/
│ │ ├── __init__.py
│ │ ├── aok_vqa_datasets.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── coco_caption.py
│ │ ├── coco_dataset.py
│ │ ├── coco_vqa_datasets.py
│ │ ├── dataloader_utils.py
│ │ ├── flickr.py
│ │ ├── gqa_datasets.py
│ │ ├── laion_dataset.py
│ │ ├── llava_dataset.py
│ │ ├── multitask_conversation.py
│ │ ├── ocrvqa_dataset.py
│ │ ├── text_caps.py
│ │ ├── unnatural_instruction.py
│ │ ├── vg_dataset.py
│ │ └── vqa_datasets.py
│ ├── models/
│ │ ├── Qformer.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── eva_vit.py
│ │ ├── minigpt4.py
│ │ ├── minigpt_base.py
│ │ ├── minigpt_v2.py
│ │ └── modeling_llama.py
│ ├── processors/
│ │ ├── __init__.py
│ │ ├── base_processor.py
│ │ ├── blip_processors.py
│ │ └── randaugment.py
│ ├── runners/
│ │ ├── __init__.py
│ │ └── runner_base.py
│ └── tasks/
│ ├── __init__.py
│ ├── base_task.py
│ └── image_text_pretrain.py
├── train.py
└── train_configs/
├── minigpt4_llama2_stage1_pretrain.yaml
├── minigpt4_llama2_stage2_finetune.yaml
├── minigpt4_stage1_pretrain.yaml
├── minigpt4_stage2_finetune.yaml
└── minigptv2_finetune.yaml
SYMBOL INDEX (708 symbols across 56 files)
FILE: demo.py
function parse_args (line 25) | def parse_args():
function setup_seeds (line 40) | def setup_seeds(config):
function gradio_reset (line 85) | def gradio_reset(chat_state, img_list):
function upload_img (line 93) | def upload_img(gr_img, text_input, chat_state):
function gradio_ask (line 103) | def gradio_ask(user_message, chatbot, chat_state):
function gradio_answer (line 111) | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
FILE: demo_v2.py
function parse_args (line 31) | def parse_args():
function extract_substrings (line 81) | def extract_substrings(string):
function is_overlapping (line 94) | def is_overlapping(rect1, rect2):
function computeIoU (line 100) | def computeIoU(bbox1, bbox2):
function save_tmp_img (line 115) | def save_tmp_img(visual_img):
function mask2bbox (line 122) | def mask2bbox(mask):
function escape_markdown (line 142) | def escape_markdown(text):
function reverse_escape (line 153) | def reverse_escape(text):
function visualize_all_bbox_together (line 189) | def visualize_all_bbox_together(image, generation):
function gradio_reset (line 383) | def gradio_reset(chat_state, img_list):
function image_upload_trigger (line 392) | def image_upload_trigger(upload_flag, replace_flag, img_list):
function example_trigger (line 401) | def example_trigger(text_input, image, upload_flag, replace_flag, img_li...
function gradio_ask (line 411) | def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, uplo...
function gradio_answer (line 454) | def gradio_answer(chatbot, chat_state, img_list, temperature):
function gradio_stream_answer (line 464) | def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
function gradio_visualize (line 483) | def gradio_visualize(chatbot, gr_img):
function gradio_taskselect (line 498) | def gradio_taskselect(idx):
FILE: eval_scripts/eval_ref.py
function list_of_str (line 18) | def list_of_str(arg):
FILE: eval_scripts/eval_vqa.py
function list_of_str (line 24) | def list_of_str(arg):
FILE: minigpt4/common/config.py
class Config (line 16) | class Config:
method __init__ (line 17) | def __init__(self, args):
method _validate_runner_config (line 44) | def _validate_runner_config(self, runner_config):
method _build_opt_list (line 53) | def _build_opt_list(self, opts):
method build_model_config (line 58) | def build_model_config(config, **kwargs):
method build_runner_config (line 85) | def build_runner_config(config):
method build_dataset_config (line 89) | def build_dataset_config(config):
method build_evaluation_dataset_config (line 117) | def build_evaluation_dataset_config(config):
method _convert_to_dot_list (line 138) | def _convert_to_dot_list(self, opts):
method get_config (line 152) | def get_config(self):
method run_cfg (line 156) | def run_cfg(self):
method datasets_cfg (line 160) | def datasets_cfg(self):
method evaluation_datasets_cfg (line 164) | def evaluation_datasets_cfg(self):
method model_cfg (line 168) | def model_cfg(self):
method pretty_print (line 171) | def pretty_print(self):
method _convert_node_to_json (line 189) | def _convert_node_to_json(self, node):
method to_dict (line 193) | def to_dict(self):
function node_to_dict (line 197) | def node_to_dict(node):
class ConfigValidator (line 201) | class ConfigValidator:
class _Argument (line 215) | class _Argument:
method __init__ (line 216) | def __init__(self, name, choices=None, type=None, help=None):
method __str__ (line 223) | def __str__(self):
method __init__ (line 233) | def __init__(self, description):
method __getitem__ (line 240) | def __getitem__(self, key):
method __str__ (line 245) | def __str__(self) -> str:
method add_argument (line 248) | def add_argument(self, *args, **kwargs):
method validate (line 254) | def validate(self, config=None):
method format_arguments (line 276) | def format_arguments(self):
method format_help (line 279) | def format_help(self):
method print_help (line 284) | def print_help(self):
function create_runner_config_validator (line 289) | def create_runner_config_validator():
FILE: minigpt4/common/dist_utils.py
function setup_for_distributed (line 17) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 33) | def is_dist_avail_and_initialized():
function get_world_size (line 41) | def get_world_size():
function get_rank (line 47) | def get_rank():
function is_main_process (line 53) | def is_main_process():
function init_distributed_mode (line 57) | def init_distributed_mode(args):
function get_dist_info (line 96) | def get_dist_info():
function main_process (line 110) | def main_process(func):
function download_cached_file (line 120) | def download_cached_file(url, check_hash=True, progress=False):
FILE: minigpt4/common/eval_utils.py
function eval_parser (line 17) | def eval_parser():
function prepare_texts (line 37) | def prepare_texts(texts, conv_temp):
function init_model (line 46) | def init_model(args):
function computeIoU (line 64) | def computeIoU(bbox1, bbox2):
FILE: minigpt4/common/gradcam.py
function getAttMap (line 7) | def getAttMap(img, attMap, blur=True, overlap=True):
FILE: minigpt4/common/logger.py
class SmoothedValue (line 19) | class SmoothedValue(object):
method __init__ (line 24) | def __init__(self, window_size=20, fmt=None):
method update (line 32) | def update(self, value, n=1):
method synchronize_between_processes (line 37) | def synchronize_between_processes(self):
method median (line 51) | def median(self):
method avg (line 56) | def avg(self):
method global_avg (line 61) | def global_avg(self):
method max (line 65) | def max(self):
method value (line 69) | def value(self):
method __str__ (line 72) | def __str__(self):
class MetricLogger (line 82) | class MetricLogger(object):
method __init__ (line 83) | def __init__(self, delimiter="\t"):
method update (line 87) | def update(self, **kwargs):
method __getattr__ (line 94) | def __getattr__(self, attr):
method __str__ (line 103) | def __str__(self):
method global_avg (line 109) | def global_avg(self):
method synchronize_between_processes (line 115) | def synchronize_between_processes(self):
method add_meter (line 119) | def add_meter(self, name, meter):
method log_every (line 122) | def log_every(self, iterable, print_freq, header=None):
class AttrDict (line 184) | class AttrDict(dict):
method __init__ (line 185) | def __init__(self, *args, **kwargs):
function setup_logger (line 190) | def setup_logger():
FILE: minigpt4/common/optims.py
class LinearWarmupStepLRScheduler (line 14) | class LinearWarmupStepLRScheduler:
method __init__ (line 15) | def __init__(
method step (line 37) | def step(self, cur_epoch, cur_step):
class LinearWarmupCosineLRScheduler (line 57) | class LinearWarmupCosineLRScheduler:
method __init__ (line 58) | def __init__(
method step (line 79) | def step(self, cur_epoch, cur_step):
function cosine_lr_schedule (line 99) | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
function warmup_lr_schedule (line 108) | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
function step_lr_schedule (line 115) | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
FILE: minigpt4/common/registry.py
class Registry (line 9) | class Registry:
method register_builder (line 22) | def register_builder(cls, name):
method register_task (line 54) | def register_task(cls, name):
method register_model (line 83) | def register_model(cls, name):
method register_processor (line 112) | def register_processor(cls, name):
method register_lr_scheduler (line 141) | def register_lr_scheduler(cls, name):
method register_runner (line 165) | def register_runner(cls, name):
method register_path (line 189) | def register_path(cls, name, path):
method register (line 205) | def register(cls, name, obj):
method get_builder_class (line 232) | def get_builder_class(cls, name):
method get_model_class (line 236) | def get_model_class(cls, name):
method get_task_class (line 240) | def get_task_class(cls, name):
method get_processor_class (line 244) | def get_processor_class(cls, name):
method get_lr_scheduler_class (line 248) | def get_lr_scheduler_class(cls, name):
method get_runner_class (line 252) | def get_runner_class(cls, name):
method list_runners (line 256) | def list_runners(cls):
method list_models (line 260) | def list_models(cls):
method list_tasks (line 264) | def list_tasks(cls):
method list_processors (line 268) | def list_processors(cls):
method list_lr_schedulers (line 272) | def list_lr_schedulers(cls):
method list_datasets (line 276) | def list_datasets(cls):
method get_path (line 280) | def get_path(cls, name):
method get (line 284) | def get(cls, name, default=None, no_warning=False):
method unregister (line 315) | def unregister(cls, name):
FILE: minigpt4/common/utils.py
function now (line 35) | def now():
function is_url (line 41) | def is_url(url_or_filename):
function get_cache_path (line 46) | def get_cache_path(rel_path):
function get_abs_path (line 50) | def get_abs_path(rel_path):
function load_json (line 54) | def load_json(filename):
function makedir (line 64) | def makedir(dir_path):
function get_redirected_url (line 78) | def get_redirected_url(url: str):
function to_google_drive_download_url (line 93) | def to_google_drive_download_url(view_url: str) -> str:
function download_google_drive_url (line 108) | def download_google_drive_url(url: str, output_path: str, output_file_na...
function _get_google_drive_file_id (line 141) | def _get_google_drive_file_id(url: str) -> Optional[str]:
function _urlretrieve (line 154) | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
function download_url (line 167) | def download_url(
function download_and_extract_archive (line 221) | def download_and_extract_archive(
function cache_url (line 242) | def cache_url(url: str, cache_dir: str) -> str:
function create_file_symlink (line 261) | def create_file_symlink(file1, file2):
function save_file (line 275) | def save_file(data, filename, append_to_json=True, verbose=True):
function load_file (line 313) | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
function abspath (line 374) | def abspath(resource_path: str):
function makedir (line 386) | def makedir(dir_path):
function is_url (line 400) | def is_url(input_url):
function cleanup_dir (line 408) | def cleanup_dir(dir):
function get_file_size (line 419) | def get_file_size(filename):
FILE: minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
class VQAEval (line 11) | class VQAEval:
method __init__ (line 12) | def __init__(self, vqa, vqaRes, n=2):
method evaluate (line 69) | def evaluate(self, quesIds=None):
method processPunctuation (line 130) | def processPunctuation(self, inText):
method processDigitArticle (line 142) | def processDigitArticle(self, inText):
method setAccuracy (line 157) | def setAccuracy(self, accQA, accQuesType, accAnsType):
method setEvalQA (line 162) | def setEvalQA(self, quesId, acc):
method setEvalQuesType (line 165) | def setEvalQuesType(self, quesId, quesType, acc):
method setEvalAnsType (line 170) | def setEvalAnsType(self, quesId, ansType, acc):
method updateProgress (line 175) | def updateProgress(self, progress):
FILE: minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
class VQA (line 24) | class VQA:
method __init__ (line 25) | def __init__(self, annotation_file=None, question_file=None):
method createIndex (line 47) | def createIndex(self):
method info (line 63) | def info(self):
method getQuesIds (line 72) | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
method getImgIds (line 96) | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
method loadQA (line 120) | def loadQA(self, ids=[]):
method showQA (line 131) | def showQA(self, anns):
method loadRes (line 145) | def loadRes(self, resFile, quesFile):
FILE: minigpt4/common/vqa_tools/vqa.py
class VQA (line 31) | class VQA:
method __init__ (line 32) | def __init__(self, annotation_file=None, question_file=None):
method createIndex (line 53) | def createIndex(self):
method info (line 71) | def info(self):
method getQuesIds (line 79) | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
method getImgIds (line 114) | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
method loadQA (line 148) | def loadQA(self, ids=[]):
method showQA (line 159) | def showQA(self, anns):
method loadRes (line 173) | def loadRes(self, resFile, quesFile):
FILE: minigpt4/common/vqa_tools/vqa_eval.py
class VQAEval (line 18) | class VQAEval:
method __init__ (line 19) | def __init__(self, vqa=None, vqaRes=None, n=2):
method evaluate (line 193) | def evaluate(self, quesIds=None):
method processPunctuation (line 249) | def processPunctuation(self, inText):
method processDigitArticle (line 261) | def processDigitArticle(self, inText):
method setAccuracy (line 276) | def setAccuracy(self, accQA, accQuesType, accAnsType):
method setEvalQA (line 292) | def setEvalQA(self, quesId, acc):
method setEvalQuesType (line 295) | def setEvalQuesType(self, quesId, quesType, acc):
method setEvalAnsType (line 300) | def setEvalAnsType(self, quesId, ansType, acc):
method updateProgress (line 305) | def updateProgress(self, progress):
FILE: minigpt4/conversation/conversation.py
class SeparatorStyle (line 17) | class SeparatorStyle(Enum):
class Conversation (line 24) | class Conversation:
method get_prompt (line 38) | def get_prompt(self):
method append_message (line 59) | def append_message(self, role, message):
method to_gradio_chatbot (line 62) | def to_gradio_chatbot(self):
method copy (line 71) | def copy(self):
method dict (line 83) | def dict(self):
class StoppingCriteriaSub (line 96) | class StoppingCriteriaSub(StoppingCriteria):
method __init__ (line 98) | def __init__(self, stops=[], encounters=1):
method __call__ (line 102) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
class Chat (line 139) | class Chat:
method __init__ (line 140) | def __init__(self, model, vis_processor, device='cuda:0', stopping_cri...
method ask (line 151) | def ask(self, text, conv):
method answer_prepare (line 158) | def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams...
method answer (line 185) | def answer(self, conv, img_list, **kargs):
method stream_answer (line 196) | def stream_answer(self, conv, img_list, **kargs):
method model_generate (line 204) | def model_generate(self, *args, **kwargs):
method encode_img (line 210) | def encode_img(self, img_list):
method upload_img (line 227) | def upload_img(self, image, conv, img_list):
FILE: minigpt4/datasets/builders/__init__.py
function load_dataset (line 23) | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
class DatasetZoo (line 61) | class DatasetZoo:
method __init__ (line 62) | def __init__(self) -> None:
method get_names (line 68) | def get_names(self):
FILE: minigpt4/datasets/builders/base_dataset_builder.py
class BaseDatasetBuilder (line 25) | class BaseDatasetBuilder:
method __init__ (line 28) | def __init__(self, cfg=None):
method build_datasets (line 45) | def build_datasets(self):
method build_processors (line 61) | def build_processors(self):
method _build_proc_from_cfg (line 80) | def _build_proc_from_cfg(cfg):
method default_config_path (line 88) | def default_config_path(cls, type="default"):
method _download_data (line 91) | def _download_data(self):
method _download_ann (line 95) | def _download_ann(self):
method _download_vis (line 152) | def _download_vis(self):
method build (line 166) | def build(self):
function load_dataset_config (line 232) | def load_dataset_config(cfg_path):
FILE: minigpt4/datasets/builders/image_text_pair_builder.py
class MultitaskConversationBuilder (line 24) | class MultitaskConversationBuilder(BaseDatasetBuilder):
method build_datasets (line 30) | def build_datasets(self):
class UnnaturalInstructionBuilder (line 50) | class UnnaturalInstructionBuilder(BaseDatasetBuilder):
method build_datasets (line 56) | def build_datasets(self):
class LlavaDetailBuilder (line 75) | class LlavaDetailBuilder(BaseDatasetBuilder):
method build_datasets (line 81) | def build_datasets(self):
class LlavaReasonBuilder (line 102) | class LlavaReasonBuilder(BaseDatasetBuilder):
method build_datasets (line 108) | def build_datasets(self):
method build_datasets (line 133) | def build_datasets(self):
class LlavaReasonBuilder (line 127) | class LlavaReasonBuilder(BaseDatasetBuilder):
method build_datasets (line 108) | def build_datasets(self):
method build_datasets (line 133) | def build_datasets(self):
class AllRefCOCOBuilder (line 152) | class AllRefCOCOBuilder(BaseDatasetBuilder):
method build_datasets (line 154) | def build_datasets(self):
class RefCOCOBuilder (line 185) | class RefCOCOBuilder(AllRefCOCOBuilder):
class RefCOCOPBuilder (line 192) | class RefCOCOPBuilder(AllRefCOCOBuilder):
class RefCOCOGBuilder (line 200) | class RefCOCOGBuilder(AllRefCOCOBuilder):
class RefCOCOBuilder (line 207) | class RefCOCOBuilder(AllRefCOCOBuilder):
class RefCOCOPBuilder (line 215) | class RefCOCOPBuilder(AllRefCOCOBuilder):
class RefCOCOGBuilder (line 223) | class RefCOCOGBuilder(AllRefCOCOBuilder):
class RefVisualGenomeBuilder (line 230) | class RefVisualGenomeBuilder(BaseDatasetBuilder):
method build_datasets (line 236) | def build_datasets(self):
class TextcapCaptionBuilder (line 257) | class TextcapCaptionBuilder(BaseDatasetBuilder):
method _download_ann (line 262) | def _download_ann(self):
method _download_vis (line 265) | def _download_vis(self):
method build (line 268) | def build(self):
class COCOVQABuilder (line 289) | class COCOVQABuilder(BaseDatasetBuilder):
class OKVQABuilder (line 297) | class OKVQABuilder(COCOVQABuilder):
class AOKVQABuilder (line 304) | class AOKVQABuilder(BaseDatasetBuilder):
class GQABuilder (line 311) | class GQABuilder(BaseDatasetBuilder):
class GroundedCaptionBuilder (line 321) | class GroundedCaptionBuilder(BaseDatasetBuilder):
method build_datasets (line 327) | def build_datasets(self):
class CaptionToPhraseBuilder (line 347) | class CaptionToPhraseBuilder(BaseDatasetBuilder):
method build_datasets (line 353) | def build_datasets(self):
method build_datasets (line 378) | def build_datasets(self):
class CaptionToPhraseBuilder (line 372) | class CaptionToPhraseBuilder(BaseDatasetBuilder):
method build_datasets (line 353) | def build_datasets(self):
method build_datasets (line 378) | def build_datasets(self):
class DocumentVQABuilder (line 399) | class DocumentVQABuilder(BaseDatasetBuilder):
method _download_ann (line 400) | def _download_ann(self):
method _download_vis (line 403) | def _download_vis(self):
method build (line 406) | def build(self):
class OCRVQABuilder (line 425) | class OCRVQABuilder(DocumentVQABuilder):
class CCSBUBuilder (line 431) | class CCSBUBuilder(BaseDatasetBuilder):
method _download_ann (line 436) | def _download_ann(self):
method _download_vis (line 439) | def _download_vis(self):
method build (line 442) | def build(self):
class LaionBuilder (line 463) | class LaionBuilder(BaseDatasetBuilder):
method _download_ann (line 468) | def _download_ann(self):
method _download_vis (line 471) | def _download_vis(self):
method build (line 474) | def build(self):
class COCOCapBuilder (line 496) | class COCOCapBuilder(BaseDatasetBuilder):
class CCSBUAlignBuilder (line 506) | class CCSBUAlignBuilder(BaseDatasetBuilder):
method build_datasets (line 513) | def build_datasets(self):
FILE: minigpt4/datasets/data_utils.py
class ChainDataset (line 33) | class ChainDataset(wds.DataPipeline):
method __init__ (line 43) | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
method __iter__ (line 59) | def __iter__(self):
function apply_to_sample (line 66) | def apply_to_sample(f, sample):
function move_to_cuda (line 83) | def move_to_cuda(sample):
function prepare_sample (line 90) | def prepare_sample(samples, cuda_enabled=True):
function reorg_datasets_by_split (line 99) | def reorg_datasets_by_split(datasets, batch_sizes):
function concat_datasets (line 128) | def concat_datasets(datasets):
FILE: minigpt4/datasets/datasets/aok_vqa_datasets.py
class __DisplMixin (line 19) | class __DisplMixin:
method displ_item (line 20) | def displ_item(self, index):
class AOKVQADataset (line 35) | class AOKVQADataset(VQADataset, __DisplMixin):
method __init__ (line 36) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method get_data (line 51) | def get_data(self, index):
method __getitem__ (line 80) | def __getitem__(self, index):
class AOKVQGDataset (line 95) | class AOKVQGDataset(AOKVQADataset):
method __init__ (line 97) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 108) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/base_dataset.py
class BaseDataset (line 17) | class BaseDataset(Dataset):
method __init__ (line 18) | def __init__(
method __len__ (line 43) | def __len__(self):
method collater (line 46) | def collater(self, samples):
method set_processors (line 49) | def set_processors(self, vis_processor, text_processor):
method _add_instance_ids (line 53) | def _add_instance_ids(self, key="instance_id"):
class ConcatDataset (line 59) | class ConcatDataset(ConcatDataset):
method __init__ (line 60) | def __init__(self, datasets: Iterable[Dataset]) -> None:
method collater (line 63) | def collater(self, samples):
FILE: minigpt4/datasets/datasets/caption_datasets.py
class __DisplMixin (line 16) | class __DisplMixin:
method displ_item (line 17) | def displ_item(self, index):
class CaptionDataset (line 29) | class CaptionDataset(BaseDataset, __DisplMixin):
method __init__ (line 30) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 45) | def __getitem__(self, index):
class COCOCaptionDataset (line 65) | class COCOCaptionDataset(BaseDataset, __DisplMixin):
method __init__ (line 66) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 108) | def __getitem__(self, index):
class CaptionEvalDataset (line 129) | class CaptionEvalDataset(BaseDataset, __DisplMixin):
method __init__ (line 130) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 138) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/cc_sbu_dataset.py
class CCSBUDataset (line 8) | class CCSBUDataset(BaseDataset):
method __init__ (line 9) | def __init__(self, vis_processor, text_processor, location):
method to_dict (line 22) | def to_dict(self, sample):
class CCSBUAlignDataset (line 29) | class CCSBUAlignDataset(CaptionDataset):
method __getitem__ (line 31) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/coco_caption.py
class COCOCapEvalDataset (line 26) | class COCOCapEvalDataset(CaptionEvalDataset):
method __init__ (line 27) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 35) | def __getitem__(self, index):
class NoCapsEvalDataset (line 52) | class NoCapsEvalDataset(CaptionEvalDataset):
method __init__ (line 53) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 61) | def __getitem__(self, index):
class RefCOCOEvalData (line 78) | class RefCOCOEvalData(torch.utils.data.Dataset):
method __init__ (line 79) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 84) | def __len__(self):
method __getitem__ (line 87) | def __getitem__(self, idx):
class EvalCaptionData (line 97) | class EvalCaptionData(torch.utils.data.Dataset):
method __init__ (line 98) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 108) | def __len__(self):
method __getitem__ (line 111) | def __getitem__(self, idx):
FILE: minigpt4/datasets/datasets/coco_dataset.py
class ReferCOCODataset (line 21) | class ReferCOCODataset(Dataset):
method __init__ (line 22) | def __init__(self, vis_processor, text_processor, vis_root, ann_path, ...
method __len__ (line 46) | def __len__(self):
method preprocess (line 49) | def preprocess(self, index):
method __getitem__ (line 82) | def __getitem__(self, index):
class InvReferCOCODataset (line 96) | class InvReferCOCODataset(ReferCOCODataset):
method __init__ (line 97) | def __init__(self, *args, **kwargs):
method __getitem__ (line 110) | def __getitem__(self, index):
class REFER (line 125) | class REFER:
method __init__ (line 126) | def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='un...
method createIndex (line 158) | def createIndex(self):
method getRefIds (line 221) | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
method getAnnIds (line 252) | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
method getImgIds (line 272) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 281) | def getCatIds(self):
method loadRefs (line 284) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 290) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 296) | def loadImgs(self, image_ids=[]):
method loadCats (line 302) | def loadCats(self, cat_ids=[]):
method getRefBox (line 308) | def getRefBox(self, ref_id):
method showRef (line 313) | def showRef(self, ref, seg_box='box'):
FILE: minigpt4/datasets/datasets/coco_vqa_datasets.py
class __DisplMixin (line 19) | class __DisplMixin:
method displ_item (line 20) | def displ_item(self, index):
class COCOVQADataset (line 34) | class COCOVQADataset(VQADataset, __DisplMixin):
method __init__ (line 35) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method get_data (line 51) | def get_data(self, index):
method __getitem__ (line 81) | def __getitem__(self, index):
class COCOVQAEvalDataset (line 94) | class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
method __init__ (line 95) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 126) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/dataloader_utils.py
class MultiIterLoader (line 15) | class MultiIterLoader:
method __init__ (line 24) | def __init__(self, loaders, ratios=None):
method __next__ (line 40) | def __next__(self):
class PrefetchLoader (line 46) | class PrefetchLoader(object):
method __init__ (line 54) | def __init__(self, loader):
method __iter__ (line 58) | def __iter__(self):
method __len__ (line 73) | def __len__(self):
method preload (line 76) | def preload(self, it):
method next (line 101) | def next(self, it):
method __getattr__ (line 109) | def __getattr__(self, name):
function record_cuda_stream (line 114) | def record_cuda_stream(batch):
class IterLoader (line 127) | class IterLoader:
method __init__ (line 135) | def __init__(self, dataloader: DataLoader, use_distributed: bool = Fal...
method epoch (line 142) | def epoch(self) -> int:
method __next__ (line 145) | def __next__(self):
method __iter__ (line 158) | def __iter__(self):
method __len__ (line 161) | def __len__(self):
FILE: minigpt4/datasets/datasets/flickr.py
class GroundedDetailDataset (line 21) | class GroundedDetailDataset(Dataset):
method __init__ (line 22) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 42) | def __len__(self):
method __getitem__ (line 45) | def __getitem__(self, index):
class CaptionToObjectDataset (line 68) | class CaptionToObjectDataset(Dataset):
method __init__ (line 69) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 86) | def __len__(self):
method __getitem__ (line 89) | def __getitem__(self, index):
class PhraseToObjectDataset (line 117) | class PhraseToObjectDataset(Dataset):
method __init__ (line 118) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 135) | def __len__(self):
method __getitem__ (line 138) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/gqa_datasets.py
class __DisplMixin (line 18) | class __DisplMixin:
method displ_item (line 19) | def displ_item(self, index):
class GQADataset (line 33) | class GQADataset(VQADataset, __DisplMixin):
method __init__ (line 34) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
method __getitem__ (line 41) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/laion_dataset.py
class LaionDataset (line 12) | class LaionDataset(BaseDataset):
method __init__ (line 13) | def __init__(self, vis_processor, text_processor, location):
method to_dict (line 26) | def to_dict(self, sample):
FILE: minigpt4/datasets/datasets/llava_dataset.py
class LlavaDetailDataset (line 18) | class LlavaDetailDataset(Dataset):
method __init__ (line 19) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 32) | def __len__(self):
method __getitem__ (line 35) | def __getitem__(self, index):
class LlavaReasonDataset (line 55) | class LlavaReasonDataset(Dataset):
method __init__ (line 56) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 69) | def __len__(self):
method __getitem__ (line 72) | def __getitem__(self, index):
class LlavaConversationDataset (line 95) | class LlavaConversationDataset(Dataset):
method __init__ (line 96) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 114) | def __len__(self):
method __getitem__ (line 117) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/multitask_conversation.py
class MultiTaskConversationDataset (line 23) | class MultiTaskConversationDataset(Dataset):
method __init__ (line 24) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 40) | def __len__(self):
method __getitem__ (line 43) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/ocrvqa_dataset.py
class OCRVQADataset (line 21) | class OCRVQADataset(Dataset):
method __init__ (line 22) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method create_data (line 38) | def create_data(self, ann_path):
method __len__ (line 59) | def __len__(self):
method __getitem__ (line 62) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/text_caps.py
class TextCapDataset (line 22) | class TextCapDataset(Dataset):
method __init__ (line 23) | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
method __len__ (line 57) | def __len__(self):
method __getitem__ (line 61) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/unnatural_instruction.py
class UnnaturalDataset (line 21) | class UnnaturalDataset(Dataset):
method __init__ (line 22) | def __init__(self, text_processor, ann_path):
method __len__ (line 32) | def __len__(self):
method __getitem__ (line 35) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/vg_dataset.py
class ReferVisualGenomeDataset (line 16) | class ReferVisualGenomeDataset(Dataset):
method __init__ (line 17) | def __init__(self, vis_processor, text_processor, data_dir):
method __len__ (line 45) | def __len__(self):
method preprocess (line 48) | def preprocess(self, index):
method __getitem__ (line 77) | def __getitem__(self, index):
FILE: minigpt4/datasets/datasets/vqa_datasets.py
class VQADataset (line 15) | class VQADataset(BaseDataset):
method __init__ (line 16) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
class VQAEvalDataset (line 20) | class VQAEvalDataset(BaseDataset):
method __init__ (line 21) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
class OKVQAEvalData (line 25) | class OKVQAEvalData(torch.utils.data.Dataset):
method __init__ (line 26) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 31) | def __len__(self):
method __getitem__ (line 34) | def __getitem__(self, idx):
class VizWizEvalData (line 46) | class VizWizEvalData(torch.utils.data.Dataset):
method __init__ (line 47) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 52) | def __len__(self):
method __getitem__ (line 55) | def __getitem__(self, idx):
class IconQAEvalData (line 67) | class IconQAEvalData(torch.utils.data.Dataset):
method __init__ (line 68) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 73) | def __len__(self):
method __getitem__ (line 76) | def __getitem__(self, idx):
class GQAEvalData (line 88) | class GQAEvalData(torch.utils.data.Dataset):
method __init__ (line 89) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 94) | def __len__(self):
method __getitem__ (line 97) | def __getitem__(self, idx):
class HMEvalData (line 109) | class HMEvalData(torch.utils.data.Dataset):
method __init__ (line 110) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 115) | def __len__(self):
method __getitem__ (line 118) | def __getitem__(self, idx):
class VSREvalData (line 130) | class VSREvalData(torch.utils.data.Dataset):
method __init__ (line 131) | def __init__(self, loaded_data, vis_processor, root_path):
method __len__ (line 136) | def __len__(self):
method __getitem__ (line 139) | def __getitem__(self, idx):
FILE: minigpt4/models/Qformer.py
class BertEmbeddings (line 51) | class BertEmbeddings(nn.Module):
method __init__ (line 54) | def __init__(self, config):
method forward (line 78) | def forward(
class BertSelfAttention (line 111) | class BertSelfAttention(nn.Module):
method __init__ (line 112) | def __init__(self, config, is_cross_attention):
method save_attn_gradients (line 149) | def save_attn_gradients(self, attn_gradients):
method get_attn_gradients (line 152) | def get_attn_gradients(self):
method save_attention_map (line 155) | def save_attention_map(self, attention_map):
method get_attention_map (line 158) | def get_attention_map(self):
method transpose_for_scores (line 161) | def transpose_for_scores(self, x):
method forward (line 169) | def forward(
class BertSelfOutput (line 278) | class BertSelfOutput(nn.Module):
method __init__ (line 279) | def __init__(self, config):
method forward (line 285) | def forward(self, hidden_states, input_tensor):
class BertAttention (line 292) | class BertAttention(nn.Module):
method __init__ (line 293) | def __init__(self, config, is_cross_attention=False):
method prune_heads (line 299) | def prune_heads(self, heads):
method forward (line 322) | def forward(
class BertIntermediate (line 349) | class BertIntermediate(nn.Module):
method __init__ (line 350) | def __init__(self, config):
method forward (line 358) | def forward(self, hidden_states):
class BertOutput (line 364) | class BertOutput(nn.Module):
method __init__ (line 365) | def __init__(self, config):
method forward (line 371) | def forward(self, hidden_states, input_tensor):
class BertLayer (line 378) | class BertLayer(nn.Module):
method __init__ (line 379) | def __init__(self, config, layer_num):
method forward (line 402) | def forward(
method feed_forward_chunk (line 476) | def feed_forward_chunk(self, attention_output):
method feed_forward_chunk_query (line 481) | def feed_forward_chunk_query(self, attention_output):
class BertEncoder (line 487) | class BertEncoder(nn.Module):
method __init__ (line 488) | def __init__(self, config):
method forward (line 495) | def forward(
class BertPooler (line 592) | class BertPooler(nn.Module):
method __init__ (line 593) | def __init__(self, config):
method forward (line 598) | def forward(self, hidden_states):
class BertPredictionHeadTransform (line 607) | class BertPredictionHeadTransform(nn.Module):
method __init__ (line 608) | def __init__(self, config):
method forward (line 617) | def forward(self, hidden_states):
class BertLMPredictionHead (line 624) | class BertLMPredictionHead(nn.Module):
method __init__ (line 625) | def __init__(self, config):
method forward (line 638) | def forward(self, hidden_states):
class BertOnlyMLMHead (line 644) | class BertOnlyMLMHead(nn.Module):
method __init__ (line 645) | def __init__(self, config):
method forward (line 649) | def forward(self, sequence_output):
class BertPreTrainedModel (line 654) | class BertPreTrainedModel(PreTrainedModel):
method _init_weights (line 664) | def _init_weights(self, module):
class BertModel (line 677) | class BertModel(BertPreTrainedModel):
method __init__ (line 687) | def __init__(self, config, add_pooling_layer=False):
method get_input_embeddings (line 699) | def get_input_embeddings(self):
method set_input_embeddings (line 702) | def set_input_embeddings(self, value):
method _prune_heads (line 705) | def _prune_heads(self, heads_to_prune):
method get_extended_attention_mask (line 713) | def get_extended_attention_mask(
method forward (line 804) | def forward(
class BertLMHeadModel (line 968) | class BertLMHeadModel(BertPreTrainedModel):
method __init__ (line 973) | def __init__(self, config):
method get_output_embeddings (line 981) | def get_output_embeddings(self):
method set_output_embeddings (line 984) | def set_output_embeddings(self, new_embeddings):
method forward (line 987) | def forward(
method prepare_inputs_for_generation (line 1097) | def prepare_inputs_for_generation(
method _reorder_cache (line 1120) | def _reorder_cache(self, past, beam_idx):
class BertForMaskedLM (line 1131) | class BertForMaskedLM(BertPreTrainedModel):
method __init__ (line 1136) | def __init__(self, config):
method get_output_embeddings (line 1144) | def get_output_embeddings(self):
method set_output_embeddings (line 1147) | def set_output_embeddings(self, new_embeddings):
method forward (line 1150) | def forward(
FILE: minigpt4/models/__init__.py
function load_model (line 29) | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint...
function load_preprocess (line 63) | def load_preprocess(config):
function load_model_and_preprocess (line 115) | def load_model_and_preprocess(name, model_type, is_eval=False, device="c...
class ModelZoo (line 163) | class ModelZoo:
method __init__ (line 174) | def __init__(self) -> None:
method __str__ (line 180) | def __str__(self) -> str:
method __iter__ (line 195) | def __iter__(self):
method __len__ (line 198) | def __len__(self):
FILE: minigpt4/models/base_model.py
class BaseModel (line 30) | class BaseModel(nn.Module):
method __init__ (line 33) | def __init__(self):
method device (line 37) | def device(self):
method load_checkpoint (line 40) | def load_checkpoint(self, url_or_filename):
method from_pretrained (line 70) | def from_pretrained(cls, model_type):
method default_config_path (line 86) | def default_config_path(cls, model_type):
method load_checkpoint_from_config (line 92) | def load_checkpoint_from_config(self, cfg, **kwargs):
method before_evaluation (line 113) | def before_evaluation(self, **kwargs):
method show_n_params (line 116) | def show_n_params(self, return_str=True):
method maybe_autocast (line 131) | def maybe_autocast(self, dtype=torch.float16):
method init_vision_encoder (line 142) | def init_vision_encoder(
method init_llm (line 171) | def init_llm(cls, llama_model_path, low_resource=False, low_res_device...
method load_from_pretrained (line 210) | def load_from_pretrained(self, url_or_filename):
function disabled_train (line 231) | def disabled_train(self, mode=True):
class LayerNorm (line 237) | class LayerNorm(nn.LayerNorm):
method forward (line 240) | def forward(self, x: torch.Tensor):
FILE: minigpt4/models/eva_vit.py
function _cfg (line 20) | def _cfg(url='', **kwargs):
class DropPath (line 30) | class DropPath(nn.Module):
method __init__ (line 33) | def __init__(self, drop_prob=None):
method forward (line 37) | def forward(self, x):
method extra_repr (line 40) | def extra_repr(self) -> str:
class Mlp (line 44) | class Mlp(nn.Module):
method __init__ (line 45) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 54) | def forward(self, x):
class Attention (line 64) | class Attention(nn.Module):
method __init__ (line 65) | def __init__(
method forward (line 118) | def forward(self, x, rel_pos_bias=None):
class Block (line 151) | class Block(nn.Module):
method __init__ (line 153) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 173) | def forward(self, x, rel_pos_bias=None):
class PatchEmbed (line 183) | class PatchEmbed(nn.Module):
method __init__ (line 186) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
method forward (line 198) | def forward(self, x, **kwargs):
class RelativePositionBias (line 207) | class RelativePositionBias(nn.Module):
method __init__ (line 209) | def __init__(self, window_size, num_heads):
method forward (line 238) | def forward(self):
class VisionTransformer (line 246) | class VisionTransformer(nn.Module):
method __init__ (line 249) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
method fix_init_weight (line 300) | def fix_init_weight(self):
method _init_weights (line 308) | def _init_weights(self, m):
method get_classifier (line 317) | def get_classifier(self):
method reset_classifier (line 320) | def reset_classifier(self, num_classes, global_pool=''):
method forward_features (line 324) | def forward_features(self, x):
method forward (line 349) | def forward(self, x):
method get_intermediate_layers (line 354) | def get_intermediate_layers(self, x):
function interpolate_pos_embed (line 373) | def interpolate_pos_embed(model, checkpoint_model):
function convert_weights_to_fp16 (line 397) | def convert_weights_to_fp16(model: nn.Module):
function create_eva_vit_g (line 415) | def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=Fals...
FILE: minigpt4/models/minigpt4.py
class MiniGPT4 (line 15) | class MiniGPT4(MiniGPTBase):
method __init__ (line 25) | def __init__(
method init_Qformer (line 88) | def init_Qformer(cls, num_query_token, vision_width, freeze):
method encode_img (line 118) | def encode_img(self, image):
method from_config (line 148) | def from_config(cls, cfg):
FILE: minigpt4/models/minigpt_base.py
class MiniGPTBase (line 14) | class MiniGPTBase(BaseModel):
method __init__ (line 19) | def __init__(
method vit_to_cpu (line 62) | def vit_to_cpu(self):
method get_context_emb (line 68) | def get_context_emb(self, prompt, img_list):
method prompt_wrap (line 83) | def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
method concat_emb_input_output (line 137) | def concat_emb_input_output(self, input_embs, input_atts, output_embs,...
method tokenize_conversation (line 166) | def tokenize_conversation(self, conv_q, conv_a):
method preparing_embedding (line 211) | def preparing_embedding(self, samples):
method forward (line 273) | def forward(self, samples, reduction='mean'):
method embed_tokens (line 310) | def embed_tokens(self, token_ids):
method generate (line 318) | def generate(
method multi_select (line 394) | def multi_select(self, images, texts, answers, num_cand=None):
FILE: minigpt4/models/minigpt_v2.py
class MiniGPTv2 (line 15) | class MiniGPTv2(MiniGPTBase):
method __init__ (line 24) | def __init__(
method encode_img (line 75) | def encode_img(self, image):
method from_config (line 92) | def from_config(cls, cfg):
FILE: minigpt4/models/modeling_llama.py
class LlamaForCausalLM (line 14) | class LlamaForCausalLM(LlamaForCausalLMOrig):
method forward (line 18) | def forward(
FILE: minigpt4/processors/__init__.py
function load_processor (line 25) | def load_processor(name, cfg=None):
FILE: minigpt4/processors/base_processor.py
class BaseProcessor (line 11) | class BaseProcessor:
method __init__ (line 12) | def __init__(self):
method __call__ (line 16) | def __call__(self, item):
method from_config (line 20) | def from_config(cls, cfg=None):
method build (line 23) | def build(self, **kwargs):
FILE: minigpt4/processors/blip_processors.py
class BlipImageBaseProcessor (line 18) | class BlipImageBaseProcessor(BaseProcessor):
method __init__ (line 19) | def __init__(self, mean=None, std=None):
class BlipCaptionProcessor (line 29) | class BlipCaptionProcessor(BaseProcessor):
method __init__ (line 30) | def __init__(self, prompt="", max_words=50):
method __call__ (line 34) | def __call__(self, caption):
method from_config (line 40) | def from_config(cls, cfg=None):
method pre_caption (line 49) | def pre_caption(self, caption):
class Blip2ImageTrainProcessor (line 72) | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
method __init__ (line 73) | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5,...
method __call__ (line 87) | def __call__(self, item):
method from_config (line 91) | def from_config(cls, cfg=None):
class Blip2ImageEvalProcessor (line 113) | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
method __init__ (line 114) | def __init__(self, image_size=224, mean=None, std=None):
method __call__ (line 127) | def __call__(self, item):
method from_config (line 131) | def from_config(cls, cfg=None):
FILE: minigpt4/processors/randaugment.py
function identity_func (line 15) | def identity_func(img):
function autocontrast_func (line 19) | def autocontrast_func(img, cutoff=0):
function equalize_func (line 52) | def equalize_func(img):
function rotate_func (line 76) | def rotate_func(img, degree, fill=(0, 0, 0)):
function solarize_func (line 87) | def solarize_func(img, thresh=128):
function color_func (line 97) | def color_func(img, factor):
function contrast_func (line 115) | def contrast_func(img, factor):
function brightness_func (line 129) | def brightness_func(img, factor):
function sharpness_func (line 138) | def sharpness_func(img, factor):
function shear_x_func (line 159) | def shear_x_func(img, factor, fill=(0, 0, 0)):
function translate_x_func (line 168) | def translate_x_func(img, offset, fill=(0, 0, 0)):
function translate_y_func (line 180) | def translate_y_func(img, offset, fill=(0, 0, 0)):
function posterize_func (line 192) | def posterize_func(img, bits):
function shear_y_func (line 200) | def shear_y_func(img, factor, fill=(0, 0, 0)):
function cutout_func (line 209) | def cutout_func(img, pad_size, replace=(0, 0, 0)):
function enhance_level_to_args (line 223) | def enhance_level_to_args(MAX_LEVEL):
function shear_level_to_args (line 230) | def shear_level_to_args(MAX_LEVEL, replace_value):
function translate_level_to_args (line 240) | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
function cutout_level_to_args (line 250) | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
function solarize_level_to_args (line 258) | def solarize_level_to_args(MAX_LEVEL):
function none_level_to_args (line 266) | def none_level_to_args(level):
function posterize_level_to_args (line 270) | def posterize_level_to_args(MAX_LEVEL):
function rotate_level_to_args (line 278) | def rotate_level_to_args(MAX_LEVEL, replace_value):
class RandomAugment (line 326) | class RandomAugment(object):
method __init__ (line 327) | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
method get_random_ops (line 336) | def get_random_ops(self):
method __call__ (line 340) | def __call__(self, img):
class VideoRandomAugment (line 352) | class VideoRandomAugment(object):
method __init__ (line 353) | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
method get_random_ops (line 363) | def get_random_ops(self):
method __call__ (line 367) | def __call__(self, frames):
method _aug (line 386) | def _aug(self, img, ops, apply_or_not):
FILE: minigpt4/runners/runner_base.py
class RunnerBase (line 38) | class RunnerBase:
method __init__ (line 46) | def __init__(self, cfg, task, model, datasets, job_id):
method device (line 68) | def device(self):
method use_distributed (line 75) | def use_distributed(self):
method model (line 79) | def model(self):
method optimizer (line 99) | def optimizer(self):
method scaler (line 132) | def scaler(self):
method lr_scheduler (line 142) | def lr_scheduler(self):
method dataloaders (line 182) | def dataloaders(self) -> dict:
method cuda_enabled (line 277) | def cuda_enabled(self):
method max_epoch (line 281) | def max_epoch(self):
method log_freq (line 285) | def log_freq(self):
method init_lr (line 290) | def init_lr(self):
method min_lr (line 294) | def min_lr(self):
method accum_grad_iters (line 298) | def accum_grad_iters(self):
method valid_splits (line 302) | def valid_splits(self):
method test_splits (line 311) | def test_splits(self):
method train_splits (line 317) | def train_splits(self):
method evaluate_only (line 326) | def evaluate_only(self):
method use_dist_eval_sampler (line 333) | def use_dist_eval_sampler(self):
method resume_ckpt_path (line 337) | def resume_ckpt_path(self):
method train_loader (line 341) | def train_loader(self):
method setup_output_dir (line 346) | def setup_output_dir(self):
method train (line 362) | def train(self):
method evaluate (line 422) | def evaluate(self, cur_epoch="best", skip_reload=False):
method train_epoch (line 433) | def train_epoch(self, epoch):
method eval_epoch (line 450) | def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
method unwrap_dist_model (line 484) | def unwrap_dist_model(self, model):
method create_loaders (line 490) | def create_loaders(
method _save_checkpoint (line 575) | def _save_checkpoint(self, cur_epoch, is_best=False):
method _reload_best_model (line 602) | def _reload_best_model(self, model):
method _load_checkpoint (line 622) | def _load_checkpoint(self, url_or_filename):
method log_stats (line 648) | def log_stats(self, stats, split_name):
method log_config (line 657) | def log_config(self):
FILE: minigpt4/tasks/__init__.py
function setup_task (line 13) | def setup_task(cfg):
FILE: minigpt4/tasks/base_task.py
class BaseTask (line 19) | class BaseTask:
method __init__ (line 20) | def __init__(self, **kwargs):
method setup_task (line 27) | def setup_task(cls, **kwargs):
method build_model (line 30) | def build_model(self, cfg):
method build_datasets (line 37) | def build_datasets(self, cfg):
method train_step (line 69) | def train_step(self, model, samples):
method valid_step (line 73) | def valid_step(self, model, samples):
method before_evaluation (line 76) | def before_evaluation(self, model, dataset, **kwargs):
method after_evaluation (line 79) | def after_evaluation(self, **kwargs):
method inference_step (line 82) | def inference_step(self):
method evaluation (line 85) | def evaluation(self, model, data_loader, cuda_enabled=True):
method train_epoch (line 104) | def train_epoch(
method train_iters (line 129) | def train_iters(
method _train_inner_loop (line 157) | def _train_inner_loop(
method save_result (line 253) | def save_result(result, result_dir, filename, remove_duplicate=""):
FILE: minigpt4/tasks/image_text_pretrain.py
class ImageTextPretrainTask (line 13) | class ImageTextPretrainTask(BaseTask):
method __init__ (line 14) | def __init__(self):
method evaluation (line 17) | def evaluation(self, model, data_loader, cuda_enabled=True):
FILE: train.py
function parse_args (line 36) | def parse_args():
function setup_seeds (line 52) | def setup_seeds(config):
function get_runner_class (line 63) | def get_runner_class(cfg):
function main (line 72) | def main():
Condensed preview — 126 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (533K chars).
[
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 834,
"preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the b"
},
{
"path": ".github/ISSUE_TEMPLATE/feature_request.md",
"chars": 595,
"preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your fea"
},
{
"path": ".gitignore",
"chars": 3293,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 5229,
"preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
},
{
"path": "LICENSE.md",
"chars": 1497,
"preview": "BSD 3-Clause License\n\nCopyright 2023 Deyao Zhu\nAll rights reserved.\n\nRedistribution and use in source and binary forms, "
},
{
"path": "LICENSE_Lavis.md",
"chars": 1502,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2022 Salesforce, Inc.\nAll rights reserved.\n\nRedistribution and use in source and bin"
},
{
"path": "MiniGPT4_Train.md",
"chars": 2094,
"preview": "## Training of MiniGPT-4\n\nThe training of MiniGPT-4 contains two alignment stages.\n\n**1. First pretraining stage**\n\nIn t"
},
{
"path": "MiniGPTv2_Train.md",
"chars": 1068,
"preview": "## Finetune of MiniGPT-4\n\n\nYou firstly need to prepare the dataset. you can follow this step to prepare the dataset.\nour"
},
{
"path": "README.md",
"chars": 10655,
"preview": "# MiniGPT-V\n\n<font size='5'>**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Lea"
},
{
"path": "SECURITY.md",
"chars": 619,
"preview": "# Security Policy\n\n## Supported Versions\n\nUse this section to tell people about which versions of your project are\ncurre"
},
{
"path": "dataset/README_1_STAGE.md",
"chars": 2798,
"preview": "## Download the filtered Conceptual Captions, SBU, LAION datasets\n\n### Pre-training datasets download:\nWe use the filter"
},
{
"path": "dataset/README_2_STAGE.md",
"chars": 511,
"preview": "## Second Stage Data Preparation\n\nOur second stage dataset can be downloaded from \n[here](https://drive.google.com/file/"
},
{
"path": "dataset/README_MINIGPTv2_FINETUNE.md",
"chars": 10533,
"preview": "## Download the dataset for finetuning the MiniGPT-v2\n\n\nDownload the dataset\n\nImage source | Download path\n--- | :---:\nC"
},
{
"path": "dataset/convert_cc_sbu.py",
"chars": 504,
"preview": "import json\nimport csv\n\n# specify input and output file paths\ninput_file = 'ccs_synthetic_filtered_large.json'\noutput_fi"
},
{
"path": "dataset/convert_laion.py",
"chars": 508,
"preview": "import json\nimport csv\n\n# specify input and output file paths\ninput_file = 'laion_synthetic_filtered_large.json'\noutput_"
},
{
"path": "demo.py",
"chars": 6463,
"preview": "import argparse\nimport os\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\nimport gra"
},
{
"path": "demo_v2.py",
"chars": 23445,
"preview": "import argparse\nimport os\nimport random\nfrom collections import defaultdict\n\nimport cv2\nimport re\n\nimport numpy as np\nfr"
},
{
"path": "environment.yml",
"chars": 666,
"preview": "name: minigptv\nchannels:\n - pytorch\n - defaults\n - anaconda\ndependencies:\n - python=3.9\n - cudatoolkit\n - pip\n - "
},
{
"path": "eval_configs/minigpt4_eval.yaml",
"chars": 443,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_vicuna0\n max_txt_len: 160\n end_sym: \"###\"\n low_resource: True\n prompt"
},
{
"path": "eval_configs/minigpt4_llama2_eval.yaml",
"chars": 434,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_llama2\n max_txt_len: 160\n end_sym: \"</s>\"\n low_resource: True\n prompt"
},
{
"path": "eval_configs/minigptv2_benchmark_evaluation.yaml",
"chars": 1832,
"preview": "model:\n arch: minigpt_v2\n model_type: pretrain\n max_txt_len: 500\n end_sym: \"</s>\"\n low_resource: False\n prompt_tem"
},
{
"path": "eval_configs/minigptv2_eval.yaml",
"chars": 458,
"preview": "model:\n arch: minigpt_v2\n model_type: pretrain\n max_txt_len: 500\n end_sym: \"</s>\"\n low_resource: True\n prompt_temp"
},
{
"path": "eval_scripts/EVAL_README.md",
"chars": 3189,
"preview": "## Evaluation Instruction for MiniGPT-v2\n\n### Data preparation\nImages download\nImage source | Download path\n--- | :---:\n"
},
{
"path": "eval_scripts/eval_ref.py",
"chars": 5406,
"preview": "import os\nimport re\nimport json\nimport argparse\nfrom collections import defaultdict\nimport random\nimport numpy as np\nfro"
},
{
"path": "eval_scripts/eval_vqa.py",
"chars": 9899,
"preview": "import os\nimport re\nimport json\nimport argparse\nfrom collections import defaultdict\n\nimport numpy as np\nfrom PIL import "
},
{
"path": "minigpt4/__init__.py",
"chars": 951,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "minigpt4/common/config.py",
"chars": 16083,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/dist_utils.py",
"chars": 3715,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/eval_utils.py",
"chars": 3031,
"preview": "import argparse\nimport numpy as np\nfrom nltk.translate.bleu_score import sentence_bleu\n\nfrom minigpt4.common.registry im"
},
{
"path": "minigpt4/common/gradcam.py",
"chars": 815,
"preview": "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom scipy.ndimage import filters\nfrom skimage import transform "
},
{
"path": "minigpt4/common/logger.py",
"chars": 6001,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/optims.py",
"chars": 3516,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/registry.py",
"chars": 9915,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/utils.py",
"chars": 13807,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py",
"chars": 3686,
"preview": "# coding: utf-8\n\nimport sys\ndataDir = '../../VQA'\nsys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))\nfrom vq"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py",
"chars": 18,
"preview": "author='aagrawal'\n"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py",
"chars": 8448,
"preview": "# coding=utf-8\n\n__author__='aagrawal'\n\nimport re\n# This code is based on the code written by Tsung-Yi Lin for MSCOCO Pyt"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py",
"chars": 2434,
"preview": "# coding: utf-8\n\nfrom vqaTools.vqa import VQA\nimport random\nimport skimage.io as io\nimport matplotlib.pyplot as plt\nimpo"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py",
"chars": 24,
"preview": "__author__ = 'aagrawal'\n"
},
{
"path": "minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py",
"chars": 8000,
"preview": "__author__ = 'aagrawal'\n__version__ = '0.9'\n\n# Interface for accessing the VQA dataset.\n\n# This code is based on the cod"
},
{
"path": "minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt",
"chars": 979,
"preview": "how many\nwhat color is the\nis the\nwhere is the\nwhat\nwhat is\nare the\nwhat is the\nis there a\ndoes the\nis the woman\nis the "
},
{
"path": "minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt",
"chars": 734,
"preview": "how many\nis the\nwhat\nwhat color is the\nwhat is the\nis this\nis this a\nwhat is\nare the\nwhat kind of\nis there a\nwhat type o"
},
{
"path": "minigpt4/common/vqa_tools/VQA/README.md",
"chars": 5382,
"preview": "Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.\n===================\n## VQA v2.0 release ##"
},
{
"path": "minigpt4/common/vqa_tools/VQA/license.txt",
"chars": 1521,
"preview": "Copyright (c) 2014, Aishwarya Agrawal\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or w"
},
{
"path": "minigpt4/common/vqa_tools/__init__.py",
"chars": 246,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/vqa_tools/vqa.py",
"chars": 8634,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/common/vqa_tools/vqa_eval.py",
"chars": 11016,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/configs/datasets/aokvqa/defaults.yaml",
"chars": 720,
"preview": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full"
},
{
"path": "minigpt4/configs/datasets/cc_sbu/align.yaml",
"chars": 102,
"preview": "datasets:\n cc_sbu_align:\n data_type: images\n build_info:\n storage: /path/to/cc_sbu_align/\n"
},
{
"path": "minigpt4/configs/datasets/cc_sbu/defaults.yaml",
"chars": 116,
"preview": "datasets:\n cc_sbu:\n data_type: images\n build_info:\n storage: /path/to/cc_sbu_dataset/{00000..01255}.tar\n"
},
{
"path": "minigpt4/configs/datasets/coco/caption.yaml",
"chars": 832,
"preview": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full"
},
{
"path": "minigpt4/configs/datasets/coco/defaults_vqa.yaml",
"chars": 789,
"preview": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml",
"chars": 190,
"preview": "datasets:\n invrefcoco:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /pa"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml",
"chars": 192,
"preview": "datasets:\n invrefcocog:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /p"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml",
"chars": 192,
"preview": "datasets:\n invrefcocop:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /p"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/refcoco.yaml",
"chars": 184,
"preview": "datasets:\n refcoco:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /path/"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/refcocog.yaml",
"chars": 186,
"preview": "datasets:\n refcocog:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /path"
},
{
"path": "minigpt4/configs/datasets/coco_bbox/refcocop.yaml",
"chars": 186,
"preview": "datasets:\n refcocop:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /path"
},
{
"path": "minigpt4/configs/datasets/flickr/caption_to_phrase.yaml",
"chars": 184,
"preview": "datasets:\n flickr_CaptionToPhrase:\n data_type: images\n build_info:\n image_path: /path/to/filtered_flikcr/ima"
},
{
"path": "minigpt4/configs/datasets/flickr/default.yaml",
"chars": 187,
"preview": "datasets:\n flickr_grounded_caption:\n data_type: images\n build_info:\n image_path: /path/to/filtered_flikcr/im"
},
{
"path": "minigpt4/configs/datasets/flickr/object_to_phrase.yaml",
"chars": 182,
"preview": "datasets:\n flickr_ObjectToPhrase:\n data_type: images\n build_info:\n image_path: /path/to/filtered_flikcr/imag"
},
{
"path": "minigpt4/configs/datasets/gqa/balanced_val.yaml",
"chars": 732,
"preview": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full"
},
{
"path": "minigpt4/configs/datasets/laion/defaults.yaml",
"chars": 114,
"preview": "datasets:\n laion:\n data_type: images\n build_info:\n storage: /path/to/laion_dataset/{00000..10488}.tar\n"
},
{
"path": "minigpt4/configs/datasets/llava/conversation.yaml",
"chars": 162,
"preview": "datasets:\n\n llava_conversation:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_"
},
{
"path": "minigpt4/configs/datasets/llava/detail.yaml",
"chars": 149,
"preview": "datasets:\n llava_detail:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: /"
},
{
"path": "minigpt4/configs/datasets/llava/reason.yaml",
"chars": 161,
"preview": "datasets:\n\n llava_reason:\n data_type: images\n build_info:\n image_path: /path/to/coco/images\n ann_path: "
},
{
"path": "minigpt4/configs/datasets/multitask_conversation/default.yaml",
"chars": 194,
"preview": "datasets:\n multitask_conversation:\n data_type: images\n build_info:\n \n image_path: /path/to/coco/images\n "
},
{
"path": "minigpt4/configs/datasets/nlp/unnatural_instruction.yaml",
"chars": 154,
"preview": "datasets:\n unnatural_instruction:\n data_type: text\n build_info:\n ann_path: /path/to/unnatural_instructions/f"
},
{
"path": "minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml",
"chars": 143,
"preview": "datasets:\n ocrvqa:\n data_type: images\n build_info:\n image_path: /path/to/ocrvqa/images\n ann_path: /path"
},
{
"path": "minigpt4/configs/datasets/okvqa/defaults.yaml",
"chars": 761,
"preview": " # Copyright (c) 2022, salesforce.com, inc.\n # All rights reserved.\n # SPDX-License-Identifier: BSD-3-Clause\n # For full"
},
{
"path": "minigpt4/configs/datasets/textcaps/caption.yaml",
"chars": 182,
"preview": "datasets:\n textcaps_caption:\n data_type: images\n \n build_info:\n image_path: /path/to/textcaps/train_image"
},
{
"path": "minigpt4/configs/datasets/vg/ref.yaml",
"chars": 95,
"preview": "datasets:\n refvg:\n data_type: images\n build_info:\n data_dir: /path/to/visual_genome"
},
{
"path": "minigpt4/configs/default.yaml",
"chars": 141,
"preview": "env:\n # For default users\n # cache_root: \"cache\"\n # For internal use with persistent storage\n cache_root: \"/export/h"
},
{
"path": "minigpt4/configs/models/minigpt4_llama2.yaml",
"chars": 574,
"preview": "model:\n arch: minigpt4\n\n # vit encoder\n image_size: 224\n drop_path_rate: 0\n use_grad_checkpoint: False\n vit_precis"
},
{
"path": "minigpt4/configs/models/minigpt4_vicuna0.yaml",
"chars": 610,
"preview": "model:\n arch: minigpt4\n\n # vit encoder\n image_size: 224\n drop_path_rate: 0\n use_grad_checkpoint: False\n vit_precis"
},
{
"path": "minigpt4/configs/models/minigpt_v2.yaml",
"chars": 586,
"preview": "model:\n arch: minigpt_v2\n\n # vit encoder\n image_size: 448\n drop_path_rate: 0\n use_grad_checkpoint: False\n vit_prec"
},
{
"path": "minigpt4/conversation/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "minigpt4/conversation/conversation.py",
"chars": 7921,
"preview": "import argparse\nimport time\nfrom threading import Thread\nfrom PIL import Image\n\nimport torch\nfrom transformers import Au"
},
{
"path": "minigpt4/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "minigpt4/datasets/builders/__init__.py",
"chars": 1897,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/builders/base_dataset_builder.py",
"chars": 8105,
"preview": "\"\"\"\n This file is from\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-C"
},
{
"path": "minigpt4/datasets/builders/image_text_pair_builder.py",
"chars": 16900,
"preview": "import os\nimport logging\nimport warnings\n\nfrom minigpt4.common.registry import registry\nfrom minigpt4.datasets.builders."
},
{
"path": "minigpt4/datasets/data_utils.py",
"chars": 6511,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "minigpt4/datasets/datasets/aok_vqa_datasets.py",
"chars": 4061,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/base_dataset.py",
"chars": 2409,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/caption_datasets.py",
"chars": 5097,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/cc_sbu_dataset.py",
"chars": 1603,
"preview": "import os\nfrom PIL import Image\nimport webdataset as wds\nfrom minigpt4.datasets.datasets.base_dataset import BaseDataset"
},
{
"path": "minigpt4/datasets/datasets/coco_caption.py",
"chars": 3767,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/coco_dataset.py",
"chars": 13665,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/coco_vqa_datasets.py",
"chars": 4626,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/dataloader_utils.py",
"chars": 5258,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/flickr.py",
"chars": 4793,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/gqa_datasets.py",
"chars": 1782,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/laion_dataset.py",
"chars": 1170,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/datasets/datasets/llava_dataset.py",
"chars": 4690,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport numpy as np\nfrom PIL import Image\nimport skimage.io"
},
{
"path": "minigpt4/datasets/datasets/multitask_conversation.py",
"chars": 2219,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/ocrvqa_dataset.py",
"chars": 2624,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/text_caps.py",
"chars": 2628,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/unnatural_instruction.py",
"chars": 1338,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/vg_dataset.py",
"chars": 2873,
"preview": "import os\nimport json\nimport pickle\nimport random\nimport time\nimport itertools\n\nimport numpy as np\nfrom PIL import Image"
},
{
"path": "minigpt4/datasets/datasets/vqa_datasets.py",
"chars": 5594,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/models/Qformer.py",
"chars": 48386,
"preview": "\"\"\"\n * Copyright (c) 2023, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
},
{
"path": "minigpt4/models/__init__.py",
"chars": 5829,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/models/base_model.py",
"chars": 8199,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/models/eva_vit.py",
"chars": 19529,
"preview": "# Based on EVA, BEIT, timm and DeiT code bases\n# https://github.com/baaivision/EVA\n# https://github.com/rwightman/pytorc"
},
{
"path": "minigpt4/models/minigpt4.py",
"chars": 7621,
"preview": "import logging\nimport random\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\n\nfrom m"
},
{
"path": "minigpt4/models/minigpt_base.py",
"chars": 17268,
"preview": "import logging\nimport random\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\n\nfrom m"
},
{
"path": "minigpt4/models/minigpt_v2.py",
"chars": 4867,
"preview": "import logging\nimport random\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\n\nfrom m"
},
{
"path": "minigpt4/models/modeling_llama.py",
"chars": 4974,
"preview": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn "
},
{
"path": "minigpt4/processors/__init__.py",
"chars": 823,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/processors/base_processor.py",
"chars": 610,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/processors/blip_processors.py",
"chars": 3956,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/processors/randaugment.py",
"chars": 11298,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/runners/__init__.py",
"chars": 306,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/runners/runner_base.py",
"chars": 23281,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/tasks/__init__.py",
"chars": 736,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/tasks/base_task.py",
"chars": 9168,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "minigpt4/tasks/image_text_pretrain.py",
"chars": 538,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "train.py",
"chars": 2749,
"preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
},
{
"path": "train_configs/minigpt4_llama2_stage1_pretrain.yaml",
"chars": 989,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_llama2\n\n\ndatasets:\n laion:\n batch_size: 64\n vis_processor:\n t"
},
{
"path": "train_configs/minigpt4_llama2_stage2_finetune.yaml",
"chars": 914,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_llama2\n\n max_txt_len: 160\n end_sym: \"</s>\"\n prompt_path: \"prompts/alig"
},
{
"path": "train_configs/minigpt4_stage1_pretrain.yaml",
"chars": 983,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_vicuna0\n\n\ndatasets:\n laion:\n batch_size: 64\n vis_processor:\n "
},
{
"path": "train_configs/minigpt4_stage2_finetune.yaml",
"chars": 916,
"preview": "model:\n arch: minigpt4\n model_type: pretrain_vicuna0\n\n max_txt_len: 160\n end_sym: \"###\"\n prompt_path: \"prompts/alig"
},
{
"path": "train_configs/minigptv2_finetune.yaml",
"chars": 5391,
"preview": "model:\n arch: minigpt_v2\n model_type: pretrain\n max_txt_len: 1024\n image_size: 448\n end_sym: \"</s>\"\n llama_model: "
}
]
About this extraction
This page contains the full source code of the Vision-CAIR/MiniGPT-4 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 126 files (491.5 KB), approximately 123.5k tokens, and a symbol index with 708 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.