Full Code of MILVLG/prophet for AI

main 3e92892ec5ec cached
50 files
157.6 KB
43.3k tokens
175 symbols
1 requests
Download .txt
Repository: MILVLG/prophet
Branch: main
Commit: 3e92892ec5ec
Files: 50
Total size: 157.6 KB

Directory structure:
gitextract_h4opr1r6/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   └── .gitkeep
├── ckpts/
│   └── .gitkeep
├── configs/
│   ├── finetune.yml
│   ├── path_cfgs.py
│   ├── pretrain.yml
│   ├── prompt.yml
│   ├── task_cfgs.py
│   └── task_to_split.py
├── datasets/
│   └── .gitkeep
├── environment.yml
├── evaluation/
│   ├── ans_punct.py
│   ├── aok_utils/
│   │   ├── eval_predictions.py
│   │   ├── load_aokvqa.py
│   │   └── remap_predictions.py
│   ├── aokvqa_evaluate.py
│   ├── okvqa_evaluate.py
│   └── vqa_utils/
│       ├── vqa.py
│       └── vqaEval.py
├── main.py
├── misc/
│   └── tree.txt
├── outputs/
│   ├── ckpts/
│   │   └── .gitkeep
│   ├── logs/
│   │   └── .gitkeep
│   └── results/
│       └── .gitkeep
├── preds/
│   └── .gitkeep
├── prophet/
│   ├── __init__.py
│   ├── stage1/
│   │   ├── finetune.py
│   │   ├── heuristics.py
│   │   ├── model/
│   │   │   ├── layers.py
│   │   │   ├── mcan.py
│   │   │   ├── mcan_for_finetune.py
│   │   │   ├── net_utils.py
│   │   │   └── rope2d.py
│   │   ├── pretrain.py
│   │   └── utils/
│   │       ├── load_data.py
│   │       └── optim.py
│   └── stage2/
│       ├── prompt.py
│       └── utils/
│           ├── data_utils.py
│           └── fancy_pbar.py
├── scripts/
│   ├── evaluate_file.sh
│   ├── evaluate_model.sh
│   ├── extract_img_feats.sh
│   ├── finetune.sh
│   ├── heuristics_gen.sh
│   ├── pretrain.sh
│   └── prompt.sh
└── tools/
    ├── extract_img_feats.py
    └── transforms.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
**/__pycache__/
datasets/*/
!datasets/.gitkeep
assets/*
!assets/.gitkeep
ckpts/*
!ckpts/.gitkeep
outputs/ckpts/*
!outputs/ckpts/.gitkeep
outputs/logs/*
!outputs/logs/.gitkeep
outputs/results/*
!outputs/results/.gitkeep
preds/*
!preds/.gitkeep
tmp

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Prophet

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/prompting-large-language-models-with-answer/visual-question-answering-on-a-okvqa)](https://paperswithcode.com/sota/visual-question-answering-on-a-okvqa?p=prompting-large-language-models-with-answer)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/prompting-large-language-models-with-answer/visual-question-answering-on-ok-vqa)](https://paperswithcode.com/sota/visual-question-answering-on-ok-vqa?p=prompting-large-language-models-with-answer)

This repository is the official implementation of the Prophet, a two stage framework designed to prompt GPT-3 with answer heuristics for knowledge-based VQA. In stage one, we train a vanilla VQA model on a specific knowledge-based VQA dataset and extract two types of complementary answer heuristics from the model: answer candidates and answer-aware examples. In stage two, answer heuristics are used to prompt GPT-3 to generate better answers. Prophet significantly outperforms existing state-of-the-art methods  on two datasets, delivering 61.1% on OK-VQA and 55.7% on A-OKVQA. Please refer to our [paper](https://arxiv.org/pdf/2303.01903.pdf) for details.

![prophet](misc/framework.png)

## Updates
April 28, 2023
- Add pretrained and finetuned models on A-OKVOA.

March 10, 2023
- Training and testing codes of the two-stages Prophet framework.
- Pretrained and finetuned models on OK-VOA.

## Table of Contents

- [Prerequisites](#prerequisites)
- [Usage](#usage)
- [Evaluation](#evaluation)
- [Citation](#citation)
- [License](#license)
<!-- - [Acknowledgement](#acknowledgement) -->

## Prerequisites

### Hardware and Software Requirements

To conduct the following experiments, a machine with at least 1 RTX 3090 GPU, 50GB memory, and 300GB free disk space is recommended. We strongly recommend using an SSD drive to guarantee high-speed I/O.

Following software is needed:

1. [Python](https://www.python.org/downloads/) >= 3.9
2. [Cuda](https://developer.nvidia.com/cuda-toolkit) >= 11.3
3. [Pytorch](https://pytorch.org/get-started/locally/) >= 12.0
5. what you can find in [environment.yml](environment.yml)

We recommend downloading [Anaconda](https://www.anaconda.com/) first and then creating a new environment with the following command:

``` shell
$ conda env create -f environment.yml
```

This command will create a new environment named `prophet` with all the required packages. To activate the environment, run:

``` shell
$ conda activate prophet
```

### Data Preparation

Before running the code, prepare two folders: `datasets` and `assets`. The `datasets` folder contains all the datasets and features used in this project, and the `assets` folder contains the pre-computed resources and other intermediate files (you can use them to skip some early experiment steps and save time).

First, download the [datasets](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/Ebzd7EANzHVHnh3FvYvCJ7kBkJf56iT1Obe5L2PZAzgM2g?download=1) and [assets](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/Ec5NPIswAxlEqi74qwGjIf0BKInF0O6nwW5dtn4h3GOUsQ?download=1). Then put the `datasets` and `assets` folder in the root directory of this project. Download MSCOCO 2014 and 2017 images from [here](https://cocodataset.org/#download) (you can skip MSCOCO 2017 if you only experiments on OK-VQA) and put them in the `datasets` folder. Run the following command to extract the features of the images:

``` shell
$ bash scripts/extract_img_feats.sh
```

After that, the `datasets` and `assets` folder will have the following structure:

<details>
<summary>Click to expand</summary>

```
datasets
├── aokvqa
│   ├── aokvqa_v1p0_test.json
│   ├── aokvqa_v1p0_train.json
│   └── aokvqa_v1p0_val.json
├── coco2014
│   ├── train2014
│   └── val2014
├── coco2014_feats
│   ├── train2014
│   └── val2014
├── coco2017
│   ├── test2017
│   ├── train2017
│   └── val2017
├── coco2017_feats
│   ├── test2017
│   ├── train2017
│   └── val2017
├── okvqa
│   ├── mscoco_train2014_annotations.json
│   ├── mscoco_val2014_annotations.json
│   ├── OpenEnded_mscoco_train2014_questions.json
│   └── OpenEnded_mscoco_val2014_questions.json
└── vqav2
    ├── v2_mscoco_train2014_annotations.json
    ├── v2_mscoco_val2014_annotations.json
    ├── v2_OpenEnded_mscoco_train2014_questions.json
    ├── v2_OpenEnded_mscoco_val2014_questions.json
    ├── v2valvg_no_ok_annotations.json
    ├── v2valvg_no_ok_questions.json
    ├── vg_annotations.json
    └── vg_questions.json
```
</details>

We've also provided a tree structure of the entire project in [misc/tree.txt](misc/tree.txt).

## Usage

We provide bash scripts for each stage of the Prophet framework. You can find them in the `scripts` directory. There are two common arguments you should take care of when running each script:

- `--task`: specify the task (i.e., the target dataset) you want to deal with. The available options are `ok` (training on `train` set of OK-VQA and evaluating on the `test` set of OK-VQA), `aok_val` (training on `train` set of A-OKVQA and evaluating on the `val` set of A-OKVQA) and `aok_test` (training on `train` set and `val` set of A-OKVQA and evaluating on the `test` set of A-OKVQA);

Note that although Prophet uses VQA v2 datasets for pre-training, there are slight differences in how the datasets are used for different tasks (`ok`, `aok_val`, and `aok_test`), as detailed in [configs/task_to_split.py](configs/task_to_split.py). This means that different pre-training commands need to be followed for each task.

- `--version`: specify the version name of this run. This name will be used to create a new folder in the `outputs` directory to store the results of this run.

Notice that you can omit any arguments when invoking following scripts, it will then use the default arguments written in the script files.

Before running any script, you can also update the configuration files (`*.yml`) in the `configs` directory to change hyperparameters.

### 1. OK-VQA

Take OK-VQA for example, Propht consists of two phases, stage one  for training a vanilla VQA model and extracting answer heuristics, and stage two for prompting GPT-3 with answer heuristics.

#### **Stage one**

At this stage, we train an improved MCAN model (check the [paper](https://arxiv.org/pdf/2303.01903.pdf) for detail description) through pretraning on VQA v2 and finetuning on target dataset. Multiple GPUs are supported by setting `--gpu 0,1,2,3` (for example). Run pretraining step with commands:

```shell
$ bash scripts/pretrain.sh \
    --task ok --version okvqa_pretrain_1 --gpu 0
```
We've provided a pretrained model for OK-VQA [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EcdTatraOqRJnZXBDXfr7QQBPtn8QYCa2m3Pvq0LlEml9Q?download=1). Then, run finetuning step with commands:

```shell
$ bash scripts/finetune.sh \
    --task ok --version okvqa_finetune_1 --gpu 0 \
    --pretrained_model outputs/okvqa_pretrain_1/ckpts/epoch_13.pkl
```

All epoch checkpoints are saved in `outputs/ckpts/{your_version_name}`. We've also provided a finetuned model for OK-VQA [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/ESUb093PgyZFtLnU_RIYJQsBN_PU0jJdu-eFUb1-4T4mIQ?download=1). You may pick one to generate answer heuristics by run following command:

```shell
$ bash scripts/heuristics_gen.sh \
    --task ok --version okvqa_heuristics_1
    --gpu 0 --ckpt_path outputs/okvqa_finetune_1/ckpts/epoch_6.pkl
    --candidate_num 10 --example_num 100
```

The extracted answer heuristics will be stored as `candidates.json` and `examples.json` in `outputs/results/{your_version_name}` directory.

#### **Stage two**

You may need the `candidates.json` and `examples.json` files generated in the former stage to step into this stage. **Or you can just skip stage one, and use the files of answer heuristics we provided in `assets`. Especially, the `candidates.json` and `examples.json` files for OK-VQA are `answer_aware_examples_okvqa.json` and `candidates_okvqa.json`.** To prompt GPT-3 with answer heuristics and generate better answers, run the following command:

```shell
$ bash scripts/prompt.sh \
    --task ok --version okvqa_prompt_1 \
    --examples_path outputs/results/okvqa_heuristics_1/examples.json \ 
    --candidates_path outputs/results/okvqa_heuristics_1/candidates.json \
    --openai_key sk-xxxxxxxxxxxxxxxxxxxxxx
```
The result file will be stored as `result.json` in `outputs/results/{your_version_name}` directory.


We also provide example scripts for the `aok_val` and `aok_test` modes on A-OKVQA.
<details>
<summary>Click to expand</summary>

### 2. A-OKVQA (val)

#### **Stage one**
Similary, for task of `aok_val`, run pretraining step with commands:

```shell
$ bash scripts/pretrain.sh \
    --task aok_val --version aokvqa_val_pretrain_1 --gpu 0
```
We've provided a pretrained model for `aok_val` [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EYeIgGR521pNsEjxliqRkmEBGpcwS5p-qrMGTC9ro_SF6g?download=1).Then, run finetuning step with commands:

```shell
$ bash scripts/finetune.sh \
    --task aok_val --version aokvqa_val_finetune_1 --gpu 0 \
    --pretrained_model outputs/aokvqa_val_pretrain_1/ckpts/epoch_13.pkl
```

All epoch checkpoints are saved in `outputs/ckpts/{your_version_name}`.We've also provided a finetuned model for `aok_val` [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EQXIIjAIiJJFrOpobVhyH9oBBeBAY-VttHqfS91qPOKlJw?download=1). You may pick one to generate answer heuristics by run following command:

```shell
$ bash scripts/heuristics_gen.sh \
    --task aok_val --version aokvqa_val_heuristics_1
    --gpu 0 --ckpt_path outputs/aokvqa_val_finetune_1/ckpts/epoch_6.pkl
    --candidate_num 10 --example_num 100
```

The extracted answer heuristics will be stored as `candidates.json` and `examples.json` in `outputs/results/{your_version_name}` directory.

#### **Stage two**

You may need the `candidates.json` and `examples.json` files generated in the former stage to step into this stage. **Or you can just skip stage one, and use the files of answer heuristics we provided in `assets`. Especially, the `candidates.json` and `examples.json` files for `aok_val` are `examples_aokvqa_val.json` and `candidates_aokvqa_val.json`.** To prompt GPT-3 with answer heuristics and generate better answers, run the following command:

```shell
$ bash scripts/prompt.sh \
    --task ok --version okvqa_val_prompt_1 \
    --examples_path outputs/results/aokvqa_val_heuristics_1/examples.json \ 
    --candidates_path outputs/results/aokvqa_val_heuristics_1/candidates.json \
    --captions_path assets/captions_aokvqa.json \
    --openai_key sk-xxxxxxxxxxxxxxxxxxxxxx
```
The result file will be stored as `result.json` in `outputs/results/{your_version_name}` directory.



### 3. A-OKVQA (test)

For task of `aok_val`, run pretraining step with commands:
#### **Stage one**
```shell
$ bash scripts/pretrain.sh \
    --task aok_test --version aokvqa_test_pretrain_1 --gpu 0
```
We've provided a pretrained model for `aok_test` [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EWSBB1OrjIlBoPdTMso6RFABNQKYKBWo1iU4l0w2NVDvuQ?download=1). Then, run finetuning step with commands:

```shell
$ bash scripts/finetune.sh \
    --task aok_test --version aokvqa_test_finetune_1 --gpu 0 \
    --pretrained_model outputs/aokvqa_test_pretrain_1/ckpts/epoch_13.pkl
```

All epoch checkpoints are saved in `outputs/ckptss/{your_version_name}`.We've also provided a finetuned model for `aok_test` [here](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EQ6gvWbv9VhHrhh0D08G79kBk6JEA_eqXEt5ULgueCf1tA?download=1). You may pick one to generate answer heuristics by run following command:

```shell
$ bash scripts/heuristics_gen.sh \
    --task aok_test --version aokvqa_test_heuristics_1
    --gpu 0 --ckpt_path outputs/aokvqa_test_finetune_1/ckpts/epoch_6.pkl
    --candidate_num 10 --example_num 100
```

The extracted answer heuristics will be stored as `candidates.json` and `examples.json` in `outputs/results/{your_version_name}` directory.

#### **Stage two**

You may need the `candidates.json` and `examples.json` files generated in the former stage to step into this stage. **Or you can just skip stage one, and use the files of answer heuristics we provided in `assets`. Especially, the `candidates.json` and `examples.json` files for `aok_test` are `examples_aokvqa_test.json` and `candidates_aokvqa_test.json`.** To prompt GPT-3 with answer heuristics and generate better answers, run the following command:

```shell
$ bash scripts/prompt.sh \
    --task ok --version okvqa_test_prompt_1 \
    --examples_path outputs/results/aokvqa_test_heuristics_1/examples.json \ 
    --candidates_path outputs/results/aokvqa_test_heuristics_1/candidates.json \
    --captions_path assets/captions_aokvqa.json \
    --openai_key sk-xxxxxxxxxxxxxxxxxxxxxx
```
The result file will be stored as `result.json` in `outputs/results/{your_version_name}` directory.

</details>

## Evaluation

For the task of `ok` and `aok_val` whose annotations are available, the scores are automatically computed after finetuning and prompting. You can also evaluate the result files that outputted after finetuning or prompting, by run

```shell
$ bash scripts/evaluate_file.sh \
    --task ok --result_path outputs/results/okvqa_prompt_1/result.json
```

Using the corresponding result files and evaluation script above, we obtain the accuracies in the following table, respectively.


<table border="2">
<tr><th> OK-VQA</th><th> A-OKVQA (val) </th><th> A-OKVQA (test) </th></tr>
<tr><td>

| MCAN | Prophet |
|:--:|:--:|
| [53.0%](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EVPAUDjTWX9Gn3GIqj7JwUoB5HMWwL3SRnNf18dSckJBOw?download=1) | [61.1%](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EUqH0N4fLVdPsLYJ48Wl_gsBneZzyGR23Tv5P9RskOBwNQ?download=1) |
</td><td>

| MCAN | Prophet |
|:--:|:--:|
| [52.0%](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EdBYZeS55iFEjdlOhUbyWRsBtYnQ3-zerho13mYj2YQ0Ag?download=1) |[58.2%](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EXDUxT3_LrpDugZ7xj-0BMYBynuFDJQS88M3EGeFEhU5dg?download=1) |
</td><td>

| MCAN | Prophet |
|:--:|:--:|
| 45.6% | 55.7% |
</td></tr>
</table>

For the task of `aok_test`, you need to submit the result file to the [A-OKVQA Leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/public) to evaluate the result.


## Citation

If you use this code in your research, please cite our paper:

```BibTex
@inproceedings{shao2023prompting,
  title={Prompting Large Language Models with Answer Heuristics for Knowledge-based Visual Question Answering},
  author={Shao, Zhenwei and Yu, Zhou and Wang, Meng and Yu, Jun},
  booktitle={Computer Vision and Pattern Recognition (CVPR)},
  pages={14974--14983},
  year={2023}
}
```

## License

This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.


================================================
FILE: assets/.gitkeep
================================================


================================================
FILE: ckpts/.gitkeep
================================================


================================================
FILE: configs/finetune.yml
================================================
# Network
IMG_RESOLUTION: 512
IMG_FEAT_GRID: 16
IMG_FEAT_SIZE: 4096
BERT_VERSION: bert-large-uncased
MAX_TOKEN: 32
ARCH_CEIL: {
  enc: ['SA', 'FFN'],
  dec: ['SA_v', 'GA', 'FFN'],
}
LANG_FEAT_SIZE: 1024
LAYER: 6
HIDDEN_SIZE: 1024
FF_SIZE: 4096
MULTI_HEAD: 8
DROPOUT_R: 0.1
FLAT_MLP_SIZE: 1024
FLAT_GLIMPSES: 1
FLAT_OUT_SIZE: 2048

# Training
BATCH_SIZE: 64
EVAL_BATCH_SIZE: 64
BERT_LR_MULT: 0.01
LR_BASE: 0.00005
LR_DECAY_R: 0.2
LR_DECAY_LIST: [5,]
WARMUP_EPOCH: 0
MAX_EPOCH: 6
GRAD_NORM_CLIP: -1
OPT: AdamW
OPT_PARAMS: {betas: '(0.9, 0.98)', eps: '1e-9'}
## optimizer for finetuning warmup (i.e., only update the new appended parameters as a warm-up)
EPOPH_FTW: 1
OPT_FTW: Adam
LR_BASE_FTW: 0.001
OPT_PARAMS_FTW: {betas: '(0.9, 0.98)', eps: '1e-9'}

================================================
FILE: configs/path_cfgs.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: set const paths and dirs
# ------------------------------------------------------------------------------ #

import os

class PATH:
    def __init__(self):

        self.LOG_ROOT = 'outputs/logs/'
        self.CKPT_ROOT = 'outputs/ckpts/'
        self.RESULTS_ROOT = 'outputs/results/'
        self.DATASET_ROOT = 'datasets/'
        self.ASSETS_ROOT = 'assets/'


        self.IMAGE_DIR = {
            'train2014': self.DATASET_ROOT + 'coco2014/train2014/',
            'val2014': self.DATASET_ROOT + 'coco2014/val2014/',
            # 'test2015': self.DATASET_ROOT + 'coco2015/test2015/',
            'train2017': self.DATASET_ROOT + 'coco2017/train2017/',
            'val2017': self.DATASET_ROOT + 'coco2017/val2017/',
            'test2017': self.DATASET_ROOT + 'coco2017/test2017/',
        }

        self.FEATS_DIR = {
            'train2014': self.DATASET_ROOT + 'coco2014_feats/train2014/',
            'val2014': self.DATASET_ROOT + 'coco2014_feats/val2014/',
            'train2017': self.DATASET_ROOT + 'coco2017_feats/train2017/',
            'val2017': self.DATASET_ROOT + 'coco2017_feats/val2017/',
            'test2017': self.DATASET_ROOT + 'coco2017_feats/test2017/',
        }

        self.QUESTION_PATH = {
            'v2train': self.DATASET_ROOT + 'vqav2/v2_OpenEnded_mscoco_train2014_questions.json',
            'v2val': self.DATASET_ROOT + 'vqav2/v2_OpenEnded_mscoco_val2014_questions.json',
            'vg': self.DATASET_ROOT + 'vqav2/vg_questions.json',
            'v2valvg_no_ok': self.DATASET_ROOT + 'vqav2/v2valvg_no_ok_questions.json',
            'oktrain': self.DATASET_ROOT + 'okvqa/OpenEnded_mscoco_train2014_questions.json',
            'oktest': self.DATASET_ROOT + 'okvqa/OpenEnded_mscoco_val2014_questions.json',
            'aoktrain': self.DATASET_ROOT + 'aokvqa/aokvqa_v1p0_train.json',
            'aokval': self.DATASET_ROOT + 'aokvqa/aokvqa_v1p0_val.json',
            'aoktest': self.DATASET_ROOT + 'aokvqa/aokvqa_v1p0_test.json',
        }

        self.ANSWER_PATH = {
            'v2train': self.DATASET_ROOT + 'vqav2/v2_mscoco_train2014_annotations.json',
            'v2val': self.DATASET_ROOT + 'vqav2/v2_mscoco_val2014_annotations.json',
            'vg': self.DATASET_ROOT + 'vqav2/vg_annotations.json',
            'v2valvg_no_ok': self.DATASET_ROOT + 'vqav2/v2valvg_no_ok_annotations.json',
            'oktrain': self.DATASET_ROOT + 'okvqa/mscoco_train2014_annotations.json',
            'oktest': self.DATASET_ROOT + 'okvqa/mscoco_val2014_annotations.json',
            'aoktrain': self.DATASET_ROOT + 'aokvqa/aokvqa_v1p0_train.json',
            'aokval': self.DATASET_ROOT + 'aokvqa/aokvqa_v1p0_val.json',
        }

        self.ANSWER_DICT_PATH = {
            'v2': self.ASSETS_ROOT + 'answer_dict_vqav2.json',
            'ok': self.ASSETS_ROOT + 'answer_dict_okvqa.json',
            'aok': self.ASSETS_ROOT + 'answer_dict_aokvqa.json',
        }




================================================
FILE: configs/pretrain.yml
================================================
# Network
IMG_RESOLUTION: 512
IMG_FEAT_GRID: 16
IMG_FEAT_SIZE: 4096
BERT_VERSION: bert-large-uncased
MAX_TOKEN: 32
ARCH_CEIL: {
  enc: ['SA', 'FFN'],
  dec: ['SA_v', 'GA', 'FFN'],
}
LANG_FEAT_SIZE: 1024
LAYER: 6
HIDDEN_SIZE: 1024
FF_SIZE: 4096
MULTI_HEAD: 8
DROPOUT_R: 0.1
FLAT_MLP_SIZE: 1024
FLAT_GLIMPSES: 1
FLAT_OUT_SIZE: 2048

# Training
BATCH_SIZE: 64
EVAL_BATCH_SIZE: 64
BERT_LR_MULT: 0.01
LR_BASE: 0.00007
LR_DECAY_R: 0.2
LR_DECAY_LIST: [10, 12]
WARMUP_EPOCH: 3
MAX_EPOCH: 13
GRAD_NORM_CLIP: 2.0
OPT: Adam
OPT_PARAMS: {betas: '(0.9, 0.98)', eps: '1e-9'}


================================================
FILE: configs/prompt.yml
================================================
MODEL: text-davinci-002
TEMPERATURE: 0.
MAX_TOKENS: 8
SLEEP_PER_INFER: 10

PROMPT_HEAD: "Please answer the question according to the context and candidate answers. Each candidate answer is associated with a confidence score within a bracket. The true answer may not be included in the candidate answers.\n\n"
LINE_PREFIX: "===\n"
N_EXAMPLES: 20
K_CANDIDATES: 10
T_INFER: 5

================================================
FILE: configs/task_cfgs.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Object that manages the configuration of the experiments.
# ------------------------------------------------------------------------------ #

import os
import random
import torch
import numpy as np
from datetime import datetime

from .path_cfgs import PATH
from .task_to_split import *


class Cfgs(PATH):
    
    def __init__(self, args):
        super(Cfgs, self).__init__()
        self.set_silent_attr()

        self.GPU = getattr(args, 'GPU', None)
        if self.GPU is not None:
            self.GPU_IDS = [int(i) for i in self.GPU.split(',')]
            # print(f'Avaliable GPUs: {torch.cuda.device_count()}')
            # print(f'Using GPU {self.GPU}')
            self.CURRENT_GPU = self.GPU_IDS[0]
            torch.cuda.set_device(f'cuda:{self.CURRENT_GPU}')
            self.N_GPU = len(self.GPU_IDS)
            self.SEED = getattr(args, 'SEED', 1111)
            torch.manual_seed(self.SEED)
            # torch.manual_seed_all(self.SEED)
            if self.N_GPU < 2:
                torch.cuda.manual_seed(self.SEED)
            else:
                torch.cuda.manual_seed_all(self.SEED)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.SEED)
            random.seed(self.SEED)
            torch.set_num_threads(2)

        # -------------------------
        # ---- Version Control ----
        # -------------------------
        self.TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
        self.VERSION = getattr(args, 'VERSION', self.TIMESTAMP)
        
        # paths and dirs
        self.CKPTS_DIR = os.path.join(self.CKPT_ROOT, self.VERSION)
        self.LOG_PATH = os.path.join(
            self.LOG_ROOT, 
            self.VERSION, 
            f'log_{self.TIMESTAMP}.txt'
        )
        self.RESULT_DIR = os.path.join(self.RESULTS_ROOT, self.VERSION)
        self.RESULT_PATH = os.path.join(
            self.RESULTS_ROOT,
            self.VERSION,
            'result_' + self.TIMESTAMP + '.json'
        )

        # about resume
        self.RESUME = getattr(args, 'RESUME', False)
        if self.RESUME and self.RUN_MODE == 'pretrain':
            self.RESUME_VERSION = getattr(args, 'RESUME_VERSION', self.VERSION)
            self.RESUME_EPOCH = getattr(args, 'RESUME_EPOCH', None)
            resume_path = getattr(args, 'RESUME_PATH', None)
            self.RESUME_PATH = os.path.join(
                self.CKPTS_DIR, 
                self.RESUME_VERSION, 
                f'epoch_{self.RESUME_EPOCH}.pkl'
            ) if resume_path is None else resume_path
        
        # for testing and heuristics generation
        self.CKPT_PATH = getattr(args, 'CKPT_PATH', None)

        # ----------------------
        # ---- Task Control ----
        # ----------------------

        self.TASK = getattr(args, 'TASK', 'ok')
        assert self.TASK in ['ok', 'aok_val', 'aok_test']

        self.RUN_MODE = getattr(args, 'RUN_MODE', 'finetune')
        assert self.RUN_MODE in ['pretrain', 'finetune', 'finetune_test', 'heuristics', 'prompt']

        if self.RUN_MODE == 'pretrain':
            self.DATA_TAG = 'v2'  # used to config answer dict
            self.DATA_MODE = 'pretrain'
        else:
            self.DATA_TAG = self.TASK.split('_')[0]  # used to config answer dict
            self.DATA_MODE = 'finetune'

        
        # config pipeline...
        self.EVAL_NOW = True
        if self.RUN_MODE == 'pretrain' or self.TASK == 'aok_test':
            self.EVAL_NOW = False
        # print(f'Eval Now: {self.EVAL_NOW}')

        # ------------------------
        # ---- Model Training ----
        # ------------------------

        self.NUM_WORKERS = 8
        self.PIN_MEM = True

        # --------------------------------
        # ---- Heuristics Generations ----
        # --------------------------------

        self.CANDIDATE_NUM = getattr(args, 'CANDIDATE_NUM', None)
        if self.CANDIDATE_NUM is not None:
            self.CANDIDATE_FILE_PATH = os.path.join(
                self.RESULTS_ROOT,
                self.VERSION,
                'candidates.json'
            )
            self.EXAMPLE_FILE_PATH = os.path.join(
                self.RESULTS_ROOT,
                self.VERSION,
                'examples.json'
            )
            self.ANSWER_LATENTS_DIR = os.path.join(
                self.RESULTS_ROOT,
                self.VERSION,
                'answer_latents'
            ) # where answer latents will be saved


        # write rest arguments to self
        for attr in args.__dict__:
            setattr(self, attr, getattr(args, attr))
    
    def __repr__(self):
        _str = ''
        for attr in self.__dict__:
            if attr in self.__silent or getattr(self, attr) is None:
                continue
            _str += '{ %-17s }-> %s\n' % (attr, getattr(self, attr))
        
        return _str
    
    def override_from_dict(self, dict_):
        for key, value in dict_.items():
            setattr(self, key, value)
    
    def set_silent_attr(self):
        self.__silent = []
        for attr in self.__dict__:
            self.__silent.append(attr)
        
    @property
    def TRAIN_SPLITS(self):
        return TASK_TO_SPLIT[self.TASK][self.DATA_MODE]['train_split']
    
    @property
    def EVAL_SPLITS(self):
        return TASK_TO_SPLIT[self.TASK][self.DATA_MODE]['eval_split']
        
    @property
    def FEATURE_SPLIT(self):
        FEATURE_SPLIT = []
        for split in self.TRAIN_SPLITS + self.EVAL_SPLITS:
            feat_split = SPLIT_TO_IMGS[split]
            if feat_split not in FEATURE_SPLIT:
                FEATURE_SPLIT.append(feat_split)
        return FEATURE_SPLIT
    
    @property
    def EVAL_QUESTION_PATH(self):
        # if not self.EVAL_NOW:
        #     return []
        return self.QUESTION_PATH[self.EVAL_SPLITS[0]]
    
    @property
    def EVAL_ANSWER_PATH(self):
        if not self.EVAL_NOW:
            return []
        return self.ANSWER_PATH[self.EVAL_SPLITS[0]]

================================================
FILE: configs/task_to_split.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: The goal of this file is to define the mapping from task and data
# mode to dataset splits.
# ------------------------------------------------------------------------------ #

class DictSafe(dict):

    def __init__(self, data={}):
        dict.__init__(self, data)
        for key, value in data.items():
            if isinstance(value, dict):
                self[key] = DictSafe(value)

    def __getitem__(self, key):
        return self.get(key, [])

# TASK_TO_SPLIT[TASK][DATA_MODE]['train_split'] is a list of dataset split name for training
# TASK_TO_SPLIT[TASK][DATA_MODE]['eval_split'] is a list of dataset split name for evaluation
# 'pretrain' mode is used for pretrain, so it does not have 'eval_split'
# 'finetune' mode is used for finetune, heuristics generation and prompting
TASK_TO_SPLIT = {
    'ok': {
        'pretrain': {
            'train_split': ['v2train', 'v2valvg_no_ok'],
            # As the testing set of okvqa uses a subset of MSCOCO val2014 as the input images,
            # we remove this subset from the training set of pretraining to avoid data leakage.
        },
        'finetune': {
            'train_split': ['oktrain'],
            'eval_split': ['oktest'],
        }
    },
    'aok_val': {
        'pretrain': {
            'train_split': ['v2train'],
        },
        'finetune': {
            'train_split': ['aoktrain'],
            'eval_split': ['aokval'],
        }
    },
    'aok_test': {
        'pretrain': {
            'train_split': ['v2train', 'v2val', 'vg'],
        },
        'finetune': {
            'train_split': ['aoktrain', 'aokval'],
            'eval_split': ['aoktest'],
        }
    },
}
TASK_TO_SPLIT = DictSafe(TASK_TO_SPLIT)

SPLIT_TO_IMGS = {
    'v2train': 'train2014',
    'v2val': 'val2014',
    'v2valvg_no_ok': 'val2014',
    'vg': 'val2014',
    'oktrain': 'train2014',
    'oktest': 'val2014',
    'aoktrain': 'train2017',
    'aokval': 'val2017',
    'aoktest': 'test2017',
}


if __name__ == '__main__':
    print(TASK_TO_SPLIT['okvqa']['test']['train_split'])

================================================
FILE: datasets/.gitkeep
================================================


================================================
FILE: environment.yml
================================================
name: prophet
channels:
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
  - pytorch
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  - conda-forge
  - defaults
dependencies:
  - numpy=1.21.2=py39h20f2e39_0
  - opt_einsum=3.3.0=pyhd8ed1ab_1
  - pip=21.2.4=py39h06a4308_0
  - python=3.9.11=h12debd9_2
  - pytorch=1.12.0=py3.9_cuda11.3_cudnn8.3.2_0
  - rich=12.5.1=py39h06a4308_0
  - torchvision=0.13.0=py39_cu113
  - pip:
    - pyyaml==6.0
    - einops==0.6.0
    - huggingface-hub==0.12.1
    - openai==0.18.0
    - opencv-python==4.5.5.64
    - pillow==9.3.0
    - pyyaml==6.0
    - sentence-transformers==2.2.2
    - sentencepiece==0.1.96
    - tokenizers==0.11.6
    - tqdm==4.63.0
    - transformers==4.26.1
    - git+https://github.com/openai/CLIP.git



================================================
FILE: evaluation/ans_punct.py
================================================
# --------------------------------------------------------
# mcan-vqa (Deep Modular Co-Attention Networks)
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# based on VQA Evaluation Code
# --------------------------------------------------------

import re

contractions = {
    "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
    "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
    "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
    "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
    "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
    "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
    "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
    "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
    "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
    "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
    "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
    "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
    "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
    "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
    "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
    "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
    "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
    "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
    "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
    "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
    "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
    "someoned've": "someone'd've", "someone'dve": "someone'd've",
    "someonell": "someone'll", "someones": "someone's", "somethingd":
    "something'd", "somethingd've": "something'd've", "something'dve":
    "something'd've", "somethingll": "something'll", "thats":
    "that's", "thered": "there'd", "thered've": "there'd've",
    "there'dve": "there'd've", "therere": "there're", "theres":
    "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
    "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
    "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
    "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
    "weren't", "whatll": "what'll", "whatre": "what're", "whats":
    "what's", "whatve": "what've", "whens": "when's", "whered":
    "where'd", "wheres": "where's", "whereve": "where've", "whod":
    "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
    "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
    "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
    "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
    "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
    "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
    "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
    "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
    "you'll", "youre": "you're", "youve": "you've"
}

manual_map = { 'none': '0',
              'zero': '0',
              'one': '1',
              'two': '2',
              'three': '3',
              'four': '4',
              'five': '5',
              'six': '6',
              'seven': '7',
              'eight': '8',
               'nine': '9',
              'ten': '10'}
articles = ['a', 'an', 'the']
period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
comma_strip = re.compile("(\d)(\,)(\d)")
punct = [';', r"/", '[', ']', '"', '{', '}',
                '(', ')', '=', '+', '\\', '_', '-',
                '>', '<', '@', '`', ',', '?', '!']

def process_punctuation(inText):
    outText = inText
    for p in punct:
        if (p + ' ' in inText or ' ' + p in inText) \
           or (re.search(comma_strip, inText) != None):
            outText = outText.replace(p, '')
        else:
            outText = outText.replace(p, ' ')
    outText = period_strip.sub("", outText, re.UNICODE)
    return outText


def process_digit_article(inText):
    outText = []
    tempText = inText.lower().split()
    for word in tempText:
        word = manual_map.setdefault(word, word)
        if word not in articles:
            outText.append(word)
        else:
            pass
    for wordId, word in enumerate(outText):
        if word in contractions:
            outText[wordId] = contractions[word]
    outText = ' '.join(outText)
    return outText


def prep_ans(answer):
    answer = process_digit_article(process_punctuation(answer))
    answer = answer.replace(',', '')
    return answer


================================================
FILE: evaluation/aok_utils/eval_predictions.py
================================================
import argparse
import pathlib
import json
import glob

from .load_aokvqa import load_aokvqa


def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):

    if isinstance(dataset, list):
        dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }

    # print(f'Loaded dataset size: {len(dataset)}')
    if multiple_choice is False:
        dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
    # print(f'Loaded dataset size: {len(dataset)}')

    if strict:
        dataset_qids = set(dataset.keys())
        preds_qids = set(preds.keys())
        assert dataset_qids.issubset(preds_qids)

    # dataset = q_id (str) : dataset element (dict)
    # preds = q_id (str) : prediction (str)

    acc = []

    for q in dataset.keys():
        if q not in preds.keys():
            acc.append(0.0)
            continue

        pred = preds[q]
        choices = dataset[q]['choices']
        direct_answers = dataset[q]['direct_answers']

        ## Multiple Choice setting
        if multiple_choice:
            if strict:
                assert pred in choices, 'Prediction must be a valid choice'
            correct_choice_idx = dataset[q]['correct_choice_idx']
            acc.append( float(pred == choices[correct_choice_idx]) )
        ## Direct Answer setting
        else:
            num_match = sum([pred == da for da in direct_answers])
            vqa_acc = min(1.0, num_match / 3.0)
            # with open('2.txt', 'a') as f:
            #     f.write(q + ' ' + str(vqa_acc) + '\n')
            acc.append(vqa_acc)

    acc = sum(acc) / len(acc) * 100

    return acc


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
    parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
    args = parser.parse_args()

    dataset = load_aokvqa(args.aokvqa_dir, args.split)

    for prediction_file in glob.glob(args.prediction_files):
        predictions = json.load(open(prediction_file, 'r'))

        # Multiple choice

        mc_predictions = {}

        for q in predictions.keys():
            if 'multiple_choice' in predictions[q].keys():
                mc_predictions[q] = predictions[q]['multiple_choice']

        if mc_predictions != {}:
            mc_acc = eval_aokvqa(
                dataset,
                mc_predictions,
                multiple_choice=True,
                strict=False
            )
            print(prediction_file, 'MC', mc_acc)

        # Direct Answer

        da_predictions = {}

        for q in predictions.keys():
            if 'direct_answer' in predictions[q].keys():
                da_predictions[q] = predictions[q]['direct_answer']

        if da_predictions != {}:
            da_acc = eval_aokvqa(
                dataset,
                da_predictions,
                multiple_choice=False,
                strict=False
            )
            print(prediction_file, 'DA', da_acc)


================================================
FILE: evaluation/aok_utils/load_aokvqa.py
================================================
import os
import json


def load_aokvqa(aokvqa_dir, split, version='v1p0'):
    assert split in ['train', 'val', 'test', 'test_w_ans']
    dataset = json.load(open(
        os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
    ))
    return dataset

def get_coco_path(split, image_id, coco_dir):
    return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")


================================================
FILE: evaluation/aok_utils/remap_predictions.py
================================================
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import argparse
import pathlib
import json
from tqdm import tqdm

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

from .load_aokvqa import load_aokvqa


def map_to_choices(dataset, predictions, device='cpu'):
    if isinstance(dataset, list):
        dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }

    if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
        return predictions

    model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
    model.to(device)
    for q in tqdm(predictions.keys()):
        choices = dataset[q]['choices']
        if predictions[q] not in choices:
            choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
            a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
            predictions[q] = choices[a_idx]

    return predictions


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    args = parser.parse_args()


    dataset = load_aokvqa(args.aokvqa_dir, args.split)
    predictions = json.load(args.prediction_file)
    # predictions = {qid: predictions[qid]['direct_answer'] for qid in predictions }
    # json.dump(predictions, open('cache/mcan_da.json', 'w'))
    predictions = map_to_choices(dataset, predictions)

    json.dump(predictions, args.output_file)


================================================
FILE: evaluation/aokvqa_evaluate.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Evaluation script for A-OKVQA
# ------------------------------------------------------------------------------ #

import json
from evaluation.aok_utils.eval_predictions import eval_aokvqa
from evaluation.aok_utils.remap_predictions import map_to_choices
from .ans_punct import prep_ans
import argparse

class AOKEvaluater:
    def __init__(self, annotation_path: str, question_path: str):
        self.annotation_path = annotation_path
        self.question_path = question_path
        self.dataset = json.load(open(question_path, 'r'))
        self.result_file = {}
        self.result_path = None
        self.multiple_choice = False
        self.map_to_mc = True
    
    def init(self):
        self.result_file = []
    
    def set_mode(self, multiple_choice=None, map_to_mc=None):
        if multiple_choice is not None:
            self.multiple_choice = multiple_choice
        if map_to_mc is not None:
            self.map_to_mc = map_to_mc
    
    def prep_ans(self, answer):
        return prep_ans(answer)
    
    def add(self, qid, answer):
        if self.multiple_choice:
            self.result_file[qid] = {
                'multiple_choice': answer,
            }
        else:
            self.result_file[qid] = {
                'direct_answer': answer,
            }
    
    def save(self, result_path: str):
        self.result_path = result_path
        if not self.multiple_choice and self.map_to_mc:
            predictions = {qid: item['direct_answer'] for qid, item in self.result_file.items()}
            predictions = map_to_choices(self.dataset, predictions, 'cuda:0')
            for qid, answer in predictions.items():
                self.result_file[qid]['multiple_choice'] = answer
        json.dump(self.result_file, open(self.result_path, 'w'))
    
    def evaluate(self, logfile=None):
        assert self.result_path is not None, "Please save the result file first."

        direct_answer = not self.multiple_choice
        multiple_choice = self.multiple_choice or self.map_to_mc
        eval_str = _evaluate(self.dataset, self.result_file, direct_answer=direct_answer, multiple_choice=multiple_choice)
        print(eval_str)
        if logfile is not None:
            print(eval_str + '\n', file=logfile)


def _evaluate(dataset, results, direct_answer=True, multiple_choice=True):
    result_str = ''

    if direct_answer:
        # Direct Answer Evaluation
        da_predictions = {}
        for qid, item in results.items():
            da_predictions[qid] = item['direct_answer']

        da_acc = eval_aokvqa(
            dataset,
            da_predictions,
            multiple_choice=False,
            strict=False
        )
        result_str += f'DA: {da_acc: .2f}\n'
        
    if multiple_choice:
        # Multiple Choice Evaluation
        mc_predictions = {}
        for qid, item in results.items():
            mc_predictions[qid] = item['multiple_choice']

        mc_acc = eval_aokvqa(
            dataset,
            mc_predictions,
            multiple_choice=True,
            strict=False
        )
        result_str += f'MC: {mc_acc: .2f}\n'
    return result_str

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate A-OKVQA result file.')
    parser.add_argument('--dataset_path', type=str, required=True)
    parser.add_argument('--result_path', type=str, required=True)
    parser.add_argument('--direct_answer', action='store_true')
    parser.add_argument('--multiple_choice', action='store_true')
    args = parser.parse_args()
    dataset = json.load(open(args.dataset_path, 'r'))
    result = json.load(open(args.result_path, 'r'))
    result_str = _evaluate(dataset, result, direct_answer=args.direct_answer, multiple_choice=args.multiple_choice)
    print(result_str)

================================================
FILE: evaluation/okvqa_evaluate.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Evaluation script for OK-VQA
# ------------------------------------------------------------------------------ #

import json
from evaluation.vqa_utils.vqa import VQA
from evaluation.vqa_utils.vqaEval import VQAEval
from .ans_punct import prep_ans
import argparse

class OKEvaluater:
    def __init__(self, annotation_path: str, question_path: str):
        self.annotation_path = annotation_path
        self.question_path = question_path
        # print(f'== Annotation file: {self.annotation_path}')
        # print(f'== Question file: {self.question_path}')
        self.result_file = []
        self.result_path = None

    def init(self):
        self.result_file = []

    def prep_ans(self, answer):
        return prep_ans(answer)
    
    def add(self, qid, answer):
        qid = int(qid)
        self.result_file.append({
            'question_id': qid,
            'answer': answer
        })
    
    def save(self, result_path: str):
        self.result_path = result_path
        json.dump(self.result_file, open(self.result_path, 'w'))
    
    def evaluate(self, logfile=None):
        assert self.result_path is not None, "Please save the result file first."

        eval_str = _evaluate(self.annotation_path, self.question_path, self.result_path)
        print()
        print(eval_str)
        if logfile is not None:
            print(eval_str + '\n', file=logfile)


def _evaluate(annotation_file: str, question_file: str, result_file: str):
    # print(f'== Annotation file: {annotation_file}')
    # print(f'== Question file: {question_file}')
    vqa = VQA(annotation_file, question_file)
    vqaRes_prophet = vqa.loadRes(result_file, question_file)
    vqaEval_prophet = VQAEval(vqa, vqaRes_prophet, n=2)
    vqaEval_prophet.evaluate()

    question_types = {
        "eight": "Plants and Animals",
        "nine": "Science and Technology",
        "four": "Sports and Recreation",
        "six": "Geography, History, Language and Culture",
        "two": "Brands, Companies and Products",
        "one": "Vehicles and Transportation",
        "five": "Cooking and Food",
        "ten": "Weather and Climate",
        "seven": "People and Everyday life",
        "three": "Objects, Material and Clothing"
        # "other": "Other",
    }

    result_str = ''
    result_str += "Overall Accuracy is: %.02f\n" % (vqaEval_prophet.accuracy['overall'])
    result_str += f"{'Question Type':40s}\t{'Prophet'}\n"
    for quesType in question_types:
        result_str += "%-40s\t%.02f\n" % (question_types[quesType], vqaEval_prophet.accuracy['perQuestionType'][quesType])
    # print(result_str)
    return result_str

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate OK-VQA result file.')
    parser.add_argument('--annotation_path', type=str, required=True)
    parser.add_argument('--question_path', type=str, required=True)
    parser.add_argument('--result_path', type=str, required=True)
    args = parser.parse_args()
    result_str = _evaluate(args.annotation_path, args.question_path, args.result_path)
    print(result_str)

================================================
FILE: evaluation/vqa_utils/vqa.py
================================================
__author__ = 'aagrawal'
__version__ = '0.9'

# Interface for accessing the VQA dataset.

# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).

# The following functions are defined:
#  VQA        - VQA class that loads VQA annotation file and prepares data structures.
#  getQuesIds - Get question ids that satisfy given filter conditions.
#  getImgIds  - Get image ids that satisfy given filter conditions.
#  loadQA     - Load questions and answers with the specified question ids.
#  showQA     - Display the specified questions and answers.
#  loadRes    - Load result file and create result object.

# Help on each function can be accessed by: "help(COCO.function)"

import json
import datetime
import copy


class VQA:
	def __init__(self, annotation_file=None, question_file=None):
		"""
       	Constructor of VQA helper class for reading and visualizing questions and answers.
        :param annotation_file (str): location of VQA annotation file
        :return:
		"""
		# load dataset
		self.dataset = {}
		self.questions = {}
		self.qa = {}
		self.qqa = {}
		self.imgToQA = {}
		if not annotation_file == None and not question_file == None:
			print('loading VQA annotations and questions into memory...')
			time_t = datetime.datetime.utcnow()
			dataset = json.load(open(annotation_file, 'r'))
			questions = json.load(open(question_file, 'r'))
			print(datetime.datetime.utcnow() - time_t)
			self.dataset = dataset
			self.questions = questions
			self.createIndex()

	def createIndex(self):
		# create index
		print('creating index...')
		imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
		qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
		qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
		for ann in self.dataset['annotations']:
			imgToQA[ann['image_id']] += [ann]
			qa[ann['question_id']] = ann
		for ques in self.questions['questions']:
			qqa[ques['question_id']] = ques
		print('index created!')

		# create class members
		self.qa = qa
		self.qqa = qqa
		self.imgToQA = imgToQA

	def info(self):
		"""
		Print information about the VQA annotation file.
		:return:
		"""
		for key, value in self.dataset['info'].items():
			print('%s: %s' % (key, value))

	def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
		"""
		Get question ids that satisfy given filter conditions. default skips that filter
		:param 	imgIds    (int array)   : get question ids for given imgs
				quesTypes (str array)   : get question ids for given question types
				ansTypes  (str array)   : get question ids for given answer types
		:return:    ids   (int array)   : integer array of question ids
		"""
		imgIds = imgIds if type(imgIds) == list else [imgIds]
		quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
		ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]

		if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
			anns = self.dataset['annotations']
		else:
			if not len(imgIds) == 0:
				anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
			else:
				anns = self.dataset['annotations']
			anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
			anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
		ids = [ann['question_id'] for ann in anns]
		return ids

	def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
		"""
		Get image ids that satisfy given filter conditions. default skips that filter
		:param quesIds   (int array)   : get image ids for given question ids
               quesTypes (str array)   : get image ids for given question types
               ansTypes  (str array)   : get image ids for given answer types
		:return: ids     (int array)   : integer array of image ids
		"""
		quesIds = quesIds if type(quesIds) == list else [quesIds]
		quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
		ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]

		if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
			anns = self.dataset['annotations']
		else:
			if not len(quesIds) == 0:
				anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
			else:
				anns = self.dataset['annotations']
			anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
			anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
		ids = [ann['image_id'] for ann in anns]
		return ids

	def loadQA(self, ids=[]):
		"""
		Load questions and answers with the specified question ids.
		:param ids (int array)       : integer ids specifying question ids
		:return: qa (object array)   : loaded qa objects
		"""
		if type(ids) == list:
			return [self.qa[id] for id in ids]
		elif type(ids) == int:
			return [self.qa[ids]]

	def showQA(self, anns):
		"""
		Display the specified annotations.
		:param anns (array of object): annotations to display
		:return: None
		"""
		if len(anns) == 0:
			return 0
		for ann in anns:
			quesId = ann['question_id']
			print("Question: %s" % (self.qqa[quesId]['question']))
			for ans in ann['answers']:
				print("Answer %d: %s" % (ans['answer_id'], ans['answer']))

	def loadRes(self, resFile, quesFile):
		"""
		Load result file and return a result object.
		:param   resFile (str)     : file name of result file
		:return: res (obj)         : result api object
		"""
		res = VQA()
		res.questions = json.load(open(quesFile))
		res.dataset['info'] = copy.deepcopy(self.questions['info'])
		res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
		res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
		res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
		res.dataset['license'] = copy.deepcopy(self.questions['license'])

		print('Loading and preparing results...     ')
		time_t = datetime.datetime.utcnow()
		anns = json.load(open(resFile))
		assert type(anns) == list, 'results is not an array of objects'
		annsQuesIds = [ann['question_id'] for ann in anns]
		assert set(annsQuesIds) == set(self.getQuesIds()), \
			'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
		for ann in anns:
			quesId = ann['question_id']
			if res.dataset['task_type'] == 'Multiple Choice':
				assert ann['answer'] in self.qqa[quesId][
					'multiple_choices'], 'predicted answer is not one of the multiple choices'
			qaAnn = self.qa[quesId]
			ann['image_id'] = qaAnn['image_id']
			ann['question_type'] = qaAnn['question_type']
			ann['answer_type'] = qaAnn['answer_type']
		print('DONE (t=%0.2fs)' % ((datetime.datetime.utcnow() - time_t).total_seconds()))

		res.dataset['annotations'] = anns
		res.createIndex()
		return res


================================================
FILE: evaluation/vqa_utils/vqaEval.py
================================================
# coding=utf-8

__author__='aagrawal'

# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re

class VQAEval:
	def __init__(self, vqa, vqaRes, n=2):
		self.n 			  = n
		self.accuracy     = {}
		self.evalQA       = {}
		self.evalQuesType = {}
		self.evalAnsType  = {}
		self.vqa 		  = vqa
		self.vqaRes       = vqaRes
		self.params		  = {'question_id': vqa.getQuesIds()}
		self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't",
							 "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't",
							 "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've",
							 "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've",
							 "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
							 "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
							 "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
							 "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've",
							 "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've",
							 "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll",
							 "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've",
							 "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've",
							 "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've",
							 "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've",
							 "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't",
							 "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're",
							 "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've",
							 "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
							 "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
							 "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
							 "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've",
							 "youll": "you'll", "youre": "you're", "youve": "you've"}
		self.manualMap    = { 'none': '0',
							  'zero': '0',
							  'one': '1',
							  'two': '2',
							  'three': '3',
							  'four': '4',
							  'five': '5',
							  'six': '6',
							  'seven': '7',
							  'eight': '8',
							  'nine': '9',
							  'ten': '10'
							}
		self.articles     = ['a',
							 'an',
							 'the'
							]
 

		self.periodStrip  = re.compile("(?!<=\d)(\.)(?!\d)")
		self.commaStrip   = re.compile("(\d)(,)(\d)")
		self.punct        = [';', r"/", '[', ']', '"', '{', '}',
							 '(', ')', '=', '+', '\\', '_', '-',
							 '>', '<', '@', '`', ',', '?', '!']

	
	def evaluate(self, quesIds=None):
		if quesIds == None:
			quesIds = [quesId for quesId in self.params['question_id']]
		gts = {}
		res = {}
		for quesId in quesIds:
			gts[quesId] = self.vqa.qa[quesId]
			res[quesId] = self.vqaRes.qa[quesId]
		
		# =================================================
		# Compute accuracy
		# =================================================
		accQA       = []
		accQuesType = {}
		accAnsType  = {}
		print ("computing accuracy")
		step = 0
		for quesId in quesIds:
			resAns      = res[quesId]['answer']
			resAns      = resAns.replace('\n', ' ')
			resAns      = resAns.replace('\t', ' ')
			resAns      = resAns.strip()
			resAns      = self.processPunctuation(resAns)
			resAns      = self.processDigitArticle(resAns)
			gtAcc  = []
			gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
			if len(set(gtAnswers)) > 1: 
				for ansDic in gts[quesId]['answers']:
					ansDic['answer'] = self.processPunctuation(ansDic['answer'])
			for gtAnsDatum in gts[quesId]['answers']:
				otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
				matchingAns = [item for item in otherGTAns if item['answer']==resAns]
				acc = min(1, float(len(matchingAns))/3)
				gtAcc.append(acc)
			quesType    = gts[quesId]['question_type']
			ansType     = gts[quesId]['answer_type']
			avgGTAcc = float(sum(gtAcc))/len(gtAcc)
			accQA.append(avgGTAcc)
			if quesType not in accQuesType:
				accQuesType[quesType] = []
			accQuesType[quesType].append(avgGTAcc)
			if ansType not in accAnsType:
				accAnsType[ansType] = []
			accAnsType[ansType].append(avgGTAcc)
			self.setEvalQA(quesId, avgGTAcc)
			self.setEvalQuesType(quesId, quesType, avgGTAcc)
			self.setEvalAnsType(quesId, ansType, avgGTAcc)
			if step%100 == 0:
				self.updateProgress(step/float(len(quesIds)))
			step = step + 1

		self.setAccuracy(accQA, accQuesType, accAnsType)
		print ("Done computing accuracy")
	
	def processPunctuation(self, inText):
		outText = inText
		for p in self.punct:
			if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
				outText = outText.replace(p, '')
			else:
				outText = outText.replace(p, ' ')	
		outText = self.periodStrip.sub("",
									  outText,
									  re.UNICODE)
		return outText
	
	def processDigitArticle(self, inText):
		outText = []
		tempText = inText.lower().split()
		for word in tempText:
			word = self.manualMap.setdefault(word, word)
			if word not in self.articles:
				outText.append(word)
			else:
				pass
		for wordId, word in enumerate(outText):
			if word in self.contractions: 
				outText[wordId] = self.contractions[word]
		outText = ' '.join(outText)
		return outText

	def setAccuracy(self, accQA, accQuesType, accAnsType):
		self.accuracy['overall']         = round(100*float(sum(accQA))/len(accQA), self.n)
		self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
		self.accuracy['perAnswerType']   = {ansType:  round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
			
	def setEvalQA(self, quesId, acc):
		self.evalQA[quesId] = round(100*acc, self.n)

	def setEvalQuesType(self, quesId, quesType, acc):
		if quesType not in self.evalQuesType:
			self.evalQuesType[quesType] = {}
		self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
	
	def setEvalAnsType(self, quesId, ansType, acc):
		if ansType not in self.evalAnsType:
			self.evalAnsType[ansType] = {}
		self.evalAnsType[ansType][quesId] = round(100*acc, self.n)

	def updateProgress(self, progress):
		barLength = 20
		status = ""
		if isinstance(progress, int):
			progress = float(progress)
		if not isinstance(progress, float):
			progress = 0
			status = "error: progress var must be float\r\n"
		if progress < 0:
			progress = 0
			status = "Halt...\r\n"
		if progress >= 1:
			progress = 1
			status = "Done...\r\n"
		block = int(round(barLength*progress))
		text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
		sys.stdout.write(text)
		sys.stdout.flush()



================================================
FILE: main.py
================================================
import argparse
import yaml
import torch

from evaluation.okvqa_evaluate import OKEvaluater
from evaluation.aokvqa_evaluate import AOKEvaluater
from configs.task_cfgs import Cfgs
from prophet import get_args, get_runner

# parse cfgs and args
args = get_args()
__C = Cfgs(args)
with open(args.cfg_file, 'r') as f:
    yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
__C.override_from_dict(yaml_dict)
print(__C)

# build runner
if __C.RUN_MODE == 'pretrain':
    evaluater = None
elif 'aok' in __C.TASK:
    evaluater = AOKEvaluater(
        __C.EVAL_ANSWER_PATH,
        __C.EVAL_QUESTION_PATH,
    )
else:
    evaluater = OKEvaluater(
        __C.EVAL_ANSWER_PATH,
        __C.EVAL_QUESTION_PATH,
    )

runner = get_runner(__C, evaluater)

# run
runner.run()


================================================
FILE: misc/tree.txt
================================================
prophet
├── assets
│   ├── answer_aware_examples_okvqa.json
│   ├── answer_dict_aokvqa.json
│   ├── answer_dict_okvqa.json
│   ├── answer_dict_vqav2.json
│   ├── candidates_aokvqa_test.json
│   ├── candidates_aokvqa_val.json
│   ├── candidates_okvqa.json
│   ├── captions_aokvqa.json
│   ├── captions_okvqa.json
│   ├── examples_aokvqa_test.json.json
│   └── examples_aokvqa_val.json.json
├── ckpts
│   ├── mcan_ft_aokvqa_test.pkl
│   ├── mcan_ft_aokvqa_val.pkl
│   ├── mcan_ft_okvqa.pkl
│   ├── mcan_pt_aokvqa_test.pkl
│   └── mcan_pt_aokvqa_val.pkl
│   ├── mcan_pt_okvqa.pkl
├── configs
│   ├── finetune.yml
│   ├── path_cfgs.py
│   ├── pretrain.yml
│   ├── prompt.yml
│   ├── task_cfgs.py
│   └── task_to_split.py
├── datasets
│   ├── aokvqa
│   │   ├── aokvqa_v1p0_test.json
│   │   ├── aokvqa_v1p0_train.json
│   │   └── aokvqa_v1p0_val.json
│   ├── coco2014
│   ├── coco2014_feats
│   ├── coco2017
│   ├── coco2017_feats
│   ├── okvqa
│   │   ├── mscoco_train2014_annotations.json
│   │   ├── mscoco_val2014_annotations.json
│   │   ├── OpenEnded_mscoco_train2014_questions.json
│   │   └── OpenEnded_mscoco_val2014_questions.json
│   └── vqav2
│       ├── v2_mscoco_train2014_annotations.json
│       ├── v2_mscoco_val2014_annotations.json
│       ├── v2_OpenEnded_mscoco_train2014_questions.json
│       ├── v2_OpenEnded_mscoco_val2014_questions.json
│       ├── v2valvg_no_ok_annotations.json
│       ├── v2valvg_no_ok_questions.json
│       ├── vg_annotations.json
│       └── vg_questions.json
├── environment.yml
├── evaluation
│   ├── ans_punct.py
│   ├── aok_utils
│   │   ├── eval_predictions.py
│   │   ├── load_aokvqa.py
│   │   └── remap_predictions.py
│   ├── aokvqa_evaluate.py
│   ├── okvqa_evaluate.py
│   └── vqa_utils
│       ├── vqaEval.py
│       └── vqa.py
├── main.py
├── misc
│   └── framework.png
├── outputs
│   ├── ckpts
│   ├── logs
│   └── results
├── preds
│   ├── mcan_530_okvqa.json
│   └── prophet_611_okvqa.json
├── prophet
│   ├── __init__.py
│   ├── stage1
│   │   ├── finetune.py
│   │   ├── heuristics.py
│   │   ├── model
│   │   │   ├── layers.py
│   │   │   ├── mcan_for_finetune.py
│   │   │   ├── mcan.py
│   │   │   ├── net_utils.py
│   │   │   └── rope2d.py
│   │   ├── pretrain.py
│   │   └── utils
│   │       ├── load_data.py
│   │       ├── optim.py
│   └── stage2
│       ├── prompt.py
│       └── utils
│           ├── data_utils.py
│           ├── fancy_pbar.py
├── README.md
├── scripts
│   ├── evaluate_model.sh
│   ├── extract_img_feats.sh
│   ├── finetune.sh
│   ├── heuristics_gen.sh
│   ├── pretrain.sh
│   └── prompt.sh
└── tools
    ├── extract_img_feats.py
    └── transforms.py


================================================
FILE: outputs/ckpts/.gitkeep
================================================


================================================
FILE: outputs/logs/.gitkeep
================================================


================================================
FILE: outputs/results/.gitkeep
================================================


================================================
FILE: preds/.gitkeep
================================================


================================================
FILE: prophet/__init__.py
================================================
__author__ = 'Zhenwei Shao'
__version__ = '1.0'

import argparse

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', dest='TASK', help="task name, one of ['ok', 'aok_val', 'aok_test']", type=str, required=True)
    parser.add_argument('--run_mode', dest='RUN_MODE', help="run mode, one of ['pretrain', 'finetune', 'finetune_test', 'heuristics', 'prompt']", type=str, required=True)
    parser.add_argument('--cfg', dest='cfg_file', help='config file', type=str, required=True)
    parser.add_argument('--version', dest='VERSION', help='version name, output folder will be named as version name', type=str, required=True)
    parser.add_argument('--ckpt_path', dest='CKPT_PATH', help='checkpoint path for test', type=str, default=None)
    parser.add_argument('--pretrained_model', dest='PRETRAINED_MODEL_PATH', help='pretrained model path', type=str, default=None)
    parser.add_argument('--debug', dest='DEBUG', help='debug mode', action='store_true')
    parser.add_argument('--resume', dest='RESUME', help='resume previous run', action='store_true')
    parser.add_argument('--gpu', dest='GPU', help='gpu id', type=str, default=None)
    parser.add_argument('--grad_accu', dest='GRAD_ACCU_STEPS', help='random seed', type=int, default=None)
    parser.add_argument('--seed', dest='SEED', help='random seed', type=int, default=99)
    parser.add_argument('--candidate_num', dest='CANDIDATE_NUM', help='topk candidates', type=int, default=None)
    parser.add_argument('--example_num', dest='EXAMPLE_NUM', help='number of most similar examples to be searched, default: 200', type=int, default=None)
    parser.add_argument('--examples_path', dest='EXAMPLES_PATH', help='answer-aware example file path, default: "assets/answer_aware_examples_for_ok.json"', type=str, default=None)
    parser.add_argument('--candidates_path', dest='CANDIDATES_PATH', help='candidates file path, default: "assets/candidates_for_ok.json"', type=str, default=None)
    parser.add_argument('--captions_path', dest='CAPTIONS_PATH', help='captions file path, default: "assets/captions_for_ok.json"', type=str, default=None)
    parser.add_argument('--openai_key', dest='OPENAI_KEY', help='openai api key', type=str, default=None)
    args = parser.parse_args()
    return args



def get_runner(__C, evaluater):
    if __C.RUN_MODE == 'pretrain':
        from .stage1.pretrain import Runner
    elif __C.RUN_MODE == 'finetune':
        from .stage1.finetune import Runner
    elif __C.RUN_MODE == 'finetune_test':
        from .stage1.finetune import Runner
    elif __C.RUN_MODE == 'heuristics':
        from .stage1.heuristics import Runner
    elif __C.RUN_MODE == 'prompt':
        from .stage2.prompt import Runner
    else:
        raise NotImplementedError
    runner = Runner(__C, evaluater)
    return runner

================================================
FILE: prophet/stage1/finetune.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Runner that handles the finetuning and evaluation process
# ------------------------------------------------------------------------------ #

import os, sys
# sys.path.append(os.getcwd())

from datetime import datetime
import pickle, random, math, time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import argparse
from pathlib import Path
from copy import deepcopy
import yaml

from configs.task_cfgs import Cfgs
from .utils.load_data import CommonData, DataSet
from .model.mcan_for_finetune import MCANForFinetune
from .utils.optim import get_optim_for_finetune as get_optim

class Runner(object):
    def __init__(self, __C, evaluater):
        self.__C = __C
        self.evaluater = evaluater
        
    def train(self, train_set, eval_set=None):
        data_size = train_set.data_size

        # Define the MCAN model
        net = MCANForFinetune(self.__C, train_set.ans_size)

        ## load the pretrained model
        if self.__C.PRETRAINED_MODEL_PATH is not None:
            print(f'Loading pretrained model from {self.__C.PRETRAINED_MODEL_PATH}')
            ckpt = torch.load(self.__C.PRETRAINED_MODEL_PATH, map_location='cpu')
            net.load_state_dict(ckpt['state_dict'], strict=False)
            net.parameter_init()
            print('Finish loading.')

        # Define the optimizer
        if self.__C.RESUME:
            raise NotImplementedError('Resume training is not needed as the finetuning is fast')
        else:
            optim = get_optim(self.__C, net)
            start_epoch = 0

        # load to gpu
        net.cuda()
        # Define the multi-gpu training if needed
        if self.__C.N_GPU > 1:
            net = nn.DataParallel(net, device_ids=self.__C.GPU_IDS)

        # Define the binary cross entropy loss
        loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')
        epoch_loss = 0

        # Define multi-thread dataloader
        dataloader = Data.DataLoader(
            train_set,
            batch_size=self.__C.BATCH_SIZE,
            shuffle=True,
            num_workers=self.__C.NUM_WORKERS,
            pin_memory=self.__C.PIN_MEM,
            drop_last=True
        )

        # Training script
        for epoch in range(start_epoch, self.__C.MAX_EPOCH):
            net.train()
            # Save log information
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                logfile.write(
                    f'nowTime: {datetime.now():%Y-%m-%d %H:%M:%S}\n'
                )

            time_start = time.time()

            # Iteration
            for step, input_tuple in enumerate(dataloader):
                iteration_loss = 0
                optim.zero_grad()
                input_tuple = [x.cuda() for x in input_tuple]
                SUB_BATCH_SIZE = self.__C.BATCH_SIZE // self.__C.GRAD_ACCU_STEPS
                for accu_step in range(self.__C.GRAD_ACCU_STEPS):

                    sub_tuple = [x[accu_step * SUB_BATCH_SIZE:
                        (accu_step + 1) * SUB_BATCH_SIZE] for x in input_tuple]
                    
                    sub_ans_iter = sub_tuple[-1]
                    pred = net(sub_tuple[:-1])
                    loss = loss_fn(pred, sub_ans_iter)
                    loss.backward()
                    loss_item = loss.item()
                    iteration_loss += loss_item
                    epoch_loss += loss_item# * self.__C.GRAD_ACCU_STEPS

                print("\r[version %s][epoch %2d][step %4d/%4d][Task %s][Mode %s] loss: %.4f, lr: %.2e" % (
                    self.__C.VERSION,
                    epoch + 1,
                    step,
                    int(data_size / self.__C.BATCH_SIZE),
                    self.__C.TASK,
                    self.__C.RUN_MODE,
                    iteration_loss / self.__C.BATCH_SIZE,
                    optim.current_lr(),
                ), end='          ')

                optim.step()

            time_end = time.time()
            print('Finished in {}s'.format(int(time_end - time_start)))

            # Logging
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                logfile.write(f'epoch = {epoch + 1}  loss = {epoch_loss / data_size}\nlr = {optim.current_lr()}\n\n')
            
            optim.schedule_step(epoch)

            # Save checkpoint
            state = {
                'state_dict': net.state_dict() if self.__C.N_GPU == 1 \
                    else net.module.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'warmup_lr_scale': optim.warmup_lr_scale,
                'decay_lr_scale': optim.decay_lr_scale,
            }
            torch.save(
                state,
                f'{self.__C.CKPTS_DIR}/epoch{epoch + 1}.pkl'
            )


            # Eval after every epoch
            if eval_set is not None:
                self.eval(
                    eval_set,
                    net,
                    eval_now=True
                )
            
            epoch_loss = 0

    # Evaluation
    @torch.no_grad()
    def eval(self, dataset, net=None, eval_now=False):
        data_size = dataset.data_size

        # if eval_now and self.evaluater is None:
        #     self.build_evaluator(dataset)
        
        if net is None:
            # Load parameters
            path = self.__C.CKPT_PATH

            print('Loading ckpt {}'.format(path))
            net = MCANForFinetune(self.__C, dataset.ans_size)
            ckpt = torch.load(path, map_location='cpu')
            net.load_state_dict(ckpt['state_dict'], strict=False)
            net.cuda()
            if self.__C.N_GPU > 1:
                net = nn.DataParallel(net, device_ids=self.__C.GPU)
            print('Finish!')

        net.eval()
        
        dataloader = Data.DataLoader(
            dataset,
            batch_size=self.__C.EVAL_BATCH_SIZE,
            shuffle=False,
            num_workers=self.__C.NUM_WORKERS,
            pin_memory=True
        )

        qid_idx = 0
        self.evaluater.init()

        for step, input_tuple in enumerate(dataloader):
            print("\rEvaluation: [step %4d/%4d]" % (
                step,
                int(data_size / self.__C.EVAL_BATCH_SIZE),
            ), end='          ')

            input_tuple = [x.cuda() for x in input_tuple]


            pred = net(input_tuple[:-1])
            pred_np = pred.cpu().numpy()
            pred_argmax = np.argmax(pred_np, axis=1)

            # collect answers for every batch
            for i in range(len(pred_argmax)):
                qid = dataset.qids[qid_idx]
                qid_idx += 1
                ans_id = int(pred_argmax[i])
                ans = dataset.ix_to_ans[ans_id]
                # log result to evaluater
                self.evaluater.add(qid, ans)
        
        print()
        self.evaluater.save(self.__C.RESULT_PATH)
        # evaluate if eval_now is True
        if eval_now:
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                self.evaluater.evaluate(logfile)

    # def build_evaluator(self, valid_set):
    #     if 'aok' in self.__C.TASK:
    #         from evaluation.aokvqa_evaluate import Evaluater
    #     elif 'ok' in self.__C.TASK:
    #         from evaluation.okvqa_evaluate import Evaluater
    #     else:
    #         raise ValueError('Unknown dataset')
    #     self.evaluater = Evaluater(
    #         valid_set.annotation_path,
    #         valid_set.question_path,
    #     )

    def run(self):
        # Set ckpts and log path
        ## where checkpoints will be saved
        Path(self.__C.CKPTS_DIR).mkdir(parents=True, exist_ok=True)
        ## where logs will be saved
        Path(self.__C.LOG_PATH).parent.mkdir(parents=True, exist_ok=True)
        ## where eval results will be saved
        Path(self.__C.RESULT_PATH).parent.mkdir(parents=True, exist_ok=True)
        with open(self.__C.LOG_PATH, 'w') as f:
            f.write(str(self.__C) + '\n')

        # build dataset entities        
        common_data = CommonData(self.__C)

        if self.__C.RUN_MODE == 'finetune':
            train_set = DataSet(
                self.__C, 
                common_data,
                self.__C.TRAIN_SPLITS
            )
            valid_set = None
            if self.__C.EVAL_NOW:
                valid_set = DataSet(
                    self.__C,
                    common_data,
                    self.__C.EVAL_SPLITS
                )
            self.train(train_set, valid_set)
        elif self.__C.RUN_MODE == 'finetune_test':
            test_set = DataSet(
                self.__C,
                common_data,
                self.__C.EVAL_SPLITS
            )
            self.eval(test_set, eval_now=self.__C.EVAL_NOW)
        else:
            raise ValueError('Invalid run mode')

def finetune_login_args(parser):
    parser.add_argument('--task', dest='TASK', help='task name, e.g., ok, aok_val, aok_test', type=str, required=True)
    parser.add_argument('--run_mode', dest='RUN_MODE', help='run mode', type=str, required=True)
    parser.add_argument('--cfg', dest='cfg_file', help='optional config file', type=str, required=True)
    parser.add_argument('--version', dest='VERSION', help='version name', type=str, required=True)
    parser.add_argument('--resume', dest='RESUME', help='resume training', type=bool, default=False)
    parser.add_argument('--resume_version', dest='RESUME_VERSION', help='checkpoint version name', type=str, default='')
    parser.add_argument('--resume_epoch', dest='RESUME_EPOCH', help='checkpoint epoch', type=int, default=1)
    parser.add_argument('--resume_path', dest='RESUME_PATH', help='checkpoint path', type=str, default='')
    parser.add_argument('--ckpt_path', dest='CKPT_PATH', help='checkpoint path for test', type=str, default=None)
    parser.add_argument('--gpu', dest='GPU', help='gpu id', type=str, default=None)
    parser.add_argument('--seed', dest='SEED', help='random seed', type=int, default=None)
    parser.add_argument('--grad_accu', dest='GRAD_ACCU_STEPS', help='random seed', type=int, default=None)
    parser.add_argument('--pretrained_model', dest='PRETRAINED_MODEL_PATH', help='pretrained model path', type=str, default=None)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameters for pretraining')
    finetune_login_args(parser)
    args = parser.parse_args()
    __C = Cfgs(args)
    with open(args.cfg_file, 'r') as f:
        yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
    __C.override_from_dict(yaml_dict)
    print(__C)
    runner = Runner(__C)
    runner.run()


================================================
FILE: prophet/stage1/heuristics.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Runner that handles the heuristics generations process
# ------------------------------------------------------------------------------ #

import os, sys
# sys.path.append(os.getcwd())

from datetime import datetime
import pickle, random, math, time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import torch.utils.data as Data
import argparse
from pathlib import Path
import yaml
from copy import deepcopy
from tqdm import tqdm

from configs.task_cfgs import Cfgs
from .utils.load_data import CommonData, DataSet
from .model.mcan_for_finetune import MCANForFinetune
from .utils.optim import get_optim_for_finetune as get_optim

class Runner(object):
    def __init__(self, __C, *args, **kwargs):
        self.__C = __C
        self.net = None

    # heuristics generation
    @torch.no_grad()
    def eval(self, dataset):
        data_size = dataset.data_size

        if self.net is None:
            # Load parameters
            path = self.__C.CKPT_PATH
            print('Loading ckpt {}'.format(path))
            net = MCANForFinetune(self.__C, dataset.ans_size)
            ckpt = torch.load(path, map_location='cpu')
            net.load_state_dict(ckpt['state_dict'], strict=False)
            net.cuda()
            if self.__C.N_GPU > 1:
                net = nn.DataParallel(net, device_ids=self.__C.GPU_IDS)
            print('Finish!')
            self.net = net
        else:
            net = self.net


        net.eval()
        
        dataloader = Data.DataLoader(
            dataset,
            batch_size=self.__C.EVAL_BATCH_SIZE,
            shuffle=False,
            num_workers=self.__C.NUM_WORKERS,
            pin_memory=True
        )

        qid_idx = 0
        topk_results = {}
        latent_results = []
        k = self.__C.CANDIDATE_NUM

        for step, input_tuple in enumerate(dataloader):
            print("\rEvaluation: [step %4d/%4d]" % (
                step,
                int(data_size / self.__C.EVAL_BATCH_SIZE),
            ), end='          ')

            input_tuple = [x.cuda() for x in input_tuple]


            pred, answer_latents = net(input_tuple[:-1], output_answer_latent=True)
            pred_np = pred.sigmoid().cpu().numpy()
            answer_latents_np = answer_latents.cpu().numpy()

            # collect answers for every batch
            for i in range(len(pred_np)):
                qid = dataset.qids[qid_idx]
                qid_idx += 1
                ans_np = pred_np[i]
                ans_idx = np.argsort(-ans_np)[:k]
                ans_item = []
                for idx in ans_idx:
                    ans_item.append(
                        {
                            'answer': dataset.ix_to_ans[idx],
                            'confidence': float(ans_np[idx])
                        }
                    )
                topk_results[qid] = ans_item

                latent_np = answer_latents_np[i]
                latent_results.append(latent_np)
                np.save(
                    os.path.join(self.__C.ANSWER_LATENTS_DIR, f'{qid}.npy'),
                    latent_np
                )
        print()
        
        return topk_results, latent_results

    def run(self):
        # Set ckpts and log path
        ## where checkpoints will be saved
        Path(self.__C.CKPTS_DIR).mkdir(parents=True, exist_ok=True)
        ## where the result file of topk candidates will be saved
        Path(self.__C.CANDIDATE_FILE_PATH).parent.mkdir(parents=True, exist_ok=True)
        ## where answer latents will be saved
        Path(self.__C.ANSWER_LATENTS_DIR).mkdir(parents=True, exist_ok=True)

        # build dataset entities        
        common_data = CommonData(self.__C)
        train_set = DataSet(
            self.__C,
            common_data,
            self.__C.TRAIN_SPLITS
        )
        test_set = DataSet(
            self.__C,
            common_data,
            self.__C.EVAL_SPLITS
        )

        # forward VQA model
        train_topk_results, train_latent_results = self.eval(train_set)
        test_topk_results, test_latent_results = self.eval(test_set)

        # save topk candidates
        topk_results = train_topk_results | test_topk_results
        json.dump(
            topk_results,
            open(self.__C.CANDIDATE_FILE_PATH, 'w'),
            indent=4
        )

        # search similar examples
        train_features = np.vstack(train_latent_results)
        train_features = train_features / np.linalg.norm(train_features, axis=1, keepdims=True)

        test_features = np.vstack(test_latent_results)
        test_features = test_features / np.linalg.norm(test_features, axis=1, keepdims=True)

        # compute top-E similar examples for each testing input
        E = self.__C.EXAMPLE_NUM
        similar_qids = {}
        print(f'\ncompute top-{E} similar examples for each testing input')
        for i, test_qid in enumerate(tqdm(test_set.qids)):
            # cosine similarity
            dists = np.dot(test_features[i], train_features.T)
            top_E = np.argsort(-dists)[:E]
            similar_qids[test_qid] = [train_set.qids[j] for j in top_E]
        
        # save similar qids
        with open(self.__C.EXAMPLE_FILE_PATH, 'w') as f:
            json.dump(similar_qids, f)

def heuristics_login_args(parser):
    parser.add_argument('--task', dest='TASK', help='task name, e.g., ok, aok_val, aok_test', type=str, required=True)
    parser.add_argument('--cfg', dest='cfg_file', help='optional config file', type=str, required=True)
    parser.add_argument('--version', dest='VERSION', help='version name', type=str, required=True)
    parser.add_argument('--ckpt_path', dest='CKPT_PATH', help='checkpoint path for heuristics', type=str, default=None)
    parser.add_argument('--gpu', dest='GPU', help='gpu id', type=str, default=None)
    parser.add_argument('--candidate_num', dest='CANDIDATE_NUM', help='topk candidates', type=int, default=None)
    parser.add_argument('--example_num', dest='EXAMPLE_NUM', help='number of most similar examples to be searched, default: 200', type=int, default=None)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameters for pretraining')
    heuristics_login_args(parser)
    args = parser.parse_args()
    __C = Cfgs(args)
    with open(args.cfg_file, 'r') as f:
        yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
    __C.override_from_dict(yaml_dict)
    print(__C)
    runner = Runner(__C)
    runner.run()


================================================
FILE: prophet/stage1/model/layers.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: basic layers & blocks of MCAN
# ------------------------------------------------------------------------------ #

import torch
from torch import nn
from torch.nn import functional as F
import math

from .net_utils import *
from .rope2d import RoPE2d

class AttFlat(nn.Module):
    def __init__(self, __C):
        super(AttFlat, self).__init__()
        self.__C = __C

        self.mlp = MLP(
            in_size=__C.HIDDEN_SIZE,
            mid_size=__C.FLAT_MLP_SIZE,
            out_size=__C.FLAT_GLIMPSES,
            dropout_r=__C.DROPOUT_R,
            use_relu=True
        )

        self.linear_merge = nn.Linear(
            __C.HIDDEN_SIZE * __C.FLAT_GLIMPSES,
            __C.FLAT_OUT_SIZE
        )

    def forward(self, x, x_mask):
        att = self.mlp(x)
        if x_mask is not None:
            att = att.masked_fill(
                x_mask.squeeze(1).squeeze(1).unsqueeze(2),
                -1e9
            )
        att = F.softmax(att, dim=1)

        att_list = []
        for i in range(self.__C.FLAT_GLIMPSES):
            att_list.append(
                torch.sum(att[:, :, i: i + 1] * x, dim=1)
            )

        x_atted = torch.cat(att_list, dim=1)
        x_atted = self.linear_merge(x_atted)

        return x_atted


class MHAtt(nn.Module):
    def __init__(self, __C):
        super().__init__()
        self.__C = __C
        self.n_head = __C.MULTI_HEAD
        self.external_dim = __C.HIDDEN_SIZE
        self.internal_dim = __C.HIDDEN_SIZE // self.n_head

        self.linear_v = nn.Linear(self.external_dim, self.external_dim, bias=False)
        self.linear_k = nn.Linear(self.external_dim, self.external_dim)
        self.linear_q = nn.Linear(self.external_dim, self.external_dim)
        self.linear_merge = nn.Linear(self.external_dim, self.external_dim)

        self.dropout = nn.Dropout(__C.DROPOUT_R)

    def forward(self, v, k, q, mask):
        n_batches = q.size(0)

        v = self.linear_v(v).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        k = self.linear_k(k).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        q = self.linear_q(q).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        atted = self.att(v, k, q, mask)
        atted = atted.transpose(1, 2).contiguous().view(
            n_batches, -1, self.external_dim
        )
        atted = self.linear_merge(atted)

        return atted

    def att(self, value, key, query, mask):
        d_k = query.size(-1)

        scores = torch.matmul(
            query, key.transpose(-2, -1)
        ) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)


class SA_v(nn.Module):
    def __init__(self, __C):
        super().__init__()
        self.__C = __C
        self.n_head = __C.MULTI_HEAD
        self.external_dim = __C.HIDDEN_SIZE
        self.internal_dim = __C.HIDDEN_SIZE // self.n_head

        self.linear_v = nn.Linear(self.external_dim, self.external_dim, bias=False)
        self.linear_k = nn.Linear(self.external_dim, self.external_dim)
        self.linear_q = nn.Linear(self.external_dim, self.external_dim)
        self.linear_merge = nn.Linear(self.external_dim, self.external_dim)

        self.dropout = nn.Dropout(__C.DROPOUT_R)


        self.dropout1 = nn.Dropout(__C.DROPOUT_R)
        self.norm1 = nn.LayerNorm(__C.HIDDEN_SIZE)
        self.rope = RoPE2d(self.internal_dim, __C.IMG_FEAT_GRID)

    def forward(self, *args):
        x, *_ = args
        n_batches = x.size(0)

        v = self.linear_v(x).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        k = self.linear_k(x).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        q = self.linear_q(x).view(
            n_batches, -1, self.n_head, self.internal_dim
        ).transpose(1, 2)

        q, k = self.rope(q, k)

        atted = self.att(v, k, q, None)
        atted = atted.transpose(1, 2).contiguous().view(
            n_batches, -1, self.external_dim
        )
        atted = self.linear_merge(atted)

        x = self.norm1(x + self.dropout1(atted))

        return x

    def att(self, value, key, query, mask):
        d_k = query.size(-1)

        scores = torch.matmul(
            query, key.transpose(-2, -1)
        ) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)


class FFN(nn.Module):
    def __init__(self, __C):
        super(FFN, self).__init__()

        self.mlp = MLP(
            in_size=__C.HIDDEN_SIZE,
            mid_size=__C.FF_SIZE,
            out_size=__C.HIDDEN_SIZE,
            dropout_r=__C.DROPOUT_R,
            use_relu=True
        )
        self.dropout1 = nn.Dropout(__C.DROPOUT_R)
        self.norm1 = nn.LayerNorm(__C.HIDDEN_SIZE)

    def forward(self, x, *args):
        x = self.norm1(x + self.dropout1(
            self.mlp(x)
        ))
        return x


class SA(nn.Module):
    def __init__(self, __C):
        super(SA, self).__init__()

        self.mhatt = MHAtt(__C)

        self.dropout1 = nn.Dropout(__C.DROPOUT_R)
        self.norm1 = nn.LayerNorm(__C.HIDDEN_SIZE)

    def forward(self, x, x_mask, *args):
        x = self.norm1(x + self.dropout1(
            self.mhatt(x, x, x, x_mask)
        ))

        return x


class GA(nn.Module):
    def __init__(self, __C):
        super().__init__()

        self.mhatt1 = MHAtt(__C)

        self.dropout1 = nn.Dropout(__C.DROPOUT_R)
        self.norm1 = nn.LayerNorm(__C.HIDDEN_SIZE)

    def forward(self, x, y, x_mask, y_mask, *args):

        x = self.norm1(x + self.dropout1(
            self.mhatt1(y, y, x, y_mask)
        ))

        return x

================================================
FILE: prophet/stage1/model/mcan.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: the definition of the improved MCAN
# ------------------------------------------------------------------------------ #

import torch
from torch import nn
from torch.nn import functional as F
import math
from transformers import AutoModel, logging
logging.set_verbosity_error()

from .net_utils import *
from .layers import *


class MCA_ED(nn.Module):
    """
    The definition of the encoder-decoder backbone of MCAN.
    """
    def __init__(self, __C):
        super(MCA_ED, self).__init__()

        enc = __C.ARCH_CEIL['enc'] * __C.LAYER
        dec = __C.ARCH_CEIL['dec'] * __C.LAYER
        self.enc_list = nn.ModuleList([eval(layer)(__C) for layer in enc])
        self.dec_list = nn.ModuleList([eval(layer)(__C) for layer in dec])

    def forward(self, x, y, x_mask, y_mask):
        for enc in self.enc_list:
            x = enc(x, x_mask)

        for dec in self.dec_list:
            y = dec(y, x, y_mask, x_mask)

        return x, y



class MCAN(nn.Module):
    """
    The definition of the complete network of the improved MCAN, mainly includes:
    1. A pretrained BERT model used to encode questions (already represented as tokens)
    2. A linear layer to project CLIP vision features (extracted beforehand, so the CLIP
        model is not included) to a common embedding space
    3. An encoder-decoder backbone to fuse question and image features in depth
    4. A classifier head based on `AttFlat`
    """
    def __init__(self, __C, answer_size):
        super().__init__()

        # answer_size = trainset.ans_size

        self.__C = __C

        self.bert = AutoModel.from_pretrained(__C.BERT_VERSION)

        # self.clip_visual = trainset.clip_model.visual
        # self.clip_visual.layer4 = Identity()
        # self.clip_visual.float()

        # for p in self.clip_visual.parameters():
        #     p.requires_grad = False

        self.img_feat_linear = nn.Sequential(
            nn.Linear(__C.IMG_FEAT_SIZE, __C.HIDDEN_SIZE, bias=False),
        )
        self.lang_adapt = nn.Sequential(
            nn.Linear(__C.LANG_FEAT_SIZE, __C.HIDDEN_SIZE),
            nn.Tanh(),
        )

        self.backbone = MCA_ED(__C)
        self.attflat_img = AttFlat(__C)
        self.attflat_lang = AttFlat(__C)
        self.proj_norm = nn.LayerNorm(__C.FLAT_OUT_SIZE)
        self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size)

    def forward(self, input_tuple, output_answer_latent=False):
        img_feat, ques_ix = input_tuple

        # Make mask
        lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2))
        img_feat_mask = None#self.make_mask(img_feat)

        # Pre-process Language Feature
        lang_feat = self.bert(
            ques_ix, 
            attention_mask= ~lang_feat_mask.squeeze(1).squeeze(1)
        )[0]
        lang_feat = self.lang_adapt(lang_feat)

        # Pre-process Image Feature
        img_feat = self.img_feat_linear(img_feat)


        # Backbone Framework
        # img_feat = flatten(img_feat)
        lang_feat, img_feat = self.backbone(
            lang_feat,
            img_feat,
            lang_feat_mask,
            img_feat_mask
        )
        lang_feat = self.attflat_lang(
            lang_feat,
            lang_feat_mask
        )
        img_feat = self.attflat_img(
            img_feat,
            img_feat_mask
        )

        proj_feat = lang_feat + img_feat
        answer_latent = self.proj_norm(proj_feat)
        proj_feat = self.proj(answer_latent)

        if output_answer_latent:
            return proj_feat, answer_latent

        return proj_feat

    # Masking
    def make_mask(self, feature):
        return (torch.sum(
            torch.abs(feature),
            dim=-1
        ) == 0).unsqueeze(1).unsqueeze(2)


================================================
FILE: prophet/stage1/model/mcan_for_finetune.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: the definition of A wrapper of MCAN for finetuning with the 
# strategy described in the paper.
# ------------------------------------------------------------------------------ #

import torch
from torch import nn
from torch.nn import functional as F

from .mcan import *


class MCANForFinetune(MCAN):
    """
    A wrapper of MCAN for finetuning with the strategy described 
    in the paper. We inherit the parameters of existing answers 
    and append new parameters for the new answers.
    """
    def __init__(self, __C, answer_size, base_answer_size=3129):
        super().__init__(__C, base_answer_size)

        self.proj1 = nn.Linear(__C.FLAT_OUT_SIZE, answer_size - base_answer_size)

    @torch.no_grad()
    def parameter_init(self):
        self.proj1.weight.data.zero_()
        self.proj1.bias.data = self.proj.bias.data.mean() + torch.zeros(self.proj1.bias.data.shape)

    def forward(self, input_tuple, output_answer_latent=False):
        proj_feat, answer_latent = super().forward(input_tuple, output_answer_latent=True)
        proj_feat = torch.cat([
            proj_feat,
            self.proj1(answer_latent)
        ], dim=1)
        
        if output_answer_latent:
            return proj_feat, answer_latent

        return proj_feat


================================================
FILE: prophet/stage1/model/net_utils.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Utilities for layer definitions
# ------------------------------------------------------------------------------ #

from torch import nn
import math

class FC(nn.Module):
    def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
        super(FC, self).__init__()
        self.dropout_r = dropout_r
        self.use_relu = use_relu

        self.linear = nn.Linear(in_size, out_size)

        if use_relu:
            self.relu = nn.ReLU(inplace=True)

        if dropout_r > 0:
            self.dropout = nn.Dropout(dropout_r)

    def forward(self, x):
        x = self.linear(x)

        if self.use_relu:
            x = self.relu(x)

        if self.dropout_r > 0:
            x = self.dropout(x)

        return x


class MLP(nn.Module):
    def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
        super(MLP, self).__init__()

        self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
        self.linear = nn.Linear(mid_size, out_size)

    def forward(self, x):
        return self.linear(self.fc(x))


def flatten(x):
    x = x.view(x.shape[0], x.shape[1], -1)\
        .permute(0, 2, 1).contiguous()
    return x


def unflatten(x, shape):
    x = x.permute(0, 2, 1).contiguous()\
        .view(x.shape[0], -1, shape[0], shape[1])
    return x


class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x


================================================
FILE: prophet/stage1/model/rope2d.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: A 2D version of rotary positional embeddings 
# (https://arxiv.org/abs/2104.09864).
# ------------------------------------------------------------------------------ #


import math
import torch
import torch.nn.functional as F
from torch import nn
# from einops import rearrange, repeat

def rotate_every_two(x):
    shape = x.shape
    # x = rearrange(x, '... (d j) -> ... d j', j = 2)
    # x1, x2 = x.unbind(dim = -1)
    x = x.view(*shape[:-1], -1, 2)[..., [1, 0]]
    x = x.view(*shape)
    return x

def apply_rotary_pos_emb(q, k, sinu_pos):
    sin, cos = sinu_pos
    q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
    return q, k

# rotary embeddings for 2d position
class RoPE2d(nn.Module):
    def __init__(self, in_dim, size):
        super().__init__()
        dim = in_dim // 2
        inv_freq = 1. / (40 ** (torch.arange(0, dim, 2).float() / dim))
        position = torch.arange(0, size, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        _sin = sinusoid_inp.sin()
        _cos = sinusoid_inp.cos()
        _sin, _cos = map(
            lambda x: x.unsqueeze(-1).repeat(1, 1, 2),
            (_sin, _cos)
        )
        _sin[..., 0] = -_sin[..., 0]
        _sin, _cos = map(lambda x: x.view(*x.shape[:-2], -1), (_sin, _cos))
        _sin, _cos = map(
            lambda x: torch.cat([
                x.unsqueeze(0).repeat(size, 1, 1),
                x.unsqueeze(1).repeat(1, size, 1)
            ], dim=-1).view(-1, in_dim),
            (_sin, _cos)
        )
        self.register_buffer('sin', _sin)
        self.register_buffer('cos', _cos)

    def forward(self, k, q):
        q, k = apply_rotary_pos_emb(q, k, (self.sin, self.cos))
        return q, k

if __name__ == '__main__':
    rope = RoPE2d(512, size=4)
    q = torch.randn(1, 16, 512)
    k = torch.randn(1, 16, 512)
    q, k = rope(q, k)
    print(q.shape, k.shape)
    


================================================
FILE: prophet/stage1/pretrain.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Runner that handles the pretraining process
# ------------------------------------------------------------------------------ #

import os, sys
# sys.path.append(os.getcwd())

from datetime import datetime
import pickle, random, math, time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import argparse
from pathlib import Path
from copy import deepcopy
import yaml

from configs.task_cfgs import Cfgs
from .utils.load_data import CommonData, DataSet
from .model.mcan import MCAN
from .utils.optim import get_optim

class Runner(object):
    def __init__(self, __C, *args, **kwargs):
        self.__C = __C

    def train(self, train_set, eval_set=None):
        data_size = train_set.data_size

        # Define the MCAN model
        net = MCAN(self.__C, train_set.ans_size)

        # Define the optimizer
        # Load checkpoint if resume training
        if self.__C.RESUME:
            print(' ========== Resume training')

            path = self.__C.RESUME_PATH

            # Load the network parameters
            print('Loading ckpt {}'.format(path))
            ckpt = torch.load(path, map_location='cpu')
            print('Finish loading.')
            net.load_state_dict(ckpt['state_dict'])

            # Load the optimizer paramters
            optim = get_optim(self.__C, net)
            optim.warmup_lr_scale = ckpt['warmup_lr_scale']
            optim.decay_lr_scale = ckpt['decay_lr_scale']
            optim.optimizer.load_state_dict(ckpt['optimizer'])
            start_epoch = self.__C.CKPT_EPOCH

        else:
            optim = get_optim(self.__C, net)
            start_epoch = 0

        # load to gpu
        net.cuda()
        # Define the multi-gpu training if needed
        if self.__C.N_GPU > 1:
            net = nn.DataParallel(net, device_ids=self.__C.GPU_IDS)

        # Define the binary cross entropy loss
        loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')
        epoch_loss = 0

        # Define multi-thread dataloader
        dataloader = Data.DataLoader(
            train_set,
            batch_size=self.__C.BATCH_SIZE,
            shuffle=True,
            num_workers=self.__C.NUM_WORKERS,
            pin_memory=self.__C.PIN_MEM,
            drop_last=True
        )

        # Training script
        for epoch in range(start_epoch, self.__C.MAX_EPOCH):
            net.train()
            # Save log information
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                logfile.write(
                    f'nowTime: {datetime.now():%Y-%m-%d %H:%M:%S}\n'
                )

            time_start = time.time()

            # Iteration
            for step, input_tuple in enumerate(dataloader):
                iteration_loss = 0
                optim.zero_grad()
                input_tuple = [x.cuda() for x in input_tuple]
                SUB_BATCH_SIZE = self.__C.BATCH_SIZE // self.__C.GRAD_ACCU_STEPS
                for accu_step in range(self.__C.GRAD_ACCU_STEPS):

                    sub_tuple = [x[accu_step * SUB_BATCH_SIZE:
                        (accu_step + 1) * SUB_BATCH_SIZE] for x in input_tuple]
                    
                    sub_ans_iter = sub_tuple[-1]
                    pred = net(sub_tuple[:-1])
                    loss = loss_fn(pred, sub_ans_iter)
                    loss.backward()
                    loss_item = loss.item()
                    iteration_loss += loss_item
                    epoch_loss += loss_item# * self.__C.GRAD_ACCU_STEPS

                print("\r[version %s][epoch %2d][step %4d/%4d][Task %s][Mode %s] loss: %.4f, lr: %.2e" % (
                    self.__C.VERSION,
                    epoch + 1,
                    step,
                    int(data_size / self.__C.BATCH_SIZE),
                    self.__C.TASK,
                    self.__C.RUN_MODE,
                    iteration_loss / self.__C.BATCH_SIZE,
                    optim.current_lr(),
                ), end='          ')

                optim.step()

            time_end = time.time()
            print('Finished in {}s'.format(int(time_end - time_start)))

            # Logging
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                logfile.write(f'epoch = {epoch + 1}  loss = {epoch_loss / data_size}\nlr = {optim.current_lr()}\n\n')
            
            optim.schedule_step(epoch)

            # Save checkpoint
            state = {
                'state_dict': net.state_dict() if self.__C.N_GPU == 1 \
                    else net.module.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'warmup_lr_scale': optim.warmup_lr_scale,
                'decay_lr_scale': optim.decay_lr_scale,
            }
            torch.save(
                state,
                f'{self.__C.CKPTS_DIR}/epoch{epoch + 1}.pkl'
            )

            epoch_loss = 0

    def run(self):
        # Set ckpts and log path
        Path(self.__C.CKPTS_DIR).mkdir(parents=True, exist_ok=True)
        Path(self.__C.LOG_PATH).parent.mkdir(parents=True, exist_ok=True)
        with open(self.__C.LOG_PATH, 'w') as f:
            f.write(str(self.__C) + '\n')
        
        common_data = CommonData(self.__C)
        train_set = DataSet(
            self.__C, 
            common_data,
            self.__C.TRAIN_SPLITS
        )
        valid_set = None
        self.train(train_set, valid_set)

def pretrain_login_args(parser):
    parser.add_argument('--task', dest='TASK', help='task name, e.g., ok, aok_val, aok_test', type=str, required=True)
    parser.add_argument('--cfg', dest='cfg_file', help='optional config file', type=str, required=True)
    parser.add_argument('--version', dest='VERSION', help='version name', type=str, required=True)
    parser.add_argument('--resume', dest='RESUME', help='resume training', type=bool, default=False)
    parser.add_argument('--resume_version', dest='RESUME_VERSION', help='checkpoint version name', type=str, default=None)
    parser.add_argument('--resume_epoch', dest='RESUME_EPOCH', help='checkpoint epoch', type=int, default=None)
    parser.add_argument('--resume_path', dest='RESUME_PATH', help='checkpoint path', type=str, default=None)
    parser.add_argument('--gpu', dest='GPU', help='gpu id', type=str, default=None)
    parser.add_argument('--seed', dest='SEED', help='random seed', type=int, default=None)
    parser.add_argument('--grad_accu', dest='GRAD_ACCU_STEPS', help='random seed', type=int, default=None)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameters for pretraining')
    pretrain_login_args(parser)
    args = parser.parse_args()
    __C = Cfgs(args)
    with open(args.cfg_file, 'r') as f:
        yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
    __C.override_from_dict(yaml_dict)
    print(__C)
    runner = Runner(__C)
    runner.run()


================================================
FILE: prophet/stage1/utils/load_data.py
================================================
# --------------------------------------------------------------------------------- #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Data loading and preprocessing. Note that for the sake of simplicity,
#              the code only supports the following datasets for now:
#              * VQA 2.0
#              * OK-VQA
#              * A-OKVQA
#              Transferring to other datasets is easy. You may need to modify a few 
#              line of code in this file.
# --------------------------------------------------------------------------------- #

import numpy as np
import glob, json, pickle, random
import torch
import torch.utils.data as Data
from transformers import AutoTokenizer

from evaluation.ans_punct import prep_ans
# from .transforms import _transform


def soft_target(answers, ans_to_ix, preprocess=True):
    ans_score = np.zeros(ans_to_ix.__len__(), np.float32)
    for ans in answers:
        if preprocess:
            ans = prep_ans(ans)
        if ans in ans_to_ix:
            ans_score[ans_to_ix[ans]] = min(1.0, ans_score[ans_to_ix[ans]] + 0.3)
    return ans_score


class CommonData:
    """
    load common data for all dataset objects:
    * imgid_to_path
    * bert tokenizer
    * ans_to_ix, ix_to_ans
    """
    def __init__(self, __C) -> None:
        print('Loading common data...')
        
        # load imgid_to_path
        self.img_feat_path_list = []
        for split in __C.FEATURE_SPLIT:
            feats_dir = __C.FEATS_DIR[split]
            self.img_feat_path_list += glob.glob(feats_dir + '*.npz')
        self.imgid_to_path = {}
        for feat_path in self.img_feat_path_list:
            img_id = int(feat_path.split('/')[-1].split('_')[-1].split('.')[0])
            self.imgid_to_path[img_id] = feat_path
        # self.preprocess = _transform(__C.RESOLUTION)
        print(f'== Total image number: {len(self.imgid_to_path)}')

        # load bert tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(__C.BERT_VERSION)
        self.token_size = self.tokenizer.vocab_size
        print(f'== BertTokenizer loaded, vocab size: {self.token_size}')

        # load ans_to_ix, ix_to_ans
        ans_dict_path = __C.ANSWER_DICT_PATH[__C.DATA_TAG]
        self.ix_to_ans = json.load(open(ans_dict_path, 'r'))
        self.ans_to_ix = {ans: ix for ix, ans in enumerate(self.ix_to_ans)}
        self.ans_size = len(self.ans_to_ix)
        print(f'== Answer vocab size: {self.ans_size}')

        print('Common data process is done.\n')
        

class DataSet(Data.Dataset):
    def __init__(self, __C, common_data, split_name_list):
        self.__C = __C
        print(f'Loading dataset for {self.__C.TASK}|{self.__C.RUN_MODE}({split_name_list})')
        self.split_name_list = split_name_list

        # load all attributes from common data
        self.imgid_to_path = common_data.imgid_to_path
        self.tokenizer = common_data.tokenizer
        self.token_size = common_data.token_size
        self.ans_to_ix = common_data.ans_to_ix
        self.ix_to_ans = common_data.ix_to_ans
        self.ans_size = common_data.ans_size

        # Loading question and answer list
        self.ques_list = []
        self.ans_list = []

        for split_name in split_name_list:
            ques_list = json.load(open(__C.QUESTION_PATH[split_name], 'r'))
            if 'questions' in ques_list:
                ques_list = ques_list['questions']
            self.ques_list += ques_list
            if split_name in __C.ANSWER_PATH:
                ans_list = json.load(open(__C.ANSWER_PATH[split_name], 'r'))
                if 'annotations' in ans_list:
                    ans_list = ans_list['annotations']
                self.ans_list += ans_list

        # indexing data, note that all question_id is set to str,
        # and all image_id is set to int
        if len(self.ans_list) == len(self.ques_list):
            self.annotated = True
            self.qids = [str(ans['question_id']) for ans in self.ans_list]
        elif len(self.ans_list) < len(self.ques_list):
            self.annotated = False
            self.qids = [str(ques['question_id']) for ques in self.ques_list]
        else:
            raise ValueError('Answer list is longer than question list!')

        self.data_size = len(self.qids)
        print(f'== data size: {self.data_size}\n')

        self.qid_to_ques = {str(ques['question_id']): ques for ques in self.ques_list}
        self.qid_to_ans = {str(ans['question_id']): ans for ans in self.ans_list}


    def __getitem__(self, idx):
        # get question in token ids, image in features,
        # and answer in binary-label vector

        __C = self.__C

        # For code safety
        img_feat  = np.zeros(1)
        ques_ids  = np.zeros(1)
        ans_vec   = np.zeros(1)

        qid = self.qids[idx]
        ques_info = self.qid_to_ques[qid]
        
        # Process question
        ques_str = ques_info['question']
        ques_ids = self.bert_tokenize(ques_str, __C.MAX_TOKEN)

        # Process image feature
        img_id = int(ques_info['image_id'])
        img_feat = np.load(self.imgid_to_path[img_id])['x']
        assert img_feat.shape == (__C.IMG_FEAT_GRID, __C.IMG_FEAT_GRID, __C.IMG_FEAT_SIZE)
        img_feat = img_feat.reshape(-1, __C.IMG_FEAT_SIZE)

        # Process answer
        # The code is compatible with VQA v2, OK-VQA, and A-OKVQA.
        # It is no guarantee that it works for other datasets. If
        # you want to use other datasets, please modify following
        # code to fit your dataset.
        if self.annotated:
            ans_info = self.qid_to_ans[qid]
            if 'answers' in ans_info:
                ans_list = [ans['answer'] for ans in ans_info['answers']]
            elif 'direct_answers' in ans_info:
                ans_list = ans_info['direct_answers']
            else:
                raise ValueError('Error: annotation format is not supported!')
            assert type(ans_list[0]) == str, 'Error: answer format is not supported!'
            ans_vec = soft_target(ans_list, self.ans_to_ix)

        return  torch.tensor(img_feat, dtype=torch.float), \
                torch.tensor(ques_ids, dtype=torch.long), \
                torch.tensor(ans_vec, dtype=torch.float)


    def __len__(self):
        return self.data_size

    def bert_tokenize(self, text, max_token):
        text = text.lower().replace('?', '')
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > max_token - 2:
            tokens = tokens[:max_token-2]
        tokens = ['[CLS]'] + tokens + ['[SEP]']
        ids = self.tokenizer.convert_tokens_to_ids(tokens)
        ids = ids + [0] * (max_token - len(ids))
        ids = np.array(ids, np.int64)

        return ids

================================================
FILE: prophet/stage1/utils/optim.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Utilities for optimization
# ------------------------------------------------------------------------------ #

import torch
import torch.optim as Optim
from torch.nn.utils import clip_grad_norm_

class OptimizerWrapper(object):
    """
    A Wrapper for optimizer to support learning rate warmup and decay.
    It also support multiple optimizers and switching at different steps.
    """
    def __init__(self, optimizers, 
                 warmup_schd_steps,
                 decay_schd_step_list,
                 decay_rate, 
                 cur_schd_step=-1,
                 change_optim_step_list=None
        ):
        self.optimizer_list = optimizers
        self.groups_lr_list = []
        for _optim in self.optimizer_list:
            self.groups_lr_list.append([])
            for group in _optim.param_groups:
                self.groups_lr_list[-1].append(group['lr'])
        self.curr_optim_id = 0
        self.optimizer = self.optimizer_list[self.curr_optim_id]
        self.change_optim_step_list = change_optim_step_list
        # self.total_schd_steps = total_schd_steps
        self.warmup_schd_steps = warmup_schd_steps
        self.decay_schd_step_list = decay_schd_step_list
        self.decay_rate = decay_rate
        self._step = 0
        self.warmup_lr_scale = 1.0
        self.decay_lr_scale = 1.0
        self.schedule_step(cur_schd_step)

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self, step=None, schd_step=False):
        if step is None:
            step = self._step
        if schd_step:
            self.schedule_step(step)
        
        for group in self.optimizer.param_groups:
            if '_grad_norm_clip' in group:
                if group['_grad_norm_clip'] > 0:
                    clip_grad_norm_(group['params'], group['_grad_norm_clip'])
        
        self.optimizer.step()
        self._step += 1
    
    def schedule_step(self, schd_step):
        schd_step += 1
        self.warmup_lr_scale = min(1., float(schd_step + 1) / float(self.warmup_schd_steps + 1))
        if schd_step in self.decay_schd_step_list:
            self.decay_lr_scale = self.decay_lr_scale * self.decay_rate
        lr_scale = self.warmup_lr_scale * self.decay_lr_scale
        # lr actually changes in following lines
        if self.change_optim_step_list is not None:
            if schd_step in self.change_optim_step_list:
                self.curr_optim_id += 1
                self.optimizer = self.optimizer_list[self.curr_optim_id]
        for i, group in enumerate(self.optimizer.param_groups):
            group['lr'] = lr_scale * self.groups_lr_list[self.curr_optim_id][i]

    def current_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)
    

def get_optim(__C, model):
    optim_class = eval('Optim.' + __C.OPT)
    params = [
        {'params': [], 'lr': __C.LR_BASE, '_grad_norm_clip': __C.GRAD_NORM_CLIP},
        {'params': [], 'lr': __C.LR_BASE * __C.BERT_LR_MULT, '_grad_norm_clip': __C.GRAD_NORM_CLIP},
    ]
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'bert' in name:
                params[1]['params'].append(param)
            else:
                params[0]['params'].append(param)
    hyper_params = {k: eval(v) for k, v in __C.OPT_PARAMS.items()}
    return OptimizerWrapper(
        [optim_class(
            params,
            **hyper_params
        ),],
        warmup_schd_steps=__C.WARMUP_EPOCH,
        decay_schd_step_list=__C.LR_DECAY_LIST,
        decay_rate=__C.LR_DECAY_R,
    )


def get_optim_for_finetune(__C, model, new_params_name='proj1'):
    # optimizer for finetuning warmup
    optim_class1 = eval('Optim.' + __C.OPT_FTW)
    params1 = []
    for name, param in model.named_parameters():
        if new_params_name in name and param.requires_grad:
            params1.append(param)
    hyper_params1 = {k: eval(v) for k, v in __C.OPT_PARAMS_FTW.items()}
    optimizer1 = optim_class1(
        params1,
        lr=__C.LR_BASE_FTW,
        **hyper_params1
    )

    optim_class2 = eval('Optim.' + __C.OPT)
    params2 = [
        {'params': [], 'lr': __C.LR_BASE, '_grad_norm_clip': __C.GRAD_NORM_CLIP},
        {'params': [], 'lr': __C.LR_BASE * __C.BERT_LR_MULT, '_grad_norm_clip': __C.GRAD_NORM_CLIP},
    ]
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'bert' in name:
                params2[1]['params'].append(param)
            else:
                params2[0]['params'].append(param)
    hyper_params2 = {k: eval(v) for k, v in __C.OPT_PARAMS.items()}
    optimizer2 = optim_class2(
        params2,
        **hyper_params2
    )
    return OptimizerWrapper(
        [optimizer1, optimizer2],
        warmup_schd_steps=__C.WARMUP_EPOCH,
        decay_schd_step_list=__C.LR_DECAY_LIST,
        decay_rate=__C.LR_DECAY_R,
        change_optim_step_list=[__C.EPOPH_FTW,]        
    )


================================================
FILE: prophet/stage2/prompt.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Runner that handles the prompting process
# ------------------------------------------------------------------------------ #

import os, sys
# sys.path.append(os.getcwd())

import pickle
import json, time
import math
import random
import argparse
from datetime import datetime
from copy import deepcopy
import yaml
from pathlib import Path
import openai

from .utils.fancy_pbar import progress, info_column
from .utils.data_utils import Qid2Data
from configs.task_cfgs import Cfgs


class Runner:
    def __init__(self, __C, evaluater):
        self.__C = __C
        self.evaluater = evaluater
        openai.api_key = __C.OPENAI_KEY
    
    def gpt3_infer(self, prompt_text, _retry=0):
        # print(prompt_text)
        # exponential backoff
        if _retry > 0:
            print('retrying...')
            st = 2 ** _retry
            time.sleep(st)
        
        if self.__C.DEBUG:
            # print(prompt_text)
            time.sleep(0.05)
            return 0, 0

        try:
            # print('calling gpt3...')
            response = openai.Completion.create(
                engine=self.__C.MODEL,
                prompt=prompt_text,
                temperature=self.__C.TEMPERATURE,
                max_tokens=self.__C.MAX_TOKENS,
                logprobs=1,
                stop=["\n", "<|endoftext|>"],
                # timeout=20,
            )
            # print('gpt3 called.')
        except Exception as e:
            print(type(e), e)
            if str(e) == 'You exceeded your current quota, please check your plan and billing details.':
                exit(1)
            return self.gpt3_infer(prompt_text, _retry + 1)

        response_txt = response.choices[0].text.strip()
        # print(response_txt)
        plist = []
        for ii in range(len(response['choices'][0]['logprobs']['tokens'])):
            if response['choices'][0]['logprobs']['tokens'][ii] in ["\n", "<|endoftext|>"]:
                break
            plist.append(response['choices'][0]['logprobs']['token_logprobs'][ii])
        prob = math.exp(sum(plist))
        
        return response_txt, prob
    
    def sample_make(self, ques, capt, cands, ans=None):
        line_prefix = self.__C.LINE_PREFIX
        cands = cands[:self.__C.K_CANDIDATES]
        prompt_text = line_prefix + f'Context: {capt}\n'
        prompt_text += line_prefix + f'Question: {ques}\n'
        cands_with_conf = [f'{cand["answer"]}({cand["confidence"]:.2f})' for cand in cands]
        cands = ', '.join(cands_with_conf)
        prompt_text += line_prefix + f'Candidates: {cands}\n'
        prompt_text += line_prefix + 'Answer:'
        if ans is not None:
            prompt_text += f' {ans}'
        return prompt_text

    def get_context(self, example_qids):
        # making context text for one testing input
        prompt_text = self.__C.PROMPT_HEAD
        examples = []
        for key in example_qids:
            ques = self.trainset.get_question(key)
            caption = self.trainset.get_caption(key)
            cands = self.trainset.get_topk_candidates(key)
            gt_ans = self.trainset.get_most_answer(key)
            examples.append((ques, caption, cands, gt_ans))
            prompt_text += self.sample_make(ques, caption, cands, ans=gt_ans)
            prompt_text += '\n\n'
        return prompt_text
    
    def run(self):
        ## where logs will be saved
        Path(self.__C.LOG_PATH).parent.mkdir(parents=True, exist_ok=True)
        with open(self.__C.LOG_PATH, 'w') as f:
            f.write(str(self.__C) + '\n')
        ## where results will be saved
        Path(self.__C.RESULT_DIR).mkdir(parents=True, exist_ok=True)
        
        self.cache = {}
        self.cache_file_path = os.path.join(
            self.__C.RESULT_DIR,
            'cache.json'
        )
        if self.__C.RESUME:
            self.cache = json.load(open(self.cache_file_path, 'r'))
        
        print('Note that the accuracies printed before final evaluation (the last printed one) are rough, just for checking if the process is normal!!!\n')
        self.trainset = Qid2Data(
            self.__C, 
            self.__C.TRAIN_SPLITS,
            True
        )
        self.valset = Qid2Data(
            self.__C, 
            self.__C.EVAL_SPLITS,
            self.__C.EVAL_NOW,
            json.load(open(self.__C.EXAMPLES_PATH, 'r'))
        )

        # if 'aok' in self.__C.TASK:
        #     from evaluation.aokvqa_evaluate import AOKEvaluater as Evaluater
        # else:
        #     from evaluation.okvqa_evaluate import OKEvaluater as Evaluater
        # evaluater = Evaluater(
        #     self.valset.annotation_path,
        #     self.valset.question_path
        # )

        infer_times = self.__C.T_INFER
        N_inctx = self.__C.N_EXAMPLES
        
        print()

        for qid in progress.track(self.valset.qid_to_data, description="Working...  "):
            if qid in self.cache:
                continue
            ques = self.valset.get_question(qid)
            caption = self.valset.get_caption(qid)
            cands = self.valset.get_topk_candidates(qid, self.__C.K_CANDIDATES)

            prompt_query = self.sample_make(ques, caption, cands)
            example_qids = self.valset.get_similar_qids(qid, k=infer_times * N_inctx)
            random.shuffle(example_qids)

            prompt_info_list = []
            ans_pool = {}
            # multi-times infer
            for t in range(infer_times):
                # print(f'Infer {t}...')
                prompt_in_ctx = self.get_context(example_qids[(N_inctx * t):(N_inctx * t + N_inctx)])
                prompt_text = prompt_in_ctx + prompt_query
                gen_text, gen_prob = self.gpt3_infer(prompt_text)

                ans = self.evaluater.prep_ans(gen_text)
                if ans != '':
                    ans_pool[ans] = ans_pool.get(ans, 0.) + gen_prob

                prompt_info = {
                    'prompt': prompt_text,
                    'answer': gen_text,
                    'confidence': gen_prob
                }
                prompt_info_list.append(prompt_info)
                time.sleep(self.__C.SLEEP_PER_INFER)
            
            # vote
            if len(ans_pool) == 0:
                answer = self.valset.get_topk_candidates(qid, 1)[0]['answer']
            else:
                answer = sorted(ans_pool.items(), key=lambda x: x[1], reverse=True)[0][0]
            
            self.evaluater.add(qid, answer)
            self.cache[qid] = {
                'question_id': qid,
                'answer': answer,
                'prompt_info': prompt_info_list
            }
            json.dump(self.cache, open(self.cache_file_path, 'w'))

            ll = len(self.cache)
            if self.__C.EVAL_NOW and not self.__C.DEBUG:
                if ll > 21 and ll % 10 == 0:
                    rt_accuracy = self.valset.rt_evaluate(self.cache.values())
                    info_column.info = f'Acc: {rt_accuracy}'

        self.evaluater.save(self.__C.RESULT_PATH)
        if self.__C.EVAL_NOW:
            with open(self.__C.LOG_PATH, 'a+') as logfile:
                self.evaluater.evaluate(logfile)
        
def prompt_login_args(parser):
    parser.add_argument('--debug', dest='DEBUG', help='debug mode', action='store_true')
    parser.add_argument('--resume', dest='RESUME', help='resume previous run', action='store_true')
    parser.add_argument('--task', dest='TASK', help='task name, e.g., ok, aok_val, aok_test', type=str, required=True)
    parser.add_argument('--version', dest='VERSION', help='version name', type=str, required=True)
    parser.add_argument('--cfg', dest='cfg_file', help='optional config file', type=str, default='configs/prompt.yml')
    parser.add_argument('--examples_path', dest='EXAMPLES_PATH', help='answer-aware example file path, default: "assets/answer_aware_examples_for_ok.json"', type=str, default=None)
    parser.add_argument('--candidates_path', dest='CANDIDATES_PATH', help='candidates file path, default: "assets/candidates_for_ok.json"', type=str, default=None)
    parser.add_argument('--captions_path', dest='CAPTIONS_PATH', help='captions file path, default: "assets/captions_for_ok.json"', type=str, default=None)
    parser.add_argument('--openai_key', dest='OPENAI_KEY', help='openai api key', type=str, default=None)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Heuristics-enhanced Prompting')
    prompt_login_args(parser)
    args = parser.parse_args()
    __C = Cfgs(args)
    with open(args.cfg_file, 'r') as f:
        yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
    __C.override_from_dict(yaml_dict)
    print(__C)

    runner = Runner(__C)
    runner.run()


================================================
FILE: prophet/stage2/utils/data_utils.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: dataset utils for stage2
# ------------------------------------------------------------------------------ #

import json
from typing import Dict
import pickle
from collections import Counter

# following two score is rough, and only for print accuracies during inferring.
def ok_score(gt_answers):
    gt_answers = [a['answer'] for a in gt_answers]
    ans2cnt = Counter(gt_answers)
    # sort
    ans2cnt = sorted(ans2cnt.items(), key=lambda x: x[1], reverse=True)
    ans2score = {}
    for ans, cnt in ans2cnt:
        # ans2score[ans] = min(1.0, cnt / 3.0)
        if cnt == 1:
            ans2score[ans] = 0.3
        elif cnt == 2:
            ans2score[ans] = 0.6
        elif cnt == 3:
            ans2score[ans] = 0.9
        else:
            ans2score[ans] = 1.0
    return ans2score

def aok_score(gt_answers):
    gt_answers = [a for a in gt_answers]
    ans2cnt = Counter(gt_answers)
    # sort
    ans2cnt = sorted(ans2cnt.items(), key=lambda x: x[1], reverse=True)
    ans2score = {}
    for ans, cnt in ans2cnt:
        # ans2score[ans] = min(1.0, cnt / 3.0)
        if cnt == 1:
            ans2score[ans] = 1 / 3.
        elif cnt == 2:
            ans2score[ans] = 2 / 3.
        else:
            ans2score[ans] = 1.
    return ans2score


class Qid2Data(Dict):
    def __init__(self, __C, splits, annotated=False, similar_examples=None):
        super().__init__()

        self.__C = __C
        self.annotated = annotated
        
        ques_set = []
        for split in splits:
            split_path = self.__C.QUESTION_PATH[split]
            _ques_set = json.load(open(split_path, 'r'))
            if 'questions' in _ques_set:
                _ques_set = _ques_set['questions']
            ques_set += _ques_set
        qid_to_ques = {str(q['question_id']): q for q in ques_set}

        if annotated:
            anno_set = []
            for split in splits:
                split_path = self.__C.ANSWER_PATH[split]
                _anno_set = json.load(open(split_path, 'r'))
                if 'annotations' in _anno_set:
                    _anno_set = _anno_set['annotations']
                anno_set += _anno_set
            qid_to_anno = {str(a['question_id']): a for a in anno_set}
        
        qid_to_topk = json.load(open(__C.CANDIDATES_PATH))
        # qid_to_topk = {t['question_id']: t for t in topk}

        iid_to_capt = json.load(open(__C.CAPTIONS_PATH))
        
        _score = aok_score if 'aok' in __C.TASK else ok_score
        
        qid_to_data = {}
        # ques_set = ques_set['questions']
        # anno_set = anno_set['annotations']
        for qid in qid_to_ques:
            q_item = qid_to_ques[qid]
            t_item = qid_to_topk[qid]

            iid = str(q_item['image_id'])
            caption = iid_to_capt[iid].strip()
            if caption[-1] != '.':
                caption += '.'
            
            qid_to_data[qid] = {
                'question_id': qid,
                'image_id': iid,
                'question': q_item['question'],
                # 'most_answer': most_answer,
                # 'gt_scores': ans2score,
                'topk_candidates': t_item,
                'caption': caption,
            }
            if annotated:
                a_item = qid_to_anno[qid]
                if 'answers' in a_item:
                    answers = a_item['answers']
                else:
                    answers = a_item['direct_answers']

                ans2score = _score(answers)

                most_answer = list(ans2score.keys())[0]
                if most_answer == '':
                    most_answer = list(ans2score.keys())[1]
                
                qid_to_data[qid]['most_answer'] = most_answer
                qid_to_data[qid]['gt_scores'] = ans2score

        self.qid_to_data = qid_to_data

        k = __C.K_CANDIDATES
        if annotated:
            print(f'Loaded dataset size: {len(self.qid_to_data)}, top{k} accuracy: {self.topk_accuracy(k)*100:.2f}, top1 accuracy: {self.topk_accuracy(1)*100:.2f}')
        
        if similar_examples:
            for qid in similar_examples:
                qid_to_data[qid]['similar_qids'] = similar_examples[qid]
            
            # check if all items have similar_qids
            for qid, item in self.items():
                if 'similar_qids' not in item:
                    raise ValueError(f'qid {qid} does not have similar_qids')
        
        

    def __getitem__(self, __key):
        return self.qid_to_data[__key]
    

    def get_caption(self, qid):
        caption = self[qid]['caption']
        # if with_tag:
        #     tags = self.get_tags(qid, k_tags)
        #     caption += ' ' + ', '.join(tags) + '.'
        return caption
    
    def get_question(self, qid):
        return self[qid]['question']
    
    
    def get_gt_answers(self, qid):
        if not self.annotated:
            return None
        return self[qid]['gt_scores']
    
    def get_most_answer(self, qid):
        if not self.annotated:
            return None
        return self[qid]['most_answer']

    def get_topk_candidates(self, qid, k=None):
        if k is None:
            return self[qid]['topk_candidates']
        else:
            return self[qid]['topk_candidates'][:k]
    
    def get_similar_qids(self, qid, k=None):
        similar_qids = self[qid]['similar_qids']
        if k is not None:
            similar_qids = similar_qids[:k]
        return similar_qids
    
    def evaluate_by_threshold(self, ans_set, threshold=1.0):
        if not self.annotated:
            return -1
        
        total_score = 0.0
        for item in ans_set:
            qid = item['question_id']
            topk_candidates = self.get_topk_candidates(qid)
            top1_confid = topk_candidates[0]['confidence']
            if top1_confid > threshold:
                answer = topk_candidates[0]['answer']
            else:
                answer = item['answer']
            gt_answers = self.get_gt_answers(qid)
            if answer in gt_answers:
                total_score += gt_answers[answer]
        return total_score / len(ans_set)
    
    def topk_accuracy(self, k=1, sub_set=None):
        if not self.annotated:
            return -1
        
        total_score = 0.0
        if sub_set is not None:
            qids = sub_set
        else:
            qids = list(self.qid_to_data.keys())
        for qid in qids:
            topk_candidates = self.get_topk_candidates(qid)[:k]
            gt_answers = self.get_gt_answers(qid)
            score_list = [gt_answers.get(a['answer'], 0.0) for a in topk_candidates]
            total_score += max(score_list)
        return total_score / len(qids)
    
    def rt_evaluate(self, answer_set):
        if not self.annotated:
            return ''
        score1 = self.evaluate_by_threshold(answer_set, 1.0) * 100
        score2 = self.evaluate_by_threshold(answer_set, 0.0) * 100
        score_string = f'{score2:.2f}->{score1:.2f}'
        return score_string


================================================
FILE: prophet/stage2/utils/fancy_pbar.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: customized progress bar
# ------------------------------------------------------------------------------ #

from time import sleep

from rich.table import Column
from rich.progress import *
import atexit

class RichColumn(ProgressColumn):
    def __init__(self, table_column: Optional[Column] = None) -> None:
        super().__init__(table_column)
        self.time_elapsed_column = TimeElapsedColumn()
        self.time_remaining_column = TimeRemainingColumn()
        self.m_of_n = MofNCompleteColumn()
        self._completed = 0
        self.sec_per_iter = -1
        self.info = None
    
    def render(self, task: "Task") -> Text:
        m_of_n = self.m_of_n.render(task)
        m_of_n = Text(f'{m_of_n}'.replace(' ', ''), style="red")
        elapsed = self.time_elapsed_column.render(task)
        elapsed = Text(f'{elapsed}', style="orange_red1")\
                    + Text('/', style="dark_orange")
        remaining = self.time_remaining_column.render(task)
        remaining = Text(f'{remaining}', style="yellow")
        if task.completed:
            if self._completed < task.completed:
                # do not update sec_per_iter if no new completed iterators
                self._completed = task.completed
                self.sec_per_iter = task.elapsed / task.completed
            sec_per_iter = Text(f'({self.sec_per_iter:.1f}s/iter)', style="green")
        else:
            sec_per_iter = Text(f'(--s/iter)', style="green")

        rendered = m_of_n + ' ' + elapsed + remaining + sec_per_iter
        if self.info is None:
            return rendered
        info = Text(f' {self.info}', style="cyan")
        return rendered + info

info_column = RichColumn()
progress = Progress(
    TextColumn("[bold]{task.description}", table_column=Column(ratio=1)), 
    BarColumn(bar_width=None, table_column=Column(ratio=8), complete_style="blue"),
    # MofNCompleteColumn(),
    info_column,
    expand=True,
    redirect_stdout=False,
    redirect_stderr=False
)
progress.__enter__()

def exit_progress():
    progress.__exit__(None, None, None)
atexit.register(exit_progress)

if __name__ == '__main__':
    # with progress:
    for n in progress.track(range(10), description="Working...  "):
        sleep(0.01)
        print(n)
        if n == 8:
            0 / 0

================================================
FILE: scripts/evaluate_file.sh
================================================
#!/bin/bash
# This script is used to evaluate a result file.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --task)
      TASK="$2"
      shift 2;;
    --result_path)
      RESULT_PATH="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
RESULT_PATH=${RESULT_PATH:-"preds/prophet_611_okvqa.json"} # path to the result file, default is the result from our experiments

if [ $TASK == "ok" ]; then
  python -m evaluation.okvqa_evaluate --result_path $RESULT_PATH \
    --question_path 'datasets/okvqa/OpenEnded_mscoco_val2014_questions.json' \
    --annotation_path 'datasets/okvqa/mscoco_val2014_annotations.json'
elif [ $TASK == "aok_val" ]; then
  python -m evaluation.aokvqa_evaluate --result_path $RESULT_PATH \
    --dataset_path 'datasets/aokvqa/aokvqa_v1p0_val.json' \
    --direct_answer --multiple_choice
elif [ $TASK == "aok_test" ]; then
  echo "Please submit your result to the AOKVQA leaderboard."
else
  echo "Unknown task: $TASK"
  exit 1
fi

================================================
FILE: scripts/evaluate_model.sh
================================================
#!/bin/bash
# This script is used to evaluate a finetuned model.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --gpu)
      GPU="$2"
      shift 2;;
    --task)
      TASK="$2"
      shift 2;;
    --ckpt_path)
      CKPT_PATH="$2"
      shift 2;;
    --version)
      VERSION="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
GPU=${GPU:-0} # GPU id(s) you want to use, default '0'
CKPT_PATH=${CKPT_PATH:-"ckpts/mcan_ft_okvqa.pkl"} # path to the pretrained model, default is the result from our experiments
VERSION=${VERSION:-"eval_finetuned_${TASK}_model"} # version name, default 'eval_finetuned_$TASK_model'

# CUDA_VISIBLE_DEVICES=$GPU \
python main.py \
    --task $TASK --run_mode finetune_test \
    --cfg configs/finetune.yml \
    --version $VERSION \
    --ckpt_path $CKPT_PATH \
    --gpu $GPU --grad_accu 2


================================================
FILE: scripts/extract_img_feats.sh
================================================
#!/bin/bash
# This script is used to extract image features.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --gpu)
      GPU="$2"
      shift 2;;
    --dataset)
      DATASET="$2"
      shift 2;;
    --clip)
      CLIP_MODEL="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

DATASET=${DATASET:-ok} # dataset name, one of ['ok', 'aok'], default 'ok'
GPU=${GPU:-0} # GPU id(s) you want to use, default '0'
CLIP_MODEL=${CLIP_MODEL:-RN50x64} # clip model name or path, default 'RN50x64'

# CUDA_VISIBLE_DEVICES=$GPU \
python tools/extract_img_feats.py \
    --dataset $DATASET --gpu $GPU \
    --clip_model $CLIP_MODEL

================================================
FILE: scripts/finetune.sh
================================================
#!/bin/bash
# This script is used to finetune the pretrained MCAN model.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --gpu)
      GPU="$2"
      shift 2;;
    --task)
      TASK="$2"
      shift 2;;
    --pretrained_model)
      PRETRAINED_MODEL_PATH="$2"
      shift 2;;
    --version)
      VERSION="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
GPU=${GPU:-0} # GPU id(s) you want to use, default '0'
PRETRAINED_MODEL_PATH=${PRETRAINED_MODEL_PATH:-"ckpts/mcan_pt_okvqa.pkl"} # path to the pretrained model, default is the result from our experiments
VERSION=${VERSION:-finetuning_okvqa} # version name, default 'finetuning_for_$TASK'

# run python script
# CUDA_VISIBLE_DEVICES=$GPU \
python main.py \
    --task $TASK --run_mode finetune \
    --cfg configs/finetune.yml \
    --version $VERSION \
    --pretrained_model $PRETRAINED_MODEL_PATH \
    --gpu $GPU --seed 99 --grad_accu 2


================================================
FILE: scripts/heuristics_gen.sh
================================================
#!/bin/bash
# This script is used to generate heuristics from a finetuned model.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --gpu)
      GPU="$2"
      shift 2;;
    --task)
      TASK="$2"
      shift 2;;
    --ckpt_path)
      CKPT_PATH="$2"
      shift 2;;
    --candidate_num)
      CANDIDATE_NUM="$2"
      shift 2;;
    --example_num)
      EXAMPLE_NUM="$2"
      shift 2;;
    --version)
      VERSION="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
GPU=${GPU:-0} # GPU id(s) you want to use, default '0'
CKPT_PATH=${CKPT_PATH:-"ckpts/mcan_ft_okvqa.pkl"} # path to the pretrained model, default is the result from our experiments
CANDIDATE_NUM=${CANDIDATE_NUM:-10} # number of candidates to be generated
EXAMPLE_NUM=${EXAMPLE_NUM:-100} # number of examples to be generated
VERSION=${VERSION:-"heuristics_okvqa"} # version name, default 'heuristics1_for_$TASK'

# CUDA_VISIBLE_DEVICES=$GPU \
python main.py \
    --task $TASK --run_mode heuristics \
    --version $VERSION \
    --cfg configs/finetune.yml \
    --ckpt_path $CKPT_PATH \
    --candidate_num $CANDIDATE_NUM \
    --example_num $EXAMPLE_NUM \
    --gpu $GPU

================================================
FILE: scripts/pretrain.sh
================================================
#!/bin/bash
# This script is used to pretrain the MCAN model.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --gpu)
      GPU="$2"
      shift 2;;
    --task)
      TASK="$2"
      shift 2;;
    --version)
      VERSION="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
GPU=${GPU:-0} # GPU id(s) you want to use, default '0'
VERSION=${VERSION:-pretraining_okvqa} # version name, default 'pretraining_for_$TASK'

# CUDA_VISIBLE_DEVICES=$GPU \
python main.py \
    --task $TASK --run_mode pretrain\
    --cfg configs/pretrain.yml \
    --version $VERSION \
    --gpu $GPU --seed 99 --grad_accu 2

================================================
FILE: scripts/prompt.sh
================================================
#!/bin/bash
# This script is used to prompt GPT-3 to generate final answers.

# Parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --task)
      TASK="$2"
      shift 2;;
    --version)
      VERSION="$2"
      shift 2;;
    --examples_path)
      EXAMPLES_PATH="$2"
      shift 2;;
    --candidates_path)
      CANDIDATES_PATH="$2"
      shift 2;;
    --captions_path)
      CAPTIONS_PATH="$2"
      shift 2;;
    --openai_key)
      OPENAI_KEY="$2"
      shift 2;;
    *)
      echo "Unknown argument: $1"
      exit 1;;
  esac
done

TASK=${TASK:-ok} # task name, one of ['ok', 'aok_val', 'aok_test'], default 'ok'
VERSION=${VERSION:-"prompt_okvqa"} # version name, default 'prompt_for_$TASK'
EXAMPLES_PATH=${EXAMPLES_PATH:-"assets/answer_aware_examples_okvqa.json"} # path to the examples, default is the result from our experiments
CANDIDATES_PATH=${CANDIDATES_PATH:-"assets/candidates_okvqa.json"} # path to the candidates, default is the result from our experiments
CAPTIONS_PATH=${CAPTIONS_PATH:-"assets/captions_okvqa.json"} # path to the captions, default is the result from our experiments
OPENAI_KEY=${OPENAI_KEY:-""} # path to the captions

# CUDA_VISIBLE_DEVICES=$GPU \
python main.py \
    --task $TASK --run_mode prompt \
    --version $VERSION \
    --cfg configs/prompt.yml \
    --examples_path $EXAMPLES_PATH \
    --candidates_path $CANDIDATES_PATH \
    --captions_path $CAPTIONS_PATH \
    --openai_key $OPENAI_KEY

================================================
FILE: tools/extract_img_feats.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Tool for extracting image features
# ------------------------------------------------------------------------------ #

import os, sys
sys.path.append(os.getcwd())

import glob, re, math, time, datetime
import numpy as np
import torch
from torch import nn
from PIL import Image
import clip
from tqdm import tqdm
import argparse
from pathlib import Path

from configs.task_cfgs import Cfgs
from configs.task_to_split import *
from tools.transforms import _transform


@torch.no_grad()
def _extract_feat(img_path, net, T, save_path):
    # print(img_path)
    img = Image.open(img_path)
    # W, H = img.size
    img = T(img).unsqueeze(0).cuda()
    clip_feats = net(img).cpu().numpy()[0]
    clip_feats = clip_feats.transpose(1, 2, 0)
    # print(clip_feats.shape, save_path)
    # return
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    np.savez(
        save_path,
        x=clip_feats,
    )


class ExtractModel:
    def __init__(self, encoder) -> None:
        encoder.attnpool = nn.Identity()
        self.backbone = encoder

        self.backbone.cuda().eval()
    
    @torch.no_grad()
    def __call__(self, img):
        x = self.backbone(img)
        return x


def main(__C, dataset):
    # find imgs
    img_dir_list = []
    for split in SPLIT_TO_IMGS:
        if split.startswith(dataset):
            img_dir_list.append(
                __C.IMAGE_DIR[SPLIT_TO_IMGS[split]]
            )
    print('image dirs:', img_dir_list)
    img_path_list = []
    for img_dir in img_dir_list:
        img_path_list += glob.glob(img_dir + '*.jpg')
    print('total images:', len(img_path_list))

    # load model
    clip_model, _ = clip.load(__C.CLIP_VERSION, device='cpu')
    img_encoder = clip_model.visual

    model = ExtractModel(img_encoder)
    T = _transform(__C.IMG_RESOLUTION)

    for img_path in tqdm(img_path_list):
        img_path_sep = img_path.split('/')
        img_path_sep[-3] += '_feats'
        save_path = '/'.join(img_path_sep).replace('.jpg', '.npz')
        _extract_feat(img_path, model, T, save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Tool for extracting CLIP image features.')
    parser.add_argument('--dataset', dest='dataset', help='dataset name, e.g., ok, aok', type=str, required=True)
    parser.add_argument('--gpu', dest='GPU', help='gpu id', type=str, default='0')
    parser.add_argument('--clip_model', dest='CLIP_VERSION', help='clip model name or local model checkpoint path', type=str, default='RN50x64')
    parser.add_argument('--img_resolution', dest='IMG_RESOLUTION', help='image resolution', type=int, default=512)
    args = parser.parse_args()
    __C = Cfgs(args)
    main(__C, args.dataset)

================================================
FILE: tools/transforms.py
================================================
# ------------------------------------------------------------------------------ #
# Author: Zhenwei Shao (https://github.com/ParadoxZW)
# Description: Preprocessing images to be fed into the model, the script is
#              adapted from the code of CLIP (github.com/openai/CLIP)
# ------------------------------------------------------------------------------ #

from math import ceil
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import ImageOps

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def Pad():
    def _pad(image):
        W, H = image.size # debugged
        if H < W:
            pad_H = ceil((W - H) / 2)
            pad_W = 0
        else:
            pad_H = 0
            pad_W = ceil((H - W) / 2)
        img = ImageOps.expand(image, border=(pad_W, pad_H, pad_W, pad_H), fill=0)
        # print(img.size)
        return img
    return _pad

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def identity(x):
    return x

def _transform(n_px, pad=False, crop=False):
    return Compose([
        Pad() if pad else identity,
        Resize([n_px, n_px], interpolation=BICUBIC),
        CenterCrop(n_px) if crop else identity,
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


if __name__ == '__main__':
    img = np.random.rand(100, 333, 3).astype('uint8')
    img = Image.fromarray(img)
    img = _transform(32 * 14)(img)
    img = torch.Tensor(img)
    print(img.size())
Download .txt
gitextract_h4opr1r6/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   └── .gitkeep
├── ckpts/
│   └── .gitkeep
├── configs/
│   ├── finetune.yml
│   ├── path_cfgs.py
│   ├── pretrain.yml
│   ├── prompt.yml
│   ├── task_cfgs.py
│   └── task_to_split.py
├── datasets/
│   └── .gitkeep
├── environment.yml
├── evaluation/
│   ├── ans_punct.py
│   ├── aok_utils/
│   │   ├── eval_predictions.py
│   │   ├── load_aokvqa.py
│   │   └── remap_predictions.py
│   ├── aokvqa_evaluate.py
│   ├── okvqa_evaluate.py
│   └── vqa_utils/
│       ├── vqa.py
│       └── vqaEval.py
├── main.py
├── misc/
│   └── tree.txt
├── outputs/
│   ├── ckpts/
│   │   └── .gitkeep
│   ├── logs/
│   │   └── .gitkeep
│   └── results/
│       └── .gitkeep
├── preds/
│   └── .gitkeep
├── prophet/
│   ├── __init__.py
│   ├── stage1/
│   │   ├── finetune.py
│   │   ├── heuristics.py
│   │   ├── model/
│   │   │   ├── layers.py
│   │   │   ├── mcan.py
│   │   │   ├── mcan_for_finetune.py
│   │   │   ├── net_utils.py
│   │   │   └── rope2d.py
│   │   ├── pretrain.py
│   │   └── utils/
│   │       ├── load_data.py
│   │       └── optim.py
│   └── stage2/
│       ├── prompt.py
│       └── utils/
│           ├── data_utils.py
│           └── fancy_pbar.py
├── scripts/
│   ├── evaluate_file.sh
│   ├── evaluate_model.sh
│   ├── extract_img_feats.sh
│   ├── finetune.sh
│   ├── heuristics_gen.sh
│   ├── pretrain.sh
│   └── prompt.sh
└── tools/
    ├── extract_img_feats.py
    └── transforms.py
Download .txt
SYMBOL INDEX (175 symbols across 27 files)

FILE: configs/path_cfgs.py
  class PATH (line 8) | class PATH:
    method __init__ (line 9) | def __init__(self):

FILE: configs/task_cfgs.py
  class Cfgs (line 16) | class Cfgs(PATH):
    method __init__ (line 18) | def __init__(self, args):
    method __repr__ (line 135) | def __repr__(self):
    method override_from_dict (line 144) | def override_from_dict(self, dict_):
    method set_silent_attr (line 148) | def set_silent_attr(self):
    method TRAIN_SPLITS (line 154) | def TRAIN_SPLITS(self):
    method EVAL_SPLITS (line 158) | def EVAL_SPLITS(self):
    method FEATURE_SPLIT (line 162) | def FEATURE_SPLIT(self):
    method EVAL_QUESTION_PATH (line 171) | def EVAL_QUESTION_PATH(self):
    method EVAL_ANSWER_PATH (line 177) | def EVAL_ANSWER_PATH(self):

FILE: configs/task_to_split.py
  class DictSafe (line 7) | class DictSafe(dict):
    method __init__ (line 9) | def __init__(self, data={}):
    method __getitem__ (line 15) | def __getitem__(self, key):

FILE: evaluation/ans_punct.py
  function process_punctuation (line 75) | def process_punctuation(inText):
  function process_digit_article (line 87) | def process_digit_article(inText):
  function prep_ans (line 103) | def prep_ans(answer):

FILE: evaluation/aok_utils/eval_predictions.py
  function eval_aokvqa (line 9) | def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):

FILE: evaluation/aok_utils/load_aokvqa.py
  function load_aokvqa (line 5) | def load_aokvqa(aokvqa_dir, split, version='v1p0'):
  function get_coco_path (line 12) | def get_coco_path(split, image_id, coco_dir):

FILE: evaluation/aok_utils/remap_predictions.py
  function map_to_choices (line 14) | def map_to_choices(dataset, predictions, device='cpu'):

FILE: evaluation/aokvqa_evaluate.py
  class AOKEvaluater (line 12) | class AOKEvaluater:
    method __init__ (line 13) | def __init__(self, annotation_path: str, question_path: str):
    method init (line 22) | def init(self):
    method set_mode (line 25) | def set_mode(self, multiple_choice=None, map_to_mc=None):
    method prep_ans (line 31) | def prep_ans(self, answer):
    method add (line 34) | def add(self, qid, answer):
    method save (line 44) | def save(self, result_path: str):
    method evaluate (line 53) | def evaluate(self, logfile=None):
  function _evaluate (line 64) | def _evaluate(dataset, results, direct_answer=True, multiple_choice=True):

FILE: evaluation/okvqa_evaluate.py
  class OKEvaluater (line 12) | class OKEvaluater:
    method __init__ (line 13) | def __init__(self, annotation_path: str, question_path: str):
    method init (line 21) | def init(self):
    method prep_ans (line 24) | def prep_ans(self, answer):
    method add (line 27) | def add(self, qid, answer):
    method save (line 34) | def save(self, result_path: str):
    method evaluate (line 38) | def evaluate(self, logfile=None):
  function _evaluate (line 48) | def _evaluate(annotation_file: str, question_file: str, result_file: str):

FILE: evaluation/vqa_utils/vqa.py
  class VQA (line 24) | class VQA:
    method __init__ (line 25) | def __init__(self, annotation_file=None, question_file=None):
    method createIndex (line 47) | def createIndex(self):
    method info (line 65) | def info(self):
    method getQuesIds (line 73) | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
    method getImgIds (line 97) | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
    method loadQA (line 121) | def loadQA(self, ids=[]):
    method showQA (line 132) | def showQA(self, anns):
    method loadRes (line 146) | def loadRes(self, resFile, quesFile):

FILE: evaluation/vqa_utils/vqaEval.py
  class VQAEval (line 10) | class VQAEval:
    method __init__ (line 11) | def __init__(self, vqa, vqaRes, n=2):
    method evaluate (line 68) | def evaluate(self, quesIds=None):
    method processPunctuation (line 122) | def processPunctuation(self, inText):
    method processDigitArticle (line 134) | def processDigitArticle(self, inText):
    method setAccuracy (line 149) | def setAccuracy(self, accQA, accQuesType, accAnsType):
    method setEvalQA (line 154) | def setEvalQA(self, quesId, acc):
    method setEvalQuesType (line 157) | def setEvalQuesType(self, quesId, quesType, acc):
    method setEvalAnsType (line 162) | def setEvalAnsType(self, quesId, ansType, acc):
    method updateProgress (line 167) | def updateProgress(self, progress):

FILE: prophet/__init__.py
  function get_args (line 6) | def get_args():
  function get_runner (line 30) | def get_runner(__C, evaluater):

FILE: prophet/stage1/finetune.py
  class Runner (line 27) | class Runner(object):
    method __init__ (line 28) | def __init__(self, __C, evaluater):
    method train (line 32) | def train(self, train_set, eval_set=None):
    method eval (line 151) | def eval(self, dataset, net=None, eval_now=False):
    method run (line 224) | def run(self):
  function finetune_login_args (line 262) | def finetune_login_args(parser):

FILE: prophet/stage1/heuristics.py
  class Runner (line 29) | class Runner(object):
    method __init__ (line 30) | def __init__(self, __C, *args, **kwargs):
    method eval (line 36) | def eval(self, dataset):
    method run (line 109) | def run(self):
  function heuristics_login_args (line 164) | def heuristics_login_args(parser):

FILE: prophet/stage1/model/layers.py
  class AttFlat (line 14) | class AttFlat(nn.Module):
    method __init__ (line 15) | def __init__(self, __C):
    method forward (line 32) | def forward(self, x, x_mask):
  class MHAtt (line 53) | class MHAtt(nn.Module):
    method __init__ (line 54) | def __init__(self, __C):
    method forward (line 68) | def forward(self, v, k, q, mask):
    method att (line 91) | def att(self, value, key, query, mask):
  class SA_v (line 107) | class SA_v(nn.Module):
    method __init__ (line 108) | def __init__(self, __C):
    method forward (line 127) | def forward(self, *args):
    method att (line 155) | def att(self, value, key, query, mask):
  class FFN (line 171) | class FFN(nn.Module):
    method __init__ (line 172) | def __init__(self, __C):
    method forward (line 185) | def forward(self, x, *args):
  class SA (line 192) | class SA(nn.Module):
    method __init__ (line 193) | def __init__(self, __C):
    method forward (line 201) | def forward(self, x, x_mask, *args):
  class GA (line 209) | class GA(nn.Module):
    method __init__ (line 210) | def __init__(self, __C):
    method forward (line 218) | def forward(self, x, y, x_mask, y_mask, *args):

FILE: prophet/stage1/model/mcan.py
  class MCA_ED (line 17) | class MCA_ED(nn.Module):
    method __init__ (line 21) | def __init__(self, __C):
    method forward (line 29) | def forward(self, x, y, x_mask, y_mask):
  class MCAN (line 40) | class MCAN(nn.Module):
    method __init__ (line 49) | def __init__(self, __C, answer_size):
    method forward (line 79) | def forward(self, input_tuple, output_answer_latent=False):
    method make_mask (line 124) | def make_mask(self, feature):

FILE: prophet/stage1/model/mcan_for_finetune.py
  class MCANForFinetune (line 14) | class MCANForFinetune(MCAN):
    method __init__ (line 20) | def __init__(self, __C, answer_size, base_answer_size=3129):
    method parameter_init (line 26) | def parameter_init(self):
    method forward (line 30) | def forward(self, input_tuple, output_answer_latent=False):

FILE: prophet/stage1/model/net_utils.py
  class FC (line 9) | class FC(nn.Module):
    method __init__ (line 10) | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
    method forward (line 23) | def forward(self, x):
  class MLP (line 35) | class MLP(nn.Module):
    method __init__ (line 36) | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu...
    method forward (line 42) | def forward(self, x):
  function flatten (line 46) | def flatten(x):
  function unflatten (line 52) | def unflatten(x, shape):
  class Identity (line 58) | class Identity(nn.Module):
    method __init__ (line 59) | def __init__(self):
    method forward (line 62) | def forward(self, x):

FILE: prophet/stage1/model/rope2d.py
  function rotate_every_two (line 14) | def rotate_every_two(x):
  function apply_rotary_pos_emb (line 22) | def apply_rotary_pos_emb(q, k, sinu_pos):
  class RoPE2d (line 28) | class RoPE2d(nn.Module):
    method __init__ (line 29) | def __init__(self, in_dim, size):
    method forward (line 53) | def forward(self, k, q):

FILE: prophet/stage1/pretrain.py
  class Runner (line 27) | class Runner(object):
    method __init__ (line 28) | def __init__(self, __C, *args, **kwargs):
    method train (line 31) | def train(self, train_set, eval_set=None):
    method run (line 148) | def run(self):
  function pretrain_login_args (line 164) | def pretrain_login_args(parser):

FILE: prophet/stage1/utils/load_data.py
  function soft_target (line 22) | def soft_target(answers, ans_to_ix, preprocess=True):
  class CommonData (line 32) | class CommonData:
    method __init__ (line 39) | def __init__(self, __C) -> None:
  class DataSet (line 69) | class DataSet(Data.Dataset):
    method __init__ (line 70) | def __init__(self, __C, common_data, split_name_list):
    method __getitem__ (line 116) | def __getitem__(self, idx):
    method __len__ (line 161) | def __len__(self):
    method bert_tokenize (line 164) | def bert_tokenize(self, text, max_token):

FILE: prophet/stage1/utils/optim.py
  class OptimizerWrapper (line 10) | class OptimizerWrapper(object):
    method __init__ (line 15) | def __init__(self, optimizers,
    method zero_grad (line 40) | def zero_grad(self):
    method step (line 43) | def step(self, step=None, schd_step=False):
    method schedule_step (line 57) | def schedule_step(self, schd_step):
    method current_lr (line 71) | def current_lr(self):
    method state_dict (line 74) | def state_dict(self):
    method load_state_dict (line 77) | def load_state_dict(self, state_dict):
  function get_optim (line 81) | def get_optim(__C, model):
  function get_optim_for_finetune (line 105) | def get_optim_for_finetune(__C, model, new_params_name='proj1'):

FILE: prophet/stage2/prompt.py
  class Runner (line 25) | class Runner:
    method __init__ (line 26) | def __init__(self, __C, evaluater):
    method gpt3_infer (line 31) | def gpt3_infer(self, prompt_text, _retry=0):
    method sample_make (line 73) | def sample_make(self, ques, capt, cands, ans=None):
    method get_context (line 86) | def get_context(self, example_qids):
    method run (line 100) | def run(self):
  function prompt_login_args (line 200) | def prompt_login_args(parser):

FILE: prophet/stage2/utils/data_utils.py
  function ok_score (line 12) | def ok_score(gt_answers):
  function aok_score (line 30) | def aok_score(gt_answers):
  class Qid2Data (line 47) | class Qid2Data(Dict):
    method __init__ (line 48) | def __init__(self, __C, splits, annotated=False, similar_examples=None):
    method __getitem__ (line 134) | def __getitem__(self, __key):
    method get_caption (line 138) | def get_caption(self, qid):
    method get_question (line 145) | def get_question(self, qid):
    method get_gt_answers (line 149) | def get_gt_answers(self, qid):
    method get_most_answer (line 154) | def get_most_answer(self, qid):
    method get_topk_candidates (line 159) | def get_topk_candidates(self, qid, k=None):
    method get_similar_qids (line 165) | def get_similar_qids(self, qid, k=None):
    method evaluate_by_threshold (line 171) | def evaluate_by_threshold(self, ans_set, threshold=1.0):
    method topk_accuracy (line 189) | def topk_accuracy(self, k=1, sub_set=None):
    method rt_evaluate (line 205) | def rt_evaluate(self, answer_set):

FILE: prophet/stage2/utils/fancy_pbar.py
  class RichColumn (line 12) | class RichColumn(ProgressColumn):
    method __init__ (line 13) | def __init__(self, table_column: Optional[Column] = None) -> None:
    method render (line 22) | def render(self, task: "Task") -> Text:
  function exit_progress (line 57) | def exit_progress():

FILE: tools/extract_img_feats.py
  function _extract_feat (line 25) | def _extract_feat(img_path, net, T, save_path):
  class ExtractModel (line 41) | class ExtractModel:
    method __init__ (line 42) | def __init__(self, encoder) -> None:
    method __call__ (line 49) | def __call__(self, img):
  function main (line 54) | def main(__C, dataset):

FILE: tools/transforms.py
  function Pad (line 21) | def Pad():
  function _convert_image_to_rgb (line 35) | def _convert_image_to_rgb(image):
  function identity (line 38) | def identity(x):
  function _transform (line 41) | def _transform(n_px, pad=False, crop=False):
Condensed preview — 50 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (172K chars).
[
  {
    "path": ".gitignore",
    "chars": 246,
    "preview": "**/__pycache__/\ndatasets/*/\n!datasets/.gitkeep\nassets/*\n!assets/.gitkeep\nckpts/*\n!ckpts/.gitkeep\noutputs/ckpts/*\n!output"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 14990,
    "preview": "# Prophet\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/prompting-large-language-mod"
  },
  {
    "path": "assets/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ckpts/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "configs/finetune.yml",
    "chars": 749,
    "preview": "# Network\nIMG_RESOLUTION: 512\nIMG_FEAT_GRID: 16\nIMG_FEAT_SIZE: 4096\nBERT_VERSION: bert-large-uncased\nMAX_TOKEN: 32\nARCH_"
  },
  {
    "path": "configs/path_cfgs.py",
    "chars": 3071,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "configs/pretrain.yml",
    "chars": 561,
    "preview": "# Network\nIMG_RESOLUTION: 512\nIMG_FEAT_GRID: 16\nIMG_FEAT_SIZE: 4096\nBERT_VERSION: bert-large-uncased\nMAX_TOKEN: 32\nARCH_"
  },
  {
    "path": "configs/prompt.yml",
    "chars": 372,
    "preview": "MODEL: text-davinci-002\nTEMPERATURE: 0.\nMAX_TOKENS: 8\nSLEEP_PER_INFER: 10\n\nPROMPT_HEAD: \"Please answer the question acco"
  },
  {
    "path": "configs/task_cfgs.py",
    "chars": 6144,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "configs/task_to_split.py",
    "chars": 2203,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "datasets/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "environment.yml",
    "chars": 786,
    "preview": "name: prophet\nchannels:\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch\n  - pytorch\n  - https://mirrors."
  },
  {
    "path": "evaluation/ans_punct.py",
    "chars": 4718,
    "preview": "# --------------------------------------------------------\n# mcan-vqa (Deep Modular Co-Attention Networks)\n# Licensed un"
  },
  {
    "path": "evaluation/aok_utils/eval_predictions.py",
    "chars": 3161,
    "preview": "import argparse\nimport pathlib\nimport json\nimport glob\n\nfrom .load_aokvqa import load_aokvqa\n\n\ndef eval_aokvqa(dataset, "
  },
  {
    "path": "evaluation/aok_utils/load_aokvqa.py",
    "chars": 378,
    "preview": "import os\nimport json\n\n\ndef load_aokvqa(aokvqa_dir, split, version='v1p0'):\n    assert split in ['train', 'val', 'test',"
  },
  {
    "path": "evaluation/aok_utils/remap_predictions.py",
    "chars": 1873,
    "preview": "import os \nos.environ['CUDA_VISIBLE_DEVICES'] = '1'\nimport argparse\nimport pathlib\nimport json\nfrom tqdm import tqdm\n\nfr"
  },
  {
    "path": "evaluation/aokvqa_evaluate.py",
    "chars": 3947,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "evaluation/okvqa_evaluate.py",
    "chars": 3243,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "evaluation/vqa_utils/vqa.py",
    "chars": 7090,
    "preview": "__author__ = 'aagrawal'\n__version__ = '0.9'\n\n# Interface for accessing the VQA dataset.\n\n# This code is based on the cod"
  },
  {
    "path": "evaluation/vqa_utils/vqaEval.py",
    "chars": 8156,
    "preview": "# coding=utf-8\n\n__author__='aagrawal'\n\n# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API av"
  },
  {
    "path": "main.py",
    "chars": 762,
    "preview": "import argparse\nimport yaml\nimport torch\n\nfrom evaluation.okvqa_evaluate import OKEvaluater\nfrom evaluation.aokvqa_evalu"
  },
  {
    "path": "misc/tree.txt",
    "chars": 2645,
    "preview": "prophet\n├── assets\n│   ├── answer_aware_examples_okvqa.json\n│   ├── answer_dict_aokvqa.json\n│   ├── answer_dict_okvqa.js"
  },
  {
    "path": "outputs/ckpts/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "outputs/logs/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "outputs/results/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "preds/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "prophet/__init__.py",
    "chars": 2833,
    "preview": "__author__ = 'Zhenwei Shao'\n__version__ = '1.0'\n\nimport argparse\n\ndef get_args():\n    parser = argparse.ArgumentParser()"
  },
  {
    "path": "prophet/stage1/finetune.py",
    "chars": 10822,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/heuristics.py",
    "chars": 6726,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/model/layers.py",
    "chars": 6209,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/model/mcan.py",
    "chars": 3890,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/model/mcan_for_finetune.py",
    "chars": 1419,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/model/net_utils.py",
    "chars": 1600,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/model/rope2d.py",
    "chars": 2075,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/pretrain.py",
    "chars": 7074,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage1/utils/load_data.py",
    "chars": 6769,
    "preview": "# --------------------------------------------------------------------------------- #\n# Author: Zhenwei Shao (https://gi"
  },
  {
    "path": "prophet/stage1/utils/optim.py",
    "chars": 5244,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage2/prompt.py",
    "chars": 8897,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage2/utils/data_utils.py",
    "chars": 7173,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "prophet/stage2/utils/fancy_pbar.py",
    "chars": 2447,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "scripts/evaluate_file.sh",
    "chars": 1088,
    "preview": "#!/bin/bash\n# This script is used to evaluate a result file.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do\n  case \"$1\" in\n"
  },
  {
    "path": "scripts/evaluate_model.sh",
    "chars": 956,
    "preview": "#!/bin/bash\n# This script is used to evaluate a finetuned model.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do\n  case \"$1\""
  },
  {
    "path": "scripts/extract_img_feats.sh",
    "chars": 670,
    "preview": "#!/bin/bash\n# This script is used to extract image features.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do\n  case \"$1\" in\n"
  },
  {
    "path": "scripts/finetune.sh",
    "chars": 1031,
    "preview": "#!/bin/bash\n# This script is used to finetune the pretrained MCAN model.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do\n  c"
  },
  {
    "path": "scripts/heuristics_gen.sh",
    "chars": 1270,
    "preview": "#!/bin/bash\n# This script is used to generate heuristics from a finetuned model.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]"
  },
  {
    "path": "scripts/pretrain.sh",
    "chars": 730,
    "preview": "#!/bin/bash\n# This script is used to pretrain the MCAN model.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do\n  case \"$1\" in"
  },
  {
    "path": "scripts/prompt.sh",
    "chars": 1445,
    "preview": "#!/bin/bash\n# This script is used to prompt GPT-3 to generate final answers.\n\n# Parse arguments\nwhile [[ $# -gt 0 ]]; do"
  },
  {
    "path": "tools/extract_img_feats.py",
    "chars": 2849,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  },
  {
    "path": "tools/transforms.py",
    "chars": 1753,
    "preview": "# ------------------------------------------------------------------------------ #\n# Author: Zhenwei Shao (https://githu"
  }
]

About this extraction

This page contains the full source code of the MILVLG/prophet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 50 files (157.6 KB), approximately 43.3k tokens, and a symbol index with 175 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!