Full Code of TencentARC/ST-LLM for AI

main 64d2231a0c50 cached
107 files
558.1 KB
135.4k tokens
774 symbols
1 requests
Download .txt
Showing preview only (592K chars total). Download the full file or copy to clipboard to get everything.
Repository: TencentARC/ST-LLM
Branch: main
Commit: 64d2231a0c50
Files: 107
Total size: 558.1 KB

Directory structure:
gitextract_1lgs5ey9/

├── LICENSE
├── PrepareVicuna.md
├── README.md
├── config/
│   ├── instructblipbase_avp.yaml
│   ├── instructblipbase_stllm_conversation.yaml
│   ├── instructblipbase_stllm_qa.yaml
│   ├── minigpt4base_avp.yaml
│   └── minigpt4base_stllm_qa.yaml
├── demo.py
├── demo_gradio.py
├── prompts/
│   └── alignment.txt
├── requirement.txt
├── script/
│   ├── inference/
│   │   ├── mvbench/
│   │   │   └── test_mvbench.sh
│   │   ├── qabench/
│   │   │   ├── anet_qa.sh
│   │   │   ├── msrvtt_qa.sh
│   │   │   ├── msvd_qa.sh
│   │   │   ├── score_anet.sh
│   │   │   ├── score_msrvtt.sh
│   │   │   └── score_msvd.sh
│   │   └── vcgbench/
│   │       ├── score_consist.sh
│   │       ├── score_context.sh
│   │       ├── score_correct.sh
│   │       ├── score_detail.sh
│   │       ├── score_temporal.sh
│   │       ├── test_consist.sh
│   │       ├── test_general.sh
│   │       └── test_temporal.sh
│   └── train/
│       └── train.sh
├── stllm/
│   ├── __init__.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dist_utils.py
│   │   ├── gradcam.py
│   │   ├── logger.py
│   │   ├── optims.py
│   │   ├── registry.py
│   │   └── utils.py
│   ├── configs/
│   │   ├── datasets/
│   │   │   ├── cc_sbu/
│   │   │   │   ├── align.yaml
│   │   │   │   └── defaults.yaml
│   │   │   └── laion/
│   │   │       └── defaults.yaml
│   │   ├── default.yaml
│   │   └── models/
│   │       ├── instructblip_vicuna0.yaml
│   │       ├── instructblip_vicuna0_btadapter.yaml
│   │       ├── minigpt4_vicuna0.yaml
│   │       └── minigpt4_vicuna0_btadapter.yaml
│   ├── conversation/
│   │   ├── __init__.py
│   │   ├── conversation.py
│   │   └── mvbench_conversation.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── builders/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset_builder.py
│   │   │   └── image_text_pair_builder.py
│   │   ├── data_utils.py
│   │   └── datasets/
│   │       ├── __init__.py
│   │       ├── base_dataset.py
│   │       ├── caption_datasets.py
│   │       ├── cc_sbu_dataset.py
│   │       ├── dataloader_utils.py
│   │       ├── image_video_itdatasets.py
│   │       ├── instruction_data.py
│   │       ├── laion_dataset.py
│   │       └── utils.py
│   ├── models/
│   │   ├── Qformer.py
│   │   ├── __init__.py
│   │   ├── base_decoder.py
│   │   ├── base_model.py
│   │   ├── blip2.py
│   │   ├── blip2_outputs.py
│   │   ├── eva_btadapter.py
│   │   ├── eva_vit.py
│   │   ├── modeling_llama_mem.py
│   │   ├── peft_model.py
│   │   ├── st_llm.py
│   │   └── utils.py
│   ├── processors/
│   │   ├── __init__.py
│   │   ├── base_processor.py
│   │   ├── blip_processors.py
│   │   ├── randaugment.py
│   │   └── video_transform.py
│   ├── runners/
│   │   ├── __init__.py
│   │   └── runner_base.py
│   ├── tasks/
│   │   ├── __init__.py
│   │   ├── base_task.py
│   │   └── image_text_pretrain.py
│   ├── test/
│   │   ├── __init__.py
│   │   ├── gpt_evaluation/
│   │   │   ├── evaluate_activitynet_qa.py
│   │   │   ├── evaluate_benchmark_1_correctness.py
│   │   │   ├── evaluate_benchmark_2_detailed_orientation.py
│   │   │   ├── evaluate_benchmark_3_context.py
│   │   │   ├── evaluate_benchmark_4_temporal.py
│   │   │   └── evaluate_benchmark_5_consistency.py
│   │   ├── mvbench/
│   │   │   ├── mv_bench.py
│   │   │   └── mv_bench_infer.py
│   │   ├── qabench/
│   │   │   ├── activitynet_qa.py
│   │   │   ├── msrvtt_qa.py
│   │   │   └── msvd_qa.py
│   │   ├── vcgbench/
│   │   │   ├── videochatgpt_benchmark_consist.py
│   │   │   └── videochatgpt_benchmark_general.py
│   │   ├── video_transforms.py
│   │   └── video_utils.py
│   └── train/
│       ├── stllm_trainer.py
│       ├── train.py
│       ├── train_hf.py
│       ├── zero2.json
│       ├── zero3.json
│       └── zero3_offload.json
└── trainval.md

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: PrepareVicuna.md
================================================
## How to Prepare Vicuna Weight
Vicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT. 
We currently use the v0 version of Vicuna-13B. 

To prepare Vicuna’s weight, first download Vicuna’s **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0). 
In case you have git-lfs installed (https://git-lfs.com), this can be done by

```
git lfs install
git clone https://huggingface.co/lmsys/vicuna-13b-delta-v0  # more powerful, need at least 24G gpu memory
# or
git clone https://huggingface.co/lmsys/vicuna-7b-delta-v0  # smaller, need 12G gpu memory
```

Note that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMA’s rules, we cannot distribute the weight of LLAMA.)

Then, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format 
either following the instruction provided by HuggingFace 
[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet. 

When these two weights are ready, we can use tools from Vicuna’s team to create the real working weight.
First, Install their library that is compatible with v0 Vicuna by

```
pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10
```

Then, run the following command to create the final working weight

```
python -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/  --target /path/to/save/working/vicuna/weight/  --delta /path/to/vicuna-13bOR7b-delta-v0/
```

Now you are good to go!



================================================
FILE: README.md
================================================
<p align="center" width="100%">
<a target="_blank"><img src="example/material/stllm_logo.png" alt="ST-LLM" style="width: 50%; min-width: 150px; display: block; margin: auto;"></a>
</p>

<h2 align="center"> <a href="https://arxiv.org/abs/2404.00308">ST-LLM: Large Language Models Are Effective Temporal Learners</a></h2>

<h5 align=center>

[![hf](https://img.shields.io/badge/🤗-Hugging%20Face-blue.svg)](https://huggingface.co/farewellthree/ST_LLM_weight/tree/main)
[![arXiv](https://img.shields.io/badge/Arxiv-2311.08046-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2404.00308)
[![License](https://img.shields.io/badge/Code%20License-Apache2.0-yellow)](https://github.com/farewellthree/ST-LLM/blob/main/LICENSE)
</h5>

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-1)](https://paperswithcode.com/sota/video-based-generative-performance-1?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-5)](https://paperswithcode.com/sota/video-based-generative-performance-5?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-2)](https://paperswithcode.com/sota/video-based-generative-performance-2?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-3)](https://paperswithcode.com/sota/video-based-generative-performance-3?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-4)](https://paperswithcode.com/sota/video-based-generative-performance-4?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=st-llm-large-language-models-are-effective-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=st-llm-large-language-models-are-effective-1)

## News :loudspeaker:

* **[2024/3/28]**  All codes and weights are available now! Welcome to watch this repository for the latest updates.

## Introduction :bulb:

- **ST-LLM** is a temporal-sensitive video large language model. Our model incorporates three key architectural: 
  - (1) Joint spatial-temporal modeling within large language models for effective video understanding.
  - (2) Dynamic masking strategy and mask video modeling for efficiency and robustness.
  - (3) Global-local input module for long video understanding.
- **ST-LLM** has established new state-of-the-art results on MVBench, VideoChatGPT Bench and VideoQA Bench:

<div align="center">
<table border="1" width="100%">
    <tr align="center">
        <th rowspan="2">Method</th><th rowspan="2">MVBench</th><th colspan="6">VcgBench</th><th colspan="3">VideoQABench</th>
    </tr>
  <tr align="center">
        <th>Avg</th><th>Correct</th><th>Detail</th><th>Context</th><th>Temporal</th><th>Consist</th><th>MSVD</th><th>MSRVTT</th><th>ANet</th>
    </tr>
  <tr align="center">
        <td>VideoLLaMA</td><td>34.1</td><td>1.96</td><td>2.18</td><td>2.16</td><td>1.82</td><td>1.79</td><td>1.98</td><td>51.6</td><td>29.6</td><td>12.4</td>
    </tr>
  <tr align="center">
        <td>LLaMA-Adapter</td><td>31.7</td><td>2.03</td><td>2.32</td><td>2.30</td><td>1.98</td><td>2.15</td><td>2.16</td><td>54.9</td><td>43.8</td><td>34.2</td>
    </tr>
  <tr align="center">
        <td>VideoChat</td><td>35.5</td><td>2.23</td><td>2.50</td><td>2.53</td><td>1.94</td><td>2.24</td><td>2.29</td><td>56.3</td><td>45.0</td><td>26.5</td>
    </tr>
  <tr align="center">
        <td>VideoChatGPT</td><td>32.7</td><td>2.38</td><td>2.40</td><td>2.52</td><td>2.62</td><td>1.98</td><td>2.37</td><td>64.9</td><td>49.3</td><td>35.7</td>
    </tr>
  <tr align="center">
        <td>MovieChat</td><td>-</td><td>2.76</td><td>2.93</td><td>3.01</td><td>2.24</td><td>2.42</td><td>2.67</td><td>74.2</td><td>52.7</td><td>45.7</td>
    </tr>
  <tr align="center">
        <td>Vista-LLaMA</td><td>-</td><td>2.44</td><td>2.64</td><td>3.18</td><td>2.26</td><td>2.31</td><td>2.57</td><td>65.3</td><td>60.5</td><td>48.3</td>
    </tr>
  <tr align="center">
        <td>LLaMA-VID</td><td>-</td><td>2.89</td><td>2.96</td><td>3.00</td><td>3.53</td><td>2.46</td><td>2.51</td><td>69.7</td><td>57.7</td><td>47.4</td>
    </tr>
  <tr align="center">
        <td>Chat-UniVi</td><td>-</td><td>2.99</td><td>2.89</td><td>2.91</td><td>3.46</td><td>2.89</td><td>2.81</td><td>65.0</td><td>54.6</td><td>45.8</td>
    </tr>
  <tr align="center">
        <td>VideoChat2</td><td>51.1</td><td>2.98</td><td>3.02</td><td>2.88</td><td>3.51</td><td>2.66</td><td>2.81</td><td>70.0</td><td>54.1</td><td>49.1</td>
    </tr>
  <tr align="center">
        <td>ST-LLM</td><td><b>54.9</b></td><td><b>3.15</b></td><td><b>3.23</b></td><td><b>3.05</b></td><td><b>3.74</b></td><td><b>2.93</b></td><td><b>2.81</b></td><td><b>74.6</b></td><td><b>63.2</b></td><td><b>50.9</b></td>
    </tr>
  
</table>
</div>

## Demo 🤗
Please download the conversation weights from [here](https://huggingface.co/farewellthree/ST_LLM_weight/tree/main/conversation_weight) and follow the instructions in [installation](README.md#Installation) first. Then, run the gradio demo:
```
CUDA_VISIBLE_DEVICES=0 python3 demo_gradio.py --ckpt-path /path/to/STLLM_conversation_weight
```
We have also prepared local scripts that are easy to modify:[demo.py](demo.py)

<div align=center>
<img src="example/material/Mabaoguo.gif" width="70%" />
</div>

<div align=center>
<img src="example/material/Driving.gif" width="70%" />
</div>

## Examples 👀
- **Video Description: for high-difficulty videos with complex scene changes, ST-LLM can accurately describe all the contents.**
<p align="center">
  <img src="example/driving.gif" width="25%" style="display:inline-block" />
  <img src="example/driving.jpg" width="65%" style="display:inline-block" /> 
</p>

- **Action Identification: ST-LLM can accurately and comprehensively describe the actions occurring in the video.**
<p align="center">
  <img src="example/cooking.gif" width="21%" style="display:inline-block" />
  <img src="example/cooking.jpg" width="68%" style="display:inline-block" /> 
</p>

<p align="center">
  <img src="example/TVshow.gif" width="21%" style="display:inline-block" />
  <img src="example/TVshow.jpg" width="68%" style="display:inline-block" /> 
</p>

<p align="center">
  <img src="example/monkey.gif" width="21%" style="display:inline-block" />
  <img src="example/monkey.jpg" width="68%" style="display:inline-block" /> 
</p>

- **Reasoning: for the challenging open-ended reasoning questions, STLLM can also provide reasonable answers.**
  <p align="center">
  <img src="example/BaoguoMa.gif" width="26%" style="display:inline-block" />
  <img src="example/baoguoma.jpg" width="66%" style="display:inline-block" /> 
</p>

## Installation 🛠️
Git clone our repository, creating a Python environment and activate it via the following command

```bash
git clone https://github.com/farewellthree/ST-LLM.git
cd ST-LLM
conda create --name stllm python=3.10
conda activate stllm
pip install -r requirement.txt
```

## Training & Validation :bar_chart:
The instructions of data, training and evaluating can be found in [trainval.md](trainval.md).

## Acknowledgement 👍
* [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT) and [MVBench](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) Great job contributing video LLM benchmark.
* [InstuctBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip) and [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4/tree/main) The codebase and the basic image LLM we built upon.

## Citation ✏️
If you find the code and paper useful for your research, please consider staring this repo and citing our paper:
```
@article{liu2023one,
  title={One for all: Video conversation is feasible without video instruction tuning},
  author={Liu, Ruyang and Li, Chen and Ge, Yixiao and Shan, Ying and Li, Thomas H and Li, Ge},
  journal={arXiv preprint arXiv:2309.15785},
  year={2023}
}
```
```
@article{liu2023one,
  title={ST-LLM: Large Language Models Are Effective Temporal Learners},
  author={Liu, Ruyang and Li, Chen and Tang, Haoran and Ge, Yixiao and Shan, Ying and Li, Ge},
  journal={https://arxiv.org/abs/2404.00308},
  year={2023}
}
```
 


================================================
FILE: config/instructblipbase_avp.yaml
================================================
model:
  arch: st_llm_hf
  model_type: instructblip_vicuna0_btadapter
  use_grad_checkpoint: True
  max_txt_len: 256
  end_sym: "###"
  video_input: "mean"
  llama_model: '/path/to/vicuna-7b-v1.1'
  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  qformer_text_input: True
  freeze_LLM: False

datasets:
  caption_videochatgpt:
    num_frames: 16
    #video_reader_type: 'rawframe'
  classification_k710:
    num_frames: 16
  classification_ssv2:
    num_frames: 16
  reasoning_next_qa:
    num_frames: 16
  reasoning_clevrer_qa:
    num_frames: 16
  reasoning_clevrer_mc:
    num_frames: 16
  vqa_webvid_qa:
    num_frames: 16

run:
  task: video_text_it
  bf16: True
  tf32: False
  output_dir: "./stllm/output/instructblipbase_avp"
  num_train_epochs: 2
  dataloader_num_workers: 4
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  evaluation_strategy: "no"
  learning_rate: 2e-5
  weight_decay: 0.
  warmup_ratio: 0.03
  lr_scheduler_type: 'cosine'
  logging_steps: 50
  model_max_length: 1024
  #save_steps: 10000 
  save_strategy: "epoch" 
  save_total_limit: 1
  deepspeed: 'stllm/train/zero3.json'

================================================
FILE: config/instructblipbase_stllm_conversation.yaml
================================================
model:
  arch: st_llm_hf
  model_type: instructblip_vicuna0
  use_grad_checkpoint: True
  max_txt_len: 256
  end_sym: "###"
  #prompt_path: "prompts/alignment.txt"
  prompt_template: '###Human: {} ###Assistant: '
  llama_model: '/path/to/vicuna-7b-v1.1'
  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  qformer_text_input: True
  freeze_LLM: False
  video_input: "residual"
  residual_size: 16
  use_mask : True
  mvm_decode: True

datasets:
  caption_videochat:
    num_frames: 64
  conversation_videochat1:
    num_frames: 64
  caption_videochatgpt:
    num_frames: 64
    #video_reader_type: 'rawframe'
  caption_webvid:
    num_frames: 64
  vqa_webvid_qa:
    num_frames: 64

run:
  task: video_text_it
  bf16: True
  tf32: False
  output_dir: "./stllm/output/instructblipbase_stllm_conversation"
  num_train_epochs: 2
  dataloader_num_workers: 4
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  evaluation_strategy: "no"
  learning_rate: 2e-5
  weight_decay: 0.
  warmup_ratio: 0.03
  lr_scheduler_type: 'cosine'
  logging_steps: 50
  model_max_length: 1024
  save_strategy: "epoch" 
  save_total_limit: 1
  deepspeed: 'stllm/train/zero2.json'

================================================
FILE: config/instructblipbase_stllm_qa.yaml
================================================
model:
  arch: st_llm_hf
  model_type: instructblip_vicuna0_btadapter
  use_grad_checkpoint: True
  max_txt_len: 256
  end_sym: "###"
  video_input: "all"
  llama_model: '/path/to/vicuna-7b-v1.1'
  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'
  qformer_text_input: True
  freeze_LLM: False
  use_mask : True
  mvm_decode: True

datasets:
  caption_videochatgpt:
    num_frames: 16
    #video_reader_type: 'rawframe'
  classification_k710:
    num_frames: 16
  classification_ssv2:
    num_frames: 16
  reasoning_next_qa:
    num_frames: 16
  reasoning_clevrer_qa:
    num_frames: 16
  reasoning_clevrer_mc:
    num_frames: 16
  vqa_webvid_qa:
    num_frames: 16

run:
  task: video_text_it
  bf16: True
  tf32: False
  output_dir: "./stllm/output/instructblipbase_stllm_qa"
  num_train_epochs: 2
  dataloader_num_workers: 4
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  evaluation_strategy: "no"
  learning_rate: 2e-5
  weight_decay: 0.
  warmup_ratio: 0.03
  lr_scheduler_type: 'cosine'
  logging_steps: 50
  model_max_length: 1024
  #save_steps: 10000 
  save_strategy: "epoch" 
  save_total_limit: 1
  deepspeed: 'stllm/train/zero3.json'

================================================
FILE: config/minigpt4base_avp.yaml
================================================
model:
  arch: st_llm_hf
  model_type: minigpt4_vicuna0_btadapter
  use_grad_checkpoint: True
  max_txt_len: 256
  end_sym: "###"
  video_input: "mean"
  llama_model: "/path/to/vicuna-7b"
  ckpt: '/Path/to/prerained_minigpt4_7b.pth'
  q_former_model: /Path/to/blip2_pretrained_flant5xxl.pth
  qformer_text_input: False
  freeze_LLM: False

datasets:
  caption_videochatgpt:
    num_frames: 16
    #video_reader_type: 'rawframe'
  classification_k710:
    num_frames: 16
  classification_ssv2:
    num_frames: 16
  reasoning_next_qa:
    num_frames: 16
  reasoning_clevrer_qa:
    num_frames: 16
  reasoning_clevrer_mc:
    num_frames: 16
  vqa_webvid_qa:
    num_frames: 16

run:
  task: video_text_it
  bf16: True
  tf32: False
  output_dir: "./stllm/output/minigpt4base_avp"
  num_train_epochs: 2
  dataloader_num_workers: 4
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  evaluation_strategy: "no"
  learning_rate: 2e-5
  weight_decay: 0.
  warmup_ratio: 0.03
  lr_scheduler_type: 'cosine'
  logging_steps: 50
  model_max_length: 1024
  #save_steps: 10000 
  save_strategy: "epoch" 
  save_total_limit: 1
  deepspeed: 'stllm/train/zero3.json'

================================================
FILE: config/minigpt4base_stllm_qa.yaml
================================================
model:
  arch: st_llm_hf
  model_type: minigpt4_vicuna0_btadapter
  use_grad_checkpoint: True
  max_txt_len: 256
  end_sym: "###"
  video_input: "all"
  llama_model: "/path/to/vicuna-7b"
  ckpt: '/Path/to/prerained_minigpt4_7b.pth'
  q_former_model: /Path/to/blip2_pretrained_flant5xxl.pth
  qformer_text_input: False
  freeze_LLM: False
  use_mask : True
  mvm_decode: True

datasets:
  caption_videochatgpt:
    num_frames: 16
    #video_reader_type: 'rawframe'
  classification_k710:
    num_frames: 16
  classification_ssv2:
    num_frames: 16
  reasoning_next_qa:
    num_frames: 16
  reasoning_clevrer_qa:
    num_frames: 16
  reasoning_clevrer_mc:
    num_frames: 16
  vqa_webvid_qa:
    num_frames: 16

run:
  task: video_text_it
  bf16: True
  tf32: False
  output_dir: "./stllm/output/minigpt4base_stllm_qa"
  num_train_epochs: 2
  dataloader_num_workers: 4
  per_device_train_batch_size: 16
  per_device_eval_batch_size: 16
  gradient_accumulation_steps: 1
  evaluation_strategy: "no"
  learning_rate: 2e-5
  weight_decay: 0.
  warmup_ratio: 0.03
  lr_scheduler_type: 'cosine'
  logging_steps: 50
  model_max_length: 1024
  #save_steps: 10000 
  save_strategy: "epoch" 
  save_total_limit: 1
  deepspeed: 'stllm/train/zero3.json'

================================================
FILE: demo.py
================================================
import argparse
import torch

from stllm.common.config import Config
from stllm.common.registry import registry
from stllm.conversation.conversation import Chat, CONV_instructblip_Vicuna0

# imports modules for registration
from stllm.datasets.builders import *
from stllm.models import *
from stllm.processors import *
from stllm.runners import *
from stllm.tasks import *

def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", default='config/instructblipbase_stllm_conversation.yaml', help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--ckpt-path", required=True, help="path to STLLM_conversation_weight.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


# ========================================
#             Model Initialization
# ========================================

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

ckpt_path = args.ckpt_path
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_config.ckpt = ckpt_path
model_config.llama_model = ckpt_path
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model.to(torch.float16)
CONV_VISION = CONV_instructblip_Vicuna0

chat = Chat(model, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

chat_state = CONV_VISION.copy()
video = 'example/BaoguoMa.mp4'
prompt = 'Tell me why this video looks so funny?'
img_list = []

chat.upload_video(video, chat_state, img_list, 64, text=prompt)
chat.ask("###Human: " + prompt + " ###Assistant: ", chat_state)
llm_message = chat.answer(conv=chat_state,
                img_list=img_list,
                num_beams=5,
                do_sample=False,
                temperature=1,
                max_new_tokens=300,
                max_length=2000)[0]
print (llm_message)




================================================
FILE: demo_gradio.py
================================================
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes

import argparse
import torch

from stllm.common.config import Config
from stllm.common.registry import registry
from stllm.conversation.conversation import Chat, CONV_instructblip_Vicuna0

# imports modules for registration
from stllm.datasets.builders import *
from stllm.models import *
from stllm.processors import *
from stllm.runners import *
from stllm.tasks import *

def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", default='config/instructblipbase_stllm_conversation.yaml', help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--ckpt-path", required=True, help="path to STLLM_conversation_weight.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args

# ========================================
#             Model Initialization
# ========================================

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

ckpt_path = args.ckpt_path
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_config.ckpt = ckpt_path
model_config.llama_model = ckpt_path
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model.to(torch.float16)
CONV_VISION = CONV_instructblip_Vicuna0

chat = Chat(model, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

# ========================================
#             Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
    if chat_state is not None:
        chat_state.messages = []
    if img_list is not None:
        img_list = []
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list


def upload_video(gr_video, chat_state, num_segments, text_prompt='Watch the video and answer the question.'):
    print('gr_video: ', gr_video)
    img_list = []
    if gr_video: 
        chat_state = CONV_VISION.copy()
        chat.upload_video(gr_video, chat_state, img_list, num_segments, text=text_prompt)
        return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list

def gradio_ask(user_message, chatbot, chat_state, gr_video, num_segments):
    if len(user_message) == 0:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chat_state = CONV_VISION.copy()
    img_list = []
    chat.upload_video(gr_video, chat_state, img_list, num_segments, text=user_message)
    msg = "###Human: " + user_message + " ###Assistant: "
    chat.ask(msg, chat_state)
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state, img_list


def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
    llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, do_sample=False, temperature=temperature, max_length=2000)[0]
    llm_message = llm_message.replace("<s>", "") # handle <s>
    chatbot[-1][1] = llm_message
    print(chat_state)
    print(f"Answer: {llm_message}")
    return chatbot, chat_state, img_list


class STLLM(gr.themes.base.Base):
    def __init__(
        self,
        *,
        primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        font=(
            fonts.GoogleFont("Noto Sans"),
            "ui-sans-serif",
            "sans-serif",
        ),
        font_mono=(
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-monospace",
            "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            body_background_fill="*neutral_50",
        )


gvlabtheme = STLLM(primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        )

title = """<h1 align="center"><a href="https://github.com/farewellthree/ST-LLM"><img src="https://s21.ax1x.com/2024/03/25/pF4Wzq0.png" border="0" style="margin: 0 auto; height: 150px;" /></a> </h1>"""
description ="""
        CLICK FOR SOURCE CODE!<br><p><a href='https://github.com/farewellthree/ST-LLM'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
        """


with gr.Blocks(title="ST-LLM Chatbot!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=0.5, visible=True) as video_upload:
            with gr.Column(elem_id="image", scale=0.5) as img_part:
                with gr.Tab("Video", elem_id='video_tab'):
                    up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload").style(height=360)
            # text_prompt_input = gr.Textbox(value="Watch the video and answer the question.",show_label=False, placeholder='Input your text prompt, example: "Watch the video and answer the question."', interactive=True).style(container=False)           
            upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            clear = gr.Button("Restart")
            
            num_beams = gr.Slider(
                minimum=1,
                maximum=10,
                value=5,
                step=1,
                interactive=True,
                label="beam search numbers",
            )
            
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Temperature",
            )
            
            num_segments = gr.Slider(
                minimum=16,
                maximum=96,
                value=64,
                step=1,
                interactive=True,
                label="Video Segments",
            )
        
        with gr.Column(visible=True)  as input_raws:
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(elem_id="chatbot",label='ST-LLM')
            with gr.Row():
                with gr.Column(scale=0.7):
                    text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False).style(container=False)
                with gr.Column(scale=0.15, min_width=0):
                    run = gr.Button("💭Send")
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("🔄Clear️")     
    
    upload_button.click(upload_video, [up_video, chat_state, num_segments], [up_video, text_input, upload_button, chat_state, img_list])
    
    text_input.submit(gradio_ask, [text_input, chatbot, chat_state, up_video, num_segments], [text_input, chatbot, chat_state, img_list]).then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
    )
    run.click(gradio_ask, [text_input, chatbot, chat_state, up_video, num_segments], [text_input, chatbot, chat_state, img_list]).then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
    )
    run.click(lambda: "", None, text_input)  
    clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_video, text_input, upload_button, chat_state, img_list], queue=False)

demo.launch(share=True, enable_queue=True)


================================================
FILE: prompts/alignment.txt
================================================
<Img><ImageHere></Img> Describe this image in detail.
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
<Img><ImageHere></Img> Please provide a detailed description of the picture.
<Img><ImageHere></Img> Could you describe the contents of this image for me?

================================================
FILE: requirement.txt
================================================
torch==2.0.0
torchaudio==2.0.1
torchvision==0.15.1+cu118
accelerate
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==22.2.0
bitsandbytes==0.37.0
cchardet==2.1.7
chardet==5.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.9.0
fonttools==4.38.0
frozenlist==1.3.3
huggingface-hub==0.13.4
importlib-resources==5.12.0
kiwisolver==1.4.4
matplotlib==3.7.0
multidict==6.0.4
openai==0.27.0
packaging==23.0
psutil==5.9.4
pycocotools==2.0.6
pyparsing==3.0.9
python-dateutil==2.8.2
pyyaml==6.0
regex==2022.10.31
tokenizers==0.13.2
tqdm==4.64.1
transformers==4.28.0
timm==0.6.13
spacy==3.5.1
webdataset==0.2.48
scikit-learn==1.2.2
scipy==1.10.1
yarl==1.8.2
zipp==3.14.0
omegaconf==2.3.0
opencv-python==4.7.0.72
iopath==0.1.10
decord==0.6.0
tenacity==8.2.2
pycocoevalcap
sentence-transformers
umap-learn
notebook
gradio==3.24.1
gradio-client==0.0.8
wandb
peft==0.8.1
einops==0.7.0
imageio==2.33.1
av==11.0.0
transformers[deepspeed]
mmengine


================================================
FILE: script/inference/mvbench/test_mvbench.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/mvbench/mv_bench_infer.py \
    --cfg-path config/instructblipbase_stllm_qa.yaml \
    --ckpt-path Path/to/instructblipbase_stllm_qa \
    --anno-path Path/to/MVBench/json \
    --output_dir test_output/mvbench/ \
    --output_name instructblipbase_stllm_qa_mvbench_fps1 \
    --num-frames 0 \
    --ask_simple \
    
    



================================================
FILE: script/inference/qabench/anet_qa.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python script/inference/qabench/anet_qa.sh \
    --cfg-path config/instructblipbase_stllm_qa.yaml \
    --ckpt-path /Path/to/STLLM_QA_weight \
    --video_dir /Path/to/Anet/videos \
    --gt_file_question /Path/to/Anet/test_q.json \
    --gt_file_answers /Path/to/Anet/test_a.json \
    --output_dir test_output/qabench/ \
    --output_name stllm_instructblipbase_anetqa \
    --num-frames 16 \
    
    
    

================================================
FILE: script/inference/qabench/msrvtt_qa.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/qabench/msrvtt_qa.py \
    --cfg-path config/instructblipbase_stllm_qa.yaml \
    --ckpt-path /Path/to/STLLM_QA_weight \
    --video_dir /Path/to/MSRVTT-QA/video/ \
    --gt_file /Path/to/MSRVTT-QA/test_qa.json \
    --output_dir test_output/qabench/ \
    --output_name stllm_instructblipbase_msrvttqa \
    --num-frames 64 \
    
    
    

================================================
FILE: script/inference/qabench/msvd_qa.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/qabench/msvd_qa.py \
    --cfg-path config/instructblipbase_stllm_qa.yaml \
    --ckpt-path /Path/to/STLLM_QA_weight \
    --video_dir /Path/to/MSVD/YouTubeClips \
    --gt_file /Path/to/MSVD-QA/test_qa.json \
    --output_dir test_output/qabench/ \
    --output_name stllm_instructblipbase_msvdqa \
    --num-frames 64 \

    
    

================================================
FILE: script/inference/qabench/score_anet.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \
    --pred_path test_output/qabench/stllm_instructblipbase_anetqa.json \
    --output_dir test_output/qabench/activityQA/stllm_instructblipbase \
    --output_json test_output/qabench/activityQA/stllm_instructblipbase/activityQA.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/qabench/score_msrvtt.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \
    --pred_path test_output/qabench/stllm_instructblipbase_msrvttqa.json \
    --output_dir test_output/qabench/msrvttQA/stllm_instructblipbase \
    --output_json test_output/qabench/msrvttQA/stllm_instructblipbase/msrvttQA.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/qabench/score_msvd.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \
    --pred_path test_output/qabench/stllm_instructblipbase_msvdqa.json \
    --output_dir test_output/qabench/msvdQA/stllm_instructblipbase \
    --output_json test_output/qabench/msvdQA/stllm_instructblipbase/msvdQA.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/score_consist.sh
================================================
python stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py \
    --pred_path test_output/vcgbench/stllm_instructblipbase_consist.json \
    --output_dir test_output/vcgbench/consist/stllm_instructblipbase \
    --output_json test_output/vcgbench/consist/stllm_instructblipbase/consist.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/score_context.sh
================================================
python stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py \
    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \
    --output_dir test_output/vcgbench/context/stllm_instructblipbase \
    --output_json test_output/vcgbench/context/stllm_instructblipbase/context.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/score_correct.sh
================================================
python stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py \
    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \
    --output_dir test_output/vcgbench/correctness/stllm_instructblipbase \
    --output_json test_output/vcgbench/correctness/stllm_instructblipbase/correctness.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/score_detail.sh
================================================
python stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py \
    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \
    --output_dir test_output/vcgbench/detail/stllm_instructblipbase \
    --output_json test_output/vcgbench/detail/stllm_instructblipbase/detail.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/score_temporal.sh
================================================
python stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py \
    --pred_path test_output/vcgbench/stllm_instructblipbase_temporal.json \
    --output_dir test_output/vcgbench/temporal/stllm_instructblipbase \
    --output_json test_output/vcgbench/temporal/stllm_instructblipbase/temporal.json \
    --api_key openai_api_key \
    --num_tasks 3

================================================
FILE: script/inference/vcgbench/test_consist.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/vcgbench/videochatgpt_benchmark_consist.py \
    --cfg-path config/instructblipbase_stllm_conversation.yaml \
    --ckpt-path /Path/to/STLLM_conversation_weight \
    --video_dir /Path/to/video_chatgpt/Test_Videos \
    --gt_file /Path/to/video_chatgpt/Benchmarking_QA/consistency_qa.json \
    --output_dir test_output/vcgbench/ \
    --output_name stllm_instructblipbase_consist \
    --num-frames 64 \
    
    

================================================
FILE: script/inference/vcgbench/test_general.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/vcgbench/videochatgpt_benchmark_general.py \
    --cfg-path config/instructblipbase_stllm_conversation.yaml \
    --ckpt-path /Path/to/STLLM_conversation_weight \
    --video_dir /Path/to/video_chatgpt/Test_Videos \
    --gt_file /Path/to/video_chatgpt/Benchmarking_QA/generic_qa.json \
    --output_dir test_output/vcgbench/ \
    --output_name stllm_instructblipbase_general \
    --num-frames 64 \
    
    

================================================
FILE: script/inference/vcgbench/test_temporal.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
python stllm/test/vcgbench/videochatgpt_benchmark_general.py \
    --cfg-path config/instructblipbase_stllm_conversation.yaml \
    --ckpt-path /Path/to/STLLM_conversation_weight \
    --video_dir /Path/to/video_chatgpt/Test_Videos \
    --gt_file /Path/to/Benchmarking_QA/temporal_qa.json \
    --output_dir test_output/vcgbench/ \
    --output_name stllm_instructblipbase_temporal \
    --num-frames 64 \
    
    

================================================
FILE: script/train/train.sh
================================================
export PYTHONPATH="./:$PYTHONPATH"
deepspeed --master_port=20000 --include=localhost:0,1,2,3,4,5,6,7 stllm/train/train_hf.py --cfg-path /Path/to/desired/config

================================================
FILE: stllm/__init__.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import sys

from omegaconf import OmegaConf

from stllm.common.registry import registry

from stllm.datasets.builders import *
from stllm.models import *
from stllm.processors import *
from stllm.tasks import *


root_dir = os.path.dirname(os.path.abspath(__file__))
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))

registry.register_path("library_root", root_dir)
repo_root = os.path.join(root_dir, "..")
registry.register_path("repo_root", repo_root)
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
registry.register_path("cache_root", cache_root)

registry.register("MAX_INT", sys.maxsize)
registry.register("SPLIT_NAMES", ["train", "val", "test"])


================================================
FILE: stllm/common/__init__.py
================================================


================================================
FILE: stllm/common/config.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import json
from typing import Dict

from omegaconf import OmegaConf
from stllm.common.registry import registry


class Config:
    def __init__(self, args):
        self.config = {}

        self.args = args

        # Register the config and configuration for setup
        registry.register("configuration", self)

        user_config = self._build_opt_list(self.args.options)

        config = OmegaConf.load(self.args.cfg_path)

        runner_config = self.build_runner_config(config)
        model_config = self.build_model_config(config, **user_config)
        dataset_config = self.build_dataset_config(config)

        # Validate the user-provided runner configuration
        # model and dataset configuration are supposed to be validated by the respective classes
        # [TODO] validate the model/dataset configuration
        # self._validate_runner_config(runner_config)

        # Override the default configuration with user options.
        self.config = OmegaConf.merge(
            runner_config, model_config, dataset_config, user_config
        )

    def _validate_runner_config(self, runner_config):
        """
        This method validates the configuration, such that
            1) all the user specified options are valid;
            2) no type mismatches between the user specified options and the config.
        """
        runner_config_validator = create_runner_config_validator()
        runner_config_validator.validate(runner_config)

    def _build_opt_list(self, opts):
        opts_dot_list = self._convert_to_dot_list(opts)
        return OmegaConf.from_dotlist(opts_dot_list)

    @staticmethod
    def build_model_config(config, **kwargs):
        model = config.get("model", None)
        assert model is not None, "Missing model configuration file."

        model_cls = registry.get_model_class(model.arch)
        assert model_cls is not None, f"Model '{model.arch}' has not been registered."

        model_type = kwargs.get("model.model_type", None)
        if not model_type:
            model_type = model.get("model_type", None)
        # else use the model type selected by user.

        assert model_type is not None, "Missing model_type."

        model_config_path = model_cls.default_config_path(model_type=model_type)

        model_config = OmegaConf.create()
        # hierarchy override, customized config > default config
        model_config = OmegaConf.merge(
            model_config,
            OmegaConf.load(model_config_path),
            {"model": config["model"]},
        )

        return model_config

    @staticmethod
    def build_runner_config(config):
        return {"run": config.run}

    @staticmethod
    def build_dataset_config(config):
        datasets = config.get("datasets", None)
        if datasets is None:
            raise KeyError(
                "Expecting 'datasets' as the root key for dataset configuration."
            )

        dataset_config = OmegaConf.create()

        for dataset_name in datasets:
            builder_cls = registry.get_builder_class(dataset_name)

            dataset_config_type = datasets[dataset_name].get("type", "default")
            if builder_cls is not None:
                dataset_config_path = builder_cls.default_config_path(
                    type=dataset_config_type
                )
                default_config = OmegaConf.load(dataset_config_path)
            else:
                default_config = {}
            # hierarchy override, customized config > default config
            dataset_config = OmegaConf.merge(
                dataset_config,
                default_config,
                {"datasets": {dataset_name: config["datasets"][dataset_name]}},
            )

        return dataset_config

    def _convert_to_dot_list(self, opts):
        if opts is None:
            opts = []

        if len(opts) == 0:
            return opts

        has_equal = opts[0].find("=") != -1

        if has_equal:
            return opts

        return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]

    def get_config(self):
        return self.config

    @property
    def run_cfg(self):
        return self.config.run

    @property
    def datasets_cfg(self):
        return self.config.datasets

    @property
    def model_cfg(self):
        return self.config.model

    def pretty_print(self):
        logging.info("\n=====  Running Parameters    =====")
        logging.info(self._convert_node_to_json(self.config.run))

        logging.info("\n======  Dataset Attributes  ======")
        datasets = self.config.datasets

        for dataset in datasets:
            if dataset in self.config.datasets:
                logging.info(f"\n======== {dataset} =======")
                dataset_config = self.config.datasets[dataset]
                logging.info(self._convert_node_to_json(dataset_config))
            else:
                logging.warning(f"No dataset named '{dataset}' in config. Skipping")

        logging.info(f"\n======  Model Attributes  ======")
        logging.info(self._convert_node_to_json(self.config.model))

    def _convert_node_to_json(self, node):
        container = OmegaConf.to_container(node, resolve=True)
        return json.dumps(container, indent=4, sort_keys=True)

    def to_dict(self):
        return OmegaConf.to_container(self.config)


def node_to_dict(node):
    return OmegaConf.to_container(node)


class ConfigValidator:
    """
    This is a preliminary implementation to centralize and validate the configuration.
    May be altered in the future.

    A helper class to validate configurations from yaml file.

    This serves the following purposes:
        1. Ensure all the options in the yaml are defined, raise error if not.
        2. when type mismatches are found, the validator will raise an error.
        3. a central place to store and display helpful messages for supported configurations.

    """

    class _Argument:
        def __init__(self, name, choices=None, type=None, help=None):
            self.name = name
            self.val = None
            self.choices = choices
            self.type = type
            self.help = help

        def __str__(self):
            s = f"{self.name}={self.val}"
            if self.type is not None:
                s += f", ({self.type})"
            if self.choices is not None:
                s += f", choices: {self.choices}"
            if self.help is not None:
                s += f", ({self.help})"
            return s

    def __init__(self, description):
        self.description = description

        self.arguments = dict()

        self.parsed_args = None

    def __getitem__(self, key):
        assert self.parsed_args is not None, "No arguments parsed yet."

        return self.parsed_args[key]

    def __str__(self) -> str:
        return self.format_help()

    def add_argument(self, *args, **kwargs):
        """
        Assume the first argument is the name of the argument.
        """
        self.arguments[args[0]] = self._Argument(*args, **kwargs)

    def validate(self, config=None):
        """
        Convert yaml config (dict-like) to list, required by argparse.
        """
        for k, v in config.items():
            assert (
                k in self.arguments
            ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""

            if self.arguments[k].type is not None:
                try:
                    self.arguments[k].val = self.arguments[k].type(v)
                except ValueError:
                    raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")

            if self.arguments[k].choices is not None:
                assert (
                    v in self.arguments[k].choices
                ), f"""{k} must be one of {self.arguments[k].choices}."""

        return config

    def format_arguments(self):
        return str([f"{k}" for k in sorted(self.arguments.keys())])

    def format_help(self):
        # description + key-value pair string for each argument
        help_msg = str(self.description)
        return help_msg + ", available arguments: " + self.format_arguments()

    def print_help(self):
        # display help message
        print(self.format_help())


def create_runner_config_validator():
    validator = ConfigValidator(description="Runner configurations")

    validator.add_argument(
        "runner",
        type=str,
        choices=["runner_base", "runner_iter"],
        help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
            runner runs based on iters. Default: runner_base""",
    )
    # add argumetns for training dataset ratios
    validator.add_argument(
        "train_dataset_ratios",
        type=Dict[str, float],
        help="""Ratios of training dataset. This is used in iteration-based runner.
        Do not support for epoch-based runner because how to define an epoch becomes tricky.
        Default: None""",
    )
    validator.add_argument(
        "max_iters",
        type=float,
        help="Maximum number of iterations to run.",
    )
    validator.add_argument(
        "max_epoch",
        type=int,
        help="Maximum number of epochs to run.",
    )
    # add arguments for iters_per_inner_epoch
    validator.add_argument(
        "iters_per_inner_epoch",
        type=float,
        help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
    )
    lr_scheds_choices = registry.list_lr_schedulers()
    validator.add_argument(
        "lr_sched",
        type=str,
        choices=lr_scheds_choices,
        help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
    )
    task_choices = registry.list_tasks()
    validator.add_argument(
        "task",
        type=str,
        choices=task_choices,
        help="Task to use, from {}".format(task_choices),
    )
    # add arguments for init_lr
    validator.add_argument(
        "init_lr",
        type=float,
        help="Initial learning rate. This will be the learning rate after warmup and before decay.",
    )
    # add arguments for min_lr
    validator.add_argument(
        "min_lr",
        type=float,
        help="Minimum learning rate (after decay).",
    )
    # add arguments for warmup_lr
    validator.add_argument(
        "warmup_lr",
        type=float,
        help="Starting learning rate for warmup.",
    )
    # add arguments for learning rate decay rate
    validator.add_argument(
        "lr_decay_rate",
        type=float,
        help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
    )
    # add arguments for weight decay
    validator.add_argument(
        "weight_decay",
        type=float,
        help="Weight decay rate.",
    )
    # add arguments for training batch size
    validator.add_argument(
        "batch_size_train",
        type=int,
        help="Training batch size.",
    )
    # add arguments for evaluation batch size
    validator.add_argument(
        "batch_size_eval",
        type=int,
        help="Evaluation batch size, including validation and testing.",
    )
    # add arguments for number of workers for data loading
    validator.add_argument(
        "num_workers",
        help="Number of workers for data loading.",
    )
    # add arguments for warm up steps
    validator.add_argument(
        "warmup_steps",
        type=int,
        help="Number of warmup steps. Required if a warmup schedule is used.",
    )
    # add arguments for random seed
    validator.add_argument(
        "seed",
        type=int,
        help="Random seed.",
    )
    # add arguments for output directory
    validator.add_argument(
        "output_dir",
        type=str,
        help="Output directory to save checkpoints and logs.",
    )
    # add arguments for whether only use evaluation
    validator.add_argument(
        "evaluate",
        help="Whether to only evaluate the model. If true, training will not be performed.",
    )
    # add arguments for splits used for training, e.g. ["train", "val"]
    validator.add_argument(
        "train_splits",
        type=list,
        help="Splits to use for training.",
    )
    # add arguments for splits used for validation, e.g. ["val"]
    validator.add_argument(
        "valid_splits",
        type=list,
        help="Splits to use for validation. If not provided, will skip the validation.",
    )
    # add arguments for splits used for testing, e.g. ["test"]
    validator.add_argument(
        "test_splits",
        type=list,
        help="Splits to use for testing. If not provided, will skip the testing.",
    )
    # add arguments for accumulating gradient for iterations
    validator.add_argument(
        "accum_grad_iters",
        type=int,
        help="Number of iterations to accumulate gradient for.",
    )

    # ====== distributed training ======
    validator.add_argument(
        "device",
        type=str,
        choices=["cpu", "cuda"],
        help="Device to use. Support 'cuda' or 'cpu' as for now.",
    )
    validator.add_argument(
        "world_size",
        type=int,
        help="Number of processes participating in the job.",
    )
    validator.add_argument("dist_url", type=str)
    validator.add_argument("distributed", type=bool)
    # add arguments to opt using distributed sampler during evaluation or not
    validator.add_argument(
        "use_dist_eval_sampler",
        type=bool,
        help="Whether to use distributed sampler during evaluation or not.",
    )

    # ====== task specific ======
    # generation task specific arguments
    # add arguments for maximal length of text output
    validator.add_argument(
        "max_len",
        type=int,
        help="Maximal length of text output.",
    )
    # add arguments for minimal length of text output
    validator.add_argument(
        "min_len",
        type=int,
        help="Minimal length of text output.",
    )
    # add arguments number of beams
    validator.add_argument(
        "num_beams",
        type=int,
        help="Number of beams used for beam search.",
    )

    # vqa task specific arguments
    # add arguments for number of answer candidates
    validator.add_argument(
        "num_ans_candidates",
        type=int,
        help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
    )
    # add arguments for inference method
    validator.add_argument(
        "inference_method",
        type=str,
        choices=["genearte", "rank"],
        help="""Inference method to use for question answering. If rank, requires a answer list.""",
    )

    # ====== model specific ======
    validator.add_argument(
        "k_test",
        type=int,
        help="Number of top k most similar samples from ITC/VTC selection to be tested.",
    )

    return validator


================================================
FILE: stllm/common/dist_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import functools
import os

import torch
import torch.distributed as dist
import timm.models.hub as timm_hub


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def init_distributed_mode(args):
    if args.distributed is False:
        print("Not using distributed mode")
        return
    elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = "nccl"
    print(
        "| distributed init (rank {}, world {}): {}".format(
            args.rank, args.world_size, args.dist_url
        ),
        flush=True,
    )
    torch.distributed.init_process_group(
        backend=args.dist_backend,
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
        timeout=datetime.timedelta(
            days=365
        ),  # allow auto-downloading and de-compressing
    )
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)


def get_dist_info():
    if torch.__version__ < "1.0":
        initialized = dist._initialized
    else:
        initialized = dist.is_initialized()
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:  # non-distributed training
        rank = 0
        world_size = 1
    return rank, world_size


def main_process(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        rank, _ = get_dist_info()
        if rank == 0:
            return func(*args, **kwargs)

    return wrapper


def download_cached_file(url, check_hash=True, progress=False):
    """
    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
    """

    def get_cached_file_path():
        # a hack to sync the file path across processes
        parts = torch.hub.urlparse(url)
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)

        return cached_file

    if is_main_process():
        timm_hub.download_cached_file(url, check_hash, progress)

    if is_dist_avail_and_initialized():
        dist.barrier()

    return get_cached_file_path()


================================================
FILE: stllm/common/gradcam.py
================================================
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform


def getAttMap(img, attMap, blur=True, overlap=True):
    attMap -= attMap.min()
    if attMap.max() > 0:
        attMap /= attMap.max()
    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
    if blur:
        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
        attMap -= attMap.min()
        attMap /= attMap.max()
    cmap = plt.get_cmap("jet")
    attMapV = cmap(attMap)
    attMapV = np.delete(attMapV, 3, 2)
    if overlap:
        attMap = (
            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
        )
    return attMap


================================================
FILE: stllm/common/logger.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import logging
import time
from collections import defaultdict, deque

import torch
import torch.distributed as dist

from stllm.common import dist_utils


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not dist_utils.is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value,
        )


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(
            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
        )

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append("{}: {}".format(name, str(meter)))
        return self.delimiter.join(loss_str)

    def global_avg(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        log_msg = [
            header,
            "[{0" + space_fmt + "}/{1}]",
            "eta: {eta}",
            "{meters}",
            "time: {time}",
            "data: {data}",
        ]
        if torch.cuda.is_available():
            log_msg.append("max mem: {memory:.0f}")
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(
            "{} Total time: {} ({:.4f} s / it)".format(
                header, total_time_str, total_time / len(iterable)
            )
        )


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def setup_logger():
    logging.basicConfig(
        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler()],
    )


================================================
FILE: stllm/common/optims.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import math

from stllm.common.registry import registry


@registry.register_lr_scheduler("linear_warmup_step_lr")
class LinearWarmupStepLRScheduler:
    def __init__(
        self,
        optimizer,
        max_epoch,
        min_lr,
        init_lr,
        decay_rate=1,
        warmup_start_lr=-1,
        warmup_steps=0,
        **kwargs
    ):
        self.optimizer = optimizer

        self.max_epoch = max_epoch
        self.min_lr = min_lr

        self.decay_rate = decay_rate

        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr

    def step(self, cur_epoch, cur_step):
        if cur_epoch == 0:
            warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        else:
            step_lr_schedule(
                epoch=cur_epoch,
                optimizer=self.optimizer,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
                decay_rate=self.decay_rate,
            )

@registry.register_lr_scheduler("linear_warmup_cosine_lr")
class LinearWarmupCosineLRScheduler:
    def __init__(
        self,
        optimizer,
        max_epoch,
        iters_per_epoch,
        min_lr,
        init_lr,
        warmup_steps=0,
        warmup_start_lr=-1,
        **kwargs
    ):
        self.optimizer = optimizer

        self.max_epoch = max_epoch
        self.iters_per_epoch = iters_per_epoch
        self.min_lr = min_lr

        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr

    def step(self, cur_epoch, cur_step):
        total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
        if total_cur_step < self.warmup_steps:
            warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        else:
            cosine_lr_schedule(
                epoch=total_cur_step,
                optimizer=self.optimizer,
                max_epoch=self.max_epoch * self.iters_per_epoch,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
            )

def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
    """Decay the learning rate"""
    lr = (init_lr - min_lr) * 0.5 * (
        1.0 + math.cos(math.pi * epoch / max_epoch)
    ) + min_lr
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
    """Warmup the learning rate"""
    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
    """Decay the learning rate"""
    lr = max(min_lr, init_lr * (decay_rate**epoch))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


================================================
FILE: stllm/common/registry.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""


class Registry:
    mapping = {
        "builder_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
        "runner_name_mapping": {},
        "state": {},
        "paths": {},
    }

    @classmethod
    def register_builder(cls, name):
        r"""Register a dataset builder to registry with key 'name'

        Args:
            name: Key with which the builder will be registered.

        Usage:

            from stllm.common.registry import registry
            from stllm.datasets.base_dataset_builder import BaseDatasetBuilder
        """

        def wrap(builder_cls):
            from stllm.datasets.builders.base_dataset_builder import BaseDatasetBuilder

            assert issubclass(
                builder_cls, BaseDatasetBuilder
            ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
                builder_cls
            )
            if name in cls.mapping["builder_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["builder_name_mapping"][name]
                    )
                )
            cls.mapping["builder_name_mapping"][name] = builder_cls
            return builder_cls

        return wrap

    @classmethod
    def register_task(cls, name):
        r"""Register a task to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from stllm.common.registry import registry
        """

        def wrap(task_cls):
            from stllm.tasks.base_task import BaseTask

            assert issubclass(
                task_cls, BaseTask
            ), "All tasks must inherit BaseTask class"
            if name in cls.mapping["task_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["task_name_mapping"][name]
                    )
                )
            cls.mapping["task_name_mapping"][name] = task_cls
            return task_cls

        return wrap

    @classmethod
    def register_model(cls, name):
        r"""Register a task to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from stllm.common.registry import registry
        """

        def wrap(model_cls):
            from stllm.models import BaseModel

            assert issubclass(
                model_cls, BaseModel
            ), "All models must inherit BaseModel class"
            if name in cls.mapping["model_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["model_name_mapping"][name]
                    )
                )
            cls.mapping["model_name_mapping"][name] = model_cls
            return model_cls

        return wrap

    @classmethod
    def register_processor(cls, name):
        r"""Register a processor to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from stllm.common.registry import registry
        """

        def wrap(processor_cls):
            from stllm.processors import BaseProcessor

            assert issubclass(
                processor_cls, BaseProcessor
            ), "All processors must inherit BaseProcessor class"
            if name in cls.mapping["processor_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["processor_name_mapping"][name]
                    )
                )
            cls.mapping["processor_name_mapping"][name] = processor_cls
            return processor_cls

        return wrap

    @classmethod
    def register_lr_scheduler(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from stllm.common.registry import registry
        """

        def wrap(lr_sched_cls):
            if name in cls.mapping["lr_scheduler_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["lr_scheduler_name_mapping"][name]
                    )
                )
            cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
            return lr_sched_cls

        return wrap

    @classmethod
    def register_runner(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from stllm.common.registry import registry
        """

        def wrap(runner_cls):
            if name in cls.mapping["runner_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["runner_name_mapping"][name]
                    )
                )
            cls.mapping["runner_name_mapping"][name] = runner_cls
            return runner_cls

        return wrap

    @classmethod
    def register_path(cls, name, path):
        r"""Register a path to registry with key 'name'

        Args:
            name: Key with which the path will be registered.

        Usage:

            from stllm.common.registry import registry
        """
        assert isinstance(path, str), "All path must be str."
        if name in cls.mapping["paths"]:
            raise KeyError("Name '{}' already registered.".format(name))
        cls.mapping["paths"][name] = path

    @classmethod
    def register(cls, name, obj):
        r"""Register an item to registry with key 'name'

        Args:
            name: Key with which the item will be registered.

        Usage::

            from stllm.common.registry import registry

            registry.register("config", {})
        """
        path = name.split(".")
        current = cls.mapping["state"]

        for part in path[:-1]:
            if part not in current:
                current[part] = {}
            current = current[part]

        current[path[-1]] = obj

    # @classmethod
    # def get_trainer_class(cls, name):
    #     return cls.mapping["trainer_name_mapping"].get(name, None)

    @classmethod
    def get_builder_class(cls, name):
        return cls.mapping["builder_name_mapping"].get(name, None)

    @classmethod
    def get_model_class(cls, name):
        return cls.mapping["model_name_mapping"].get(name, None)

    @classmethod
    def get_task_class(cls, name):
        return cls.mapping["task_name_mapping"].get(name, None)

    @classmethod
    def get_processor_class(cls, name):
        return cls.mapping["processor_name_mapping"].get(name, None)

    @classmethod
    def get_lr_scheduler_class(cls, name):
        return cls.mapping["lr_scheduler_name_mapping"].get(name, None)

    @classmethod
    def get_runner_class(cls, name):
        return cls.mapping["runner_name_mapping"].get(name, None)

    @classmethod
    def list_runners(cls):
        return sorted(cls.mapping["runner_name_mapping"].keys())

    @classmethod
    def list_models(cls):
        return sorted(cls.mapping["model_name_mapping"].keys())

    @classmethod
    def list_tasks(cls):
        return sorted(cls.mapping["task_name_mapping"].keys())

    @classmethod
    def list_processors(cls):
        return sorted(cls.mapping["processor_name_mapping"].keys())

    @classmethod
    def list_lr_schedulers(cls):
        return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())

    @classmethod
    def list_datasets(cls):
        return sorted(cls.mapping["builder_name_mapping"].keys())

    @classmethod
    def get_path(cls, name):
        return cls.mapping["paths"].get(name, None)

    @classmethod
    def get(cls, name, default=None, no_warning=False):
        r"""Get an item from registry with key 'name'

        Args:
            name (string): Key whose value needs to be retrieved.
            default: If passed and key is not in registry, default value will
                     be returned with a warning. Default: None
            no_warning (bool): If passed as True, warning when key doesn't exist
                               will not be generated. Useful for MMF's
                               internal operations. Default: False
        """
        original_name = name
        name = name.split(".")
        value = cls.mapping["state"]
        for subname in name:
            value = value.get(subname, default)
            if value is default:
                break

        if (
            "writer" in cls.mapping["state"]
            and value == default
            and no_warning is False
        ):
            cls.mapping["state"]["writer"].warning(
                "Key {} is not present in registry, returning default value "
                "of {}".format(original_name, default)
            )
        return value

    @classmethod
    def unregister(cls, name):
        r"""Remove an item from registry with key 'name'

        Args:
            name: Key which needs to be removed.
        Usage::

            from mmf.common.registry import registry

            config = registry.unregister("config")
        """
        return cls.mapping["state"].pop(name, None)


registry = Registry()


================================================
FILE: stllm/common/utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import io
import json
import logging
import os
import pickle
import re
import shutil
import urllib
import urllib.error
import urllib.request
from typing import Optional
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import yaml
from iopath.common.download import download
from iopath.common.file_io import file_lock, g_pathmgr
from stllm.common.registry import registry
from torch.utils.model_zoo import tqdm
from torchvision.datasets.utils import (
    check_integrity,
    download_file_from_google_drive,
    extract_archive,
)


def now():
    from datetime import datetime

    return datetime.now().strftime("%Y%m%d%H%M")[:-1]


def is_url(url_or_filename):
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")


def get_cache_path(rel_path):
    return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))


def get_abs_path(rel_path):
    return os.path.join(registry.get_path("library_root"), rel_path)


def load_json(filename):
    with open(filename, "r") as f:
        return json.load(f)


# The following are adapted from torchvision and vissl
# torchvision: https://github.com/pytorch/vision
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py


def makedir(dir_path):
    """
    Create the directory if it does not exist.
    """
    is_success = False
    try:
        if not g_pathmgr.exists(dir_path):
            g_pathmgr.mkdirs(dir_path)
        is_success = True
    except BaseException:
        print(f"Error creating directory: {dir_path}")
    return is_success


def get_redirected_url(url: str):
    """
    Given a URL, returns the URL it redirects to or the
    original URL in case of no indirection
    """
    import requests

    with requests.Session() as session:
        with session.get(url, stream=True, allow_redirects=True) as response:
            if response.history:
                return response.url
            else:
                return url


def to_google_drive_download_url(view_url: str) -> str:
    """
    Utility function to transform a view URL of google drive
    to a download URL for google drive
    Example input:
        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
    Example output:
        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
    """
    splits = view_url.split("/")
    assert splits[-1] == "view"
    file_id = splits[-2]
    return f"https://drive.google.com/uc?export=download&id={file_id}"


def download_google_drive_url(url: str, output_path: str, output_file_name: str):
    """
    Download a file from google drive
    Downloading an URL from google drive requires confirmation when
    the file of the size is too big (google drive notifies that
    anti-viral checks cannot be performed on such files)
    """
    import requests

    with requests.Session() as session:

        # First get the confirmation token and append it to the URL
        with session.get(url, stream=True, allow_redirects=True) as response:
            for k, v in response.cookies.items():
                if k.startswith("download_warning"):
                    url = url + "&confirm=" + v

        # Then download the content of the file
        with session.get(url, stream=True, verify=True) as response:
            makedir(output_path)
            path = os.path.join(output_path, output_file_name)
            total_size = int(response.headers.get("Content-length", 0))
            with open(path, "wb") as file:
                from tqdm import tqdm

                with tqdm(total=total_size) as progress_bar:
                    for block in response.iter_content(
                        chunk_size=io.DEFAULT_BUFFER_SIZE
                    ):
                        file.write(block)
                        progress_bar.update(len(block))


def _get_google_drive_file_id(url: str) -> Optional[str]:
    parts = urlparse(url)

    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
        return None

    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
    if match is None:
        return None

    return match.group("id")


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
    with open(filename, "wb") as fh:
        with urllib.request.urlopen(
            urllib.request.Request(url, headers={"User-Agent": "vissl"})
        ) as response:
            with tqdm(total=response.length) as pbar:
                for chunk in iter(lambda: response.read(chunk_size), ""):
                    if not chunk:
                        break
                    pbar.update(chunk_size)
                    fh.write(chunk)


def download_url(
    url: str,
    root: str,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
) -> None:
    """Download a file from a url and place it in root.
    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under.
                                  If None, use the basename of the URL.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)

    makedir(root)

    # check if file is already present locally
    if check_integrity(fpath, md5):
        print("Using downloaded and verified file: " + fpath)
        return

    # expand redirect chain if needed
    url = get_redirected_url(url)

    # check if file is located on Google Drive
    file_id = _get_google_drive_file_id(url)
    if file_id is not None:
        return download_file_from_google_drive(file_id, root, filename, md5)

    # download the file
    try:
        print("Downloading " + url + " to " + fpath)
        _urlretrieve(url, fpath)
    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]
        if url[:5] == "https":
            url = url.replace("https:", "http:")
            print(
                "Failed download. Trying https -> http instead."
                " Downloading " + url + " to " + fpath
            )
            _urlretrieve(url, fpath)
        else:
            raise e

    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")


def download_and_extract_archive(
    url: str,
    download_root: str,
    extract_root: Optional[str] = None,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)

    download_url(url, download_root, filename, md5)

    archive = os.path.join(download_root, filename)
    print("Extracting {} to {}".format(archive, extract_root))
    extract_archive(archive, extract_root, remove_finished)


def cache_url(url: str, cache_dir: str) -> str:
    """
    This implementation downloads the remote resource and caches it locally.
    The resource will only be downloaded if not previously requested.
    """
    parsed_url = urlparse(url)
    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
    makedir(dirname)
    filename = url.split("/")[-1]
    cached = os.path.join(dirname, filename)
    with file_lock(cached):
        if not os.path.isfile(cached):
            logging.info(f"Downloading {url} to {cached} ...")
            cached = download(url, dirname, filename=filename)
    logging.info(f"URL {url} cached in {cached}")
    return cached


# TODO (prigoyal): convert this into RAII-style API
def create_file_symlink(file1, file2):
    """
    Simply create the symlinks for a given file1 to file2.
    Useful during model checkpointing to symlinks to the
    latest successful checkpoint.
    """
    try:
        if g_pathmgr.exists(file2):
            g_pathmgr.rm(file2)
        g_pathmgr.symlink(file1, file2)
    except Exception as e:
        logging.info(f"Could NOT create symlink. Error: {e}")


def save_file(data, filename, append_to_json=True, verbose=True):
    """
    Common i/o utility to handle saving data to various file formats.
    Supported:
        .pkl, .pickle, .npy, .json
    Specifically for .json, users have the option to either append (default)
    or rewrite by passing in Boolean value to append_to_json.
    """
    if verbose:
        logging.info(f"Saving data to file: {filename}")
    file_ext = os.path.splitext(filename)[1]
    if file_ext in [".pkl", ".pickle"]:
        with g_pathmgr.open(filename, "wb") as fopen:
            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
    elif file_ext == ".npy":
        with g_pathmgr.open(filename, "wb") as fopen:
            np.save(fopen, data)
    elif file_ext == ".json":
        if append_to_json:
            with g_pathmgr.open(filename, "a") as fopen:
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
                fopen.flush()
        else:
            with g_pathmgr.open(filename, "w") as fopen:
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
                fopen.flush()
    elif file_ext == ".yaml":
        with g_pathmgr.open(filename, "w") as fopen:
            dump = yaml.dump(data)
            fopen.write(dump)
            fopen.flush()
    else:
        raise Exception(f"Saving {file_ext} is not supported yet")

    if verbose:
        logging.info(f"Saved data to file: {filename}")


def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
    """
    Common i/o utility to handle loading data from various file formats.
    Supported:
        .pkl, .pickle, .npy, .json
    For the npy files, we support reading the files in mmap_mode.
    If the mmap_mode of reading is not successful, we load data without the
    mmap_mode.
    """
    if verbose:
        logging.info(f"Loading data from file: {filename}")

    file_ext = os.path.splitext(filename)[1]
    if file_ext == ".txt":
        with g_pathmgr.open(filename, "r") as fopen:
            data = fopen.readlines()
    elif file_ext in [".pkl", ".pickle"]:
        with g_pathmgr.open(filename, "rb") as fopen:
            data = pickle.load(fopen, encoding="latin1")
    elif file_ext == ".npy":
        if mmap_mode:
            try:
                with g_pathmgr.open(filename, "rb") as fopen:
                    data = np.load(
                        fopen,
                        allow_pickle=allow_pickle,
                        encoding="latin1",
                        mmap_mode=mmap_mode,
                    )
            except ValueError as e:
                logging.info(
                    f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
                )
                data = np.load(
                    filename,
                    allow_pickle=allow_pickle,
                    encoding="latin1",
                    mmap_mode=mmap_mode,
                )
                logging.info("Successfully loaded without g_pathmgr")
            except Exception:
                logging.info("Could not mmap without g_pathmgr. Trying without mmap")
                with g_pathmgr.open(filename, "rb") as fopen:
                    data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
        else:
            with g_pathmgr.open(filename, "rb") as fopen:
                data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
    elif file_ext == ".json":
        with g_pathmgr.open(filename, "r") as fopen:
            data = json.load(fopen)
    elif file_ext == ".yaml":
        with g_pathmgr.open(filename, "r") as fopen:
            data = yaml.load(fopen, Loader=yaml.FullLoader)
    elif file_ext == ".csv":
        with g_pathmgr.open(filename, "r") as fopen:
            data = pd.read_csv(fopen)
    else:
        raise Exception(f"Reading from {file_ext} is not supported yet")
    return data


def abspath(resource_path: str):
    """
    Make a path absolute, but take into account prefixes like
    "http://" or "manifold://"
    """
    regex = re.compile(r"^\w+://")
    if regex.match(resource_path) is None:
        return os.path.abspath(resource_path)
    else:
        return resource_path


def makedir(dir_path):
    """
    Create the directory if it does not exist.
    """
    is_success = False
    try:
        if not g_pathmgr.exists(dir_path):
            g_pathmgr.mkdirs(dir_path)
        is_success = True
    except BaseException:
        logging.info(f"Error creating directory: {dir_path}")
    return is_success


def is_url(input_url):
    """
    Check if an input string is a url. look for http(s):// and ignoring the case
    """
    is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
    return is_url


def cleanup_dir(dir):
    """
    Utility for deleting a directory. Useful for cleaning the storage space
    that contains various training artifacts like checkpoints, data etc.
    """
    if os.path.exists(dir):
        logging.info(f"Deleting directory: {dir}")
        shutil.rmtree(dir)
    logging.info(f"Deleted contents of directory: {dir}")


def get_file_size(filename):
    """
    Given a file, get the size of file in MB
    """
    size_in_mb = os.path.getsize(filename) / float(1024**2)
    return size_in_mb


================================================
FILE: stllm/configs/datasets/cc_sbu/align.yaml
================================================
datasets:
  cc_sbu_align:
    data_type: images
    build_info:
      storage: cc_sbu_align


================================================
FILE: stllm/configs/datasets/cc_sbu/defaults.yaml
================================================
datasets:
  cc_sbu:
    data_type: images
    build_info:
      storage: /path/to/cc_sbu_dataset/{00000..01255}.tar


================================================
FILE: stllm/configs/datasets/laion/defaults.yaml
================================================
datasets:
  laion:
    data_type: images
    build_info:
      storage: /path/to/laion_dataset/{00000..10488}.tar


================================================
FILE: stllm/configs/default.yaml
================================================
env:
  # For default users
  # cache_root: "cache"
  # For internal use with persistent storage
  cache_root: "/export/home/.cache/minigpt4"


================================================
FILE: stllm/configs/models/instructblip_vicuna0.yaml
================================================
model:
  arch: st_llm_hf

  # vit encoder
  image_size: 224
  drop_path_rate: 0
  use_grad_checkpoint: False
  vit_precision: "fp16"
  freeze_vit: True
  freeze_qformer: True
  
  # Q-Former
  #q_former_model: '/path/to/instruct_blip_vicuna7b_trimmed.pth'
  q_former_model: 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth'
  num_query_token: 32

  # generation configs
  prompt: ""

  llama_model: '/path/to/vicuna-7b-v1.1'

preprocess:
    vis_processor:
        train:
          name: "blip2_image_train"
          image_size: 224
        eval:
          name: "blip2_image_eval"
          image_size: 224
    text_processor:
        train:
          name: "blip_caption"
        eval:
          name: "blip_caption"


================================================
FILE: stllm/configs/models/instructblip_vicuna0_btadapter.yaml
================================================
model:
  arch: st_llm_hf

  # vit encoder
  vit_model: "eva_btadapter_g"
  image_size: 224
  drop_path_rate: 0
  use_grad_checkpoint: False
  vit_precision: "fp16"
  freeze_vit: True
  freeze_qformer: True

  # Q-Former
  #q_former_model: '/path/to/instruct_blip_vicuna7b_trimmed.pth'
  q_former_model: 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth'
  num_query_token: 32

  # generation configs
  prompt: ""

  llama_model: "/path/to/vicuna-7b-v1.1"

preprocess:
    vis_processor:
        train:
          name: "blip2_image_train"
          image_size: 224
        eval:
          name: "blip2_image_eval"
          image_size: 224
    text_processor:
        train:
          name: "blip_caption"
        eval:
          name: "blip_caption"


================================================
FILE: stllm/configs/models/minigpt4_vicuna0.yaml
================================================
model:
  arch: st_llm_hf

  # vit encoder
  image_size: 224
  drop_path_rate: 0
  use_grad_checkpoint: False
  vit_precision: "fp16"
  freeze_vit: True
  freeze_qformer: True

  # Q-Former
  #q_former_model: "/path/to/blip2_pretrained_flant5xxl.pth"
  q_former_model: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"
  num_query_token: 32

  # generation configs
  prompt: ""

  llama_model: "/path/to/vicuna-7b"

preprocess:
    vis_processor:
        train:
          name: "blip2_image_train"
          image_size: 224
        eval:
          name: "blip2_image_eval"
          image_size: 224
    text_processor:
        train:
          name: "blip_caption"
        eval:
          name: "blip_caption"


================================================
FILE: stllm/configs/models/minigpt4_vicuna0_btadapter.yaml
================================================
model:
  arch: st_llm_hf

  # vit encoder
  vit_model: "eva_btadapter_g"
  image_size: 224
  drop_path_rate: 0
  use_grad_checkpoint: False
  vit_precision: "fp16"
  freeze_vit: True
  freeze_qformer: True

  # Q-Former
  #q_former_model: "/path/to/blip2_pretrained_flant5xxl.pth"
  q_former_model: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"
  num_query_token: 32

  # generation configs
  prompt: ""

  llama_model: "/path/to/vicuna-7b"

preprocess:
    vis_processor:
        train:
          name: "blip2_image_train"
          image_size: 224
        eval:
          name: "blip2_image_eval"
          image_size: 224
    text_processor:
        train:
          name: "blip_caption"
        eval:
          name: "blip_caption"


================================================
FILE: stllm/conversation/__init__.py
================================================


================================================
FILE: stllm/conversation/conversation.py
================================================
import argparse
import time
import numpy as np
from PIL import Image

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList

import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any

from stllm.common.registry import registry
from stllm.test.video_utils import load_video
import torchvision.transforms as T
from stllm.test.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from torchvision.transforms.functional import InterpolationMode


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()


@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    # system_img: List[Image.Image] = []
    instruction: bool
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + message + seps[i % 2]
                else:
                    ret += role
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            # system_img=self.system_img,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            instruction=self.instruction,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id)

    def dict(self):
        return {
            "system": self.system,
            # "system_img": self.system_img,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False

def get_residual_index(sample_segments, total_segments, devices):
    seg_size = float(total_segments) / sample_segments
    frame_indices = np.array([
    int((seg_size / 2) + np.round(seg_size * idx))
    for idx in range(sample_segments)
    ])
    frame_indices = torch.from_numpy(frame_indices).to(devices)
    return frame_indices

CONV_VISION_Vicuna0 = Conversation(
    system="Give the following image: <Img>ImageContent</Img>. "
           "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("Human: ", "Assistant: "),
    messages=[],
    offset=2,
    instruction=True,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)

CONV_VIDEO_Vicuna0 = Conversation(
    system="Give the following video: <Video>VideoContent</Video>. "
           "You will be able to see the video once I provide it to you. Please answer my questions.",
    roles=("Human: ", "Assistant: "),
    messages=[],
    offset=2,
    instruction=True,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)

CONV_instructblip_Vicuna0 = Conversation(
    system="Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, give your answer that best addresses the question.\n",
    roles=("Human: ", "Assistant: "),
    messages=[],
    instruction=False,
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)

CONV_VISION_LLama2 = Conversation(
    system="Give the following image: <Img>ImageContent</Img>. "
           "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("<s>[INST] ", " [/INST] "),
    messages=[],
    offset=2,
    instruction=True,
    sep_style=SeparatorStyle.SINGLE,
    sep="",
)

CONV_VIDEO_LLama2 = Conversation(
    system="Give the following video: <Img>VideoContent</Img>. "
           "You will be able to see the video once I provide it to you. Please answer my questions.",
    roles=("<s>[INST] ", " [/INST] "),
    messages=[],
    offset=2,
    instruction=True,
    sep_style=SeparatorStyle.SINGLE,
    sep="",
)

class Chat:
    def __init__(self, model, device='cuda:0'):
        self.device = device
        self.model = model
        if not hasattr(model,'llama_model'):
            if hasattr(model.model,'stllm_model'):
                self.model = model.model.stllm_model
            else:
                self.model = model.model.model.stllm_model
            self.LLM = model

        input_mean = [0.48145466, 0.4578275, 0.40821073]
        input_std = [0.26862954, 0.26130258, 0.27577711]
        self.transform = T.Compose([
            GroupScale(int(224), interpolation=InterpolationMode.BICUBIC),
            GroupCenterCrop(224),
            Stack(),
            ToTorchFormatTensor(),
            GroupNormalize(input_mean, input_std) 
        ])
        stop_words_ids = [torch.tensor([835]).to(self.device),
                          torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def ask(self, text, conv):
        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
                and (conv.messages[-1][1][-6:] == '</Img>' or conv.messages[-1][1][-8:] == '</Video>' 
                    or conv.messages[-1][1][-8:] == '</Frame>'):  # last message is image.
            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        else:
            conv.append_message(conv.roles[0], text)

    def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, system=True,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, do_sample=True):
        conv.append_message(conv.roles[1], None)
        if conv.instruction:
            embs, attention_mask = self.get_context_emb(conv, img_list)
        else:
            embs, attention_mask = self.get_context_emb_sim(conv, img_list, system=system)
            repetition_penalty = 1.5

        current_max_len = embs.shape[1] + max_new_tokens
        if current_max_len - max_length > 0:
            print('Warning: The number of tokens in current conversation exceeds the max length. '
                  'The model will not see the contexts outside the range.')
        begin_idx = max(0, current_max_len - max_length)

        embs = embs[:, begin_idx:]

        llama_model = self.LLM if hasattr(self,'LLM') else self.model.llama_model
        outputs = llama_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            #attention_mask=attention_mask,
            stopping_criteria=self.stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
        output_token = outputs[0]
        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('###')[0]  # remove the stop sign '###'
        output_text = output_text.split('Assistant:')[-1].strip()
        conv.messages[-1][1] = output_text
        return output_text, output_token.cpu().numpy()

    def upload_img(self, image, conv, img_list):
        if isinstance(image, str):  # is a image path
            raw_image = Image.open(image).convert('RGB')
            image = self.transform([raw_image]).to(self.device)
        elif isinstance(image, Image.Image):
            raw_image = image
            image = self.transform([raw_image]).to(self.device)
        elif isinstance(image, torch.Tensor):
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            image = image.to(self.device)

        image_emb, _ = self.model.encode_img(image)
        img_list.append(image_emb)
        conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg

    def upload_video(self, video, conv, img_list, num_frame=64, text=None):
        raw_frames = load_video(video, num_frm=num_frame) if isinstance(video,str) else video
        video_frames = self.transform(raw_frames).to(self.device) 
        bt, w, h = video_frames.size()
        video_frames = video_frames.view(bt//3,3,w,h)

        video_emb, _, _ = self.model.encode_img(video_frames, text=text)
        if self.model.video_input == 'mean':
            video_emb = video_emb.mean(dim=0, keepdim=True)
        elif self.model.video_input == 'all':
            video_emb = video_emb.view(1, -1, video_emb.size(-1))
        elif self.model.video_input == 'residual':
            T = video_emb.size(0)
            residual_size = self.model.residual_size
            residual_index = get_residual_index(residual_size, T, video_emb.device)
            global_embeds = video_emb.mean(dim=0, keepdim=True)
            local_embeds = video_emb[residual_index]
            global_embeds = global_embeds.expand((residual_size,-1,-1)).to(self.model.up_proj.weight.dtype)
            global_embeds = self.model.up_proj(self.model.non_linear_func(self.model.down_proj(global_embeds)))
            video_emb = (local_embeds + global_embeds).view(1,-1,video_emb.size(-1)).contiguous()
        
        img_list.append(video_emb)
        sign='<Video><ImageHere></Video>'
        conv.append_message(conv.roles[0], sign)
        msg = "Received."
        return msg
    
    def get_context_emb(self, conv, img_list):
        prompt = conv.get_prompt()
        prompt_segs = prompt.split('<ImageHere>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        if hasattr(self.model, "embed_tokens"):
            embed_tokens = self.model.embed_tokens
        elif hasattr(self.model.llama_model.model, "embed_tokens"):
            embed_tokens = self.model.llama_model.model.embed_tokens
        else:
            embed_tokens = self.model.llama_model.model.model.embed_tokens
        seg_embs = [embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs, None
    
    def get_context_emb_sim(self, conv, img_list, system=True):
        question = conv.messages[0][1]
        question = question.split('</Video> ')[1]
        system = conv.system if system else ""
        question = system + "###Human: " + question + " ###Assistant: "
        seg_tokens = self.model.llama_tokenizer(
                [question], return_tensors="pt", add_special_tokens=0 == 0).to(self.device)
        
        if hasattr(self.model, "embed_tokens"):
            embed_tokens = self.model.embed_tokens
        elif hasattr(self.model.llama_model.model, "embed_tokens"):
            embed_tokens = self.model.llama_model.model.embed_tokens
        else:
            embed_tokens = self.model.llama_model.model.model.embed_tokens
        seg_embs = embed_tokens(seg_tokens.input_ids) 
        mixed_embs = torch.cat((img_list[0],seg_embs), dim=1)
        atts_img = torch.ones(img_list[0].size()[:-1], dtype=torch.long).to(mixed_embs.device)
        attention_mask = torch.cat([atts_img, seg_tokens.attention_mask], dim=1)
        return mixed_embs, attention_mask
        




================================================
FILE: stllm/conversation/mvbench_conversation.py
================================================
import torch
import numpy as np
from transformers import StoppingCriteria, StoppingCriteriaList

def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + ": " + message + conv.sep
        else:
            ret += role + ":"
    return ret

def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + ": " + message
        else:
            if message:
                ret += role + ": " + message + conv.sep
            else:
                ret += role + ":"
    return ret

def get_context_emb(conv, model, img_list, answer_prompt=None):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."

    if hasattr(model.model,'stllm_model'):
        model = model.model.stllm_model
    else:
        model = model.model.model.stllm_model
    if hasattr(model, "embed_tokens"):
        embed_tokens = model.embed_tokens
    elif hasattr(model.llama_model.model, "embed_tokens"):
        embed_tokens = model.llama_model.model.embed_tokens
    else:
        embed_tokens = model.llama_model.model.model.embed_tokens
        
    with torch.no_grad():
        seg_tokens = [
            model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs

def get_context_emb_sim(conv, model, img_list, answer_prompt=None):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    question = prompt.split('</Video>\n')[1]
    if hasattr(model.model,'stllm_model'):
        model = model.model.stllm_model
    else:
        model = model.model.model.stllm_model

    if hasattr(model, "embed_tokens"):
        embed_tokens = model.embed_tokens
    elif hasattr(model.llama_model.model, "embed_tokens"):
        embed_tokens = model.llama_model.model.embed_tokens
    else:
        embed_tokens = model.llama_model.model.model.embed_tokens

    with torch.no_grad():
        seg_tokens = model.llama_tokenizer(
                [question], return_tensors="pt", add_special_tokens=0 == 0).to("cuda:0")
        seg_embs = embed_tokens(seg_tokens.input_ids) 
    mixed_embs = torch.cat((img_list[0],seg_embs), dim=1)
    return mixed_embs

def ask(text, conv):
    conv.messages.append([conv.roles[0], text + '\n'])       

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
       
def answer(conv, model, img_list, ask_simple=False, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None):
    stop_words_ids = [
        torch.tensor([835]).to("cuda:0"),
        torch.tensor([2277, 29937]).to("cuda:0")]  # '###' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    if ask_simple:
        embs = get_context_emb_sim(conv, model, img_list, answer_prompt=answer_prompt)
    else:
        embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt)
    with torch.no_grad():
        generate_model = model if not hasattr(model,'llama_model') else model.llama_model
        outputs = generate_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
            
    if hasattr(model,'llama_model'):
        model = model
    elif hasattr(model.model,'stllm_model'):
        model = model.model.stllm_model
    else:
        model = model.model.model.stllm_model
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    conv.messages[-1][1] = output_text
    return output_text, output_token.cpu().numpy()

class EasyDict(dict):
    """
    Get attributes

    >>> d = EasyDict({'foo':3})
    >>> d['foo']
    3
    >>> d.foo
    3
    >>> d.bar
    Traceback (most recent call last):
    ...
    AttributeError: 'EasyDict' object has no attribute 'bar'

    Works recursively

    >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
    >>> isinstance(d.bar, dict)
    True
    >>> d.bar.x
    1

    Bullet-proof

    >>> EasyDict({})
    {}
    >>> EasyDict(d={})
    {}
    >>> EasyDict(None)
    {}
    >>> d = {'a': 1}
    >>> EasyDict(**d)
    {'a': 1}

    Set attributes

    >>> d = EasyDict()
    >>> d.foo = 3
    >>> d.foo
    3
    >>> d.bar = {'prop': 'value'}
    >>> d.bar.prop
    'value'
    >>> d
    {'foo': 3, 'bar': {'prop': 'value'}}
    >>> d.bar.prop = 'newer'
    >>> d.bar.prop
    'newer'


    Values extraction

    >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
    >>> isinstance(d.bar, list)
    True
    >>> from operator import attrgetter
    >>> map(attrgetter('x'), d.bar)
    [1, 3]
    >>> map(attrgetter('y'), d.bar)
    [2, 4]
    >>> d = EasyDict()
    >>> d.keys()
    []
    >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
    >>> d.foo
    3
    >>> d.bar.x
    1

    Still like a dict though

    >>> o = EasyDict({'clean':True})
    >>> o.items()
    [('clean', True)]

    And like a class

    >>> class Flower(EasyDict):
    ...     power = 1
    ...
    >>> f = Flower()
    >>> f.power
    1
    >>> f = Flower({'height': 12})
    >>> f.height
    12
    >>> f['power']
    1
    >>> sorted(f.keys())
    ['height', 'power']

    update and pop items
    >>> d = EasyDict(a=1, b='2')
    >>> e = EasyDict(c=3.0, a=9.0)
    >>> d.update(e)
    >>> d.c
    3.0
    >>> d['c']
    3.0
    >>> d.get('c')
    3.0
    >>> d.update(a=4, b=4)
    >>> d.b
    4
    >>> d.pop('a')
    4
    >>> d.a
    Traceback (most recent call last):
    ...
    AttributeError: 'EasyDict' object has no attribute 'a'
    """

    def __init__(self, d=None, **kwargs):
        if d is None:
            d = {}
        if kwargs:
            d.update(**kwargs)
        for k, v in d.items():
            setattr(self, k, v)
        # Class attributes
        for k in self.__class__.__dict__.keys():
            if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
                setattr(self, k, getattr(self, k))

    def __setattr__(self, name, value):
        if isinstance(value, (list, tuple)):
            value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
        elif isinstance(value, dict) and not isinstance(value, self.__class__):
            value = self.__class__(value)
        super(EasyDict, self).__setattr__(name, value)
        super(EasyDict, self).__setitem__(name, value)

    __setitem__ = __setattr__

    def update(self, e=None, **f):
        d = e or dict()
        d.update(f)
        for k in d:
            setattr(self, k, d[k])

    def pop(self, k, d=None):
        if hasattr(self, k):
            delattr(self, k)
        return super(EasyDict, self).pop(k, d)
    
    

================================================
FILE: stllm/datasets/__init__.py
================================================


================================================
FILE: stllm/datasets/builders/__init__.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from stllm.datasets.builders.base_dataset_builder import load_dataset_config
from stllm.datasets.builders.image_text_pair_builder import (
    CCSBUBuilder,
    LaionBuilder,
    CCSBUAlignBuilder
)
from stllm.common.registry import registry

__all__ = [
    "CCSBUBuilder",
    "LaionBuilder",
    "CCSBUAlignBuilder"
]


def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
    """
    Example

    >>> dataset = load_dataset("coco_caption", cfg=None)
    >>> splits = dataset.keys()
    >>> print([len(dataset[split]) for split in splits])

    """
    if cfg_path is None:
        cfg = None
    else:
        cfg = load_dataset_config(cfg_path)

    try:
        builder = registry.get_builder_class(name)(cfg)
    except TypeError:
        print(
            f"Dataset {name} not found. Available datasets:\n"
            + ", ".join([str(k) for k in dataset_zoo.get_names()])
        )
        exit(1)

    if vis_path is not None:
        if data_type is None:
            # use default data type in the config
            data_type = builder.config.data_type

        assert (
            data_type in builder.config.build_info
        ), f"Invalid data_type {data_type} for {name}."

        builder.config.build_info.get(data_type).storage = vis_path

    dataset = builder.build_datasets()
    return dataset


class DatasetZoo:
    def __init__(self) -> None:
        self.dataset_zoo = {
            k: list(v.DATASET_CONFIG_DICT.keys())
            for k, v in sorted(registry.mapping["builder_name_mapping"].items())
        }

    def get_names(self):
        return list(self.dataset_zoo.keys())


dataset_zoo = DatasetZoo()


================================================
FILE: stllm/datasets/builders/base_dataset_builder.py
================================================
"""
 This file is from
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import os
import shutil
import warnings

from omegaconf import OmegaConf
import torch.distributed as dist
from torchvision.datasets.utils import download_url

import stllm.common.utils as utils
from stllm.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from stllm.common.registry import registry
from stllm.processors.base_processor import BaseProcessor



class BaseDatasetBuilder:
    train_dataset_cls, eval_dataset_cls = None, None

    def __init__(self, cfg=None):
        super().__init__()

        if cfg is None:
            # help to create datasets from default config.
            self.config = load_dataset_config(self.default_config_path())
        elif isinstance(cfg, str):
            self.config = load_dataset_config(cfg)
        else:
            # when called from task.build_dataset()
            self.config = cfg

        self.data_type = self.config.data_type

        self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
        self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}

    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed

        if is_main_process():
            self._download_data()

        if is_dist_avail_and_initialized():
            dist.barrier()

        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets...")
        datasets = self.build()  # dataset['train'/'val'/'test']

        return datasets

    def build_processors(self):
        vis_proc_cfg = self.config.get("vis_processor")
        txt_proc_cfg = self.config.get("text_processor")

        if vis_proc_cfg is not None:
            vis_train_cfg = vis_proc_cfg.get("train")
            vis_eval_cfg = vis_proc_cfg.get("eval")

            self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
            self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)

        if txt_proc_cfg is not None:
            txt_train_cfg = txt_proc_cfg.get("train")
            txt_eval_cfg = txt_proc_cfg.get("eval")

            self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
            self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)

    @staticmethod
    def _build_proc_from_cfg(cfg):
        return (
            registry.get_processor_class(cfg.name).from_config(cfg)
            if cfg is not None
            else None
        )

    @classmethod
    def default_config_path(cls, type="default"):
        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])

    def _download_data(self):
        self._download_ann()
        self._download_vis()

    def _download_ann(self):
        """
        Download annotation files if necessary.
        All the vision-language datasets should have annotations of unified format.

        storage_path can be:
          (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
          (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.

        Local annotation paths should be relative.
        """
        anns = self.config.build_info.annotations

        splits = anns.keys()

        cache_root = registry.get_path("cache_root")

        for split in splits:
            info = anns[split]

            urls, storage_paths = info.get("url", None), info.storage

            if isinstance(urls, str):
                urls = [urls]
            if isinstance(storage_paths, str):
                storage_paths = [storage_paths]

            assert len(urls) == len(storage_paths)

            for url_or_filename, storage_path in zip(urls, storage_paths):
                # if storage_path is relative, make it full by prefixing with cache_root.
                if not os.path.isabs(storage_path):
                    storage_path = os.path.join(cache_root, storage_path)

                dirname = os.path.dirname(storage_path)
                if not os.path.exists(dirname):
                    os.makedirs(dirname)

                if os.path.isfile(url_or_filename):
                    src, dst = url_or_filename, storage_path
                    if not os.path.exists(dst):
                        shutil.copyfile(src=src, dst=dst)
                    else:
                        logging.info("Using existing file {}.".format(dst))
                else:
                    if os.path.isdir(storage_path):
                        # if only dirname is provided, suffix with basename of URL.
                        raise ValueError(
                            "Expecting storage_path to be a file path, got directory {}".format(
                                storage_path
                            )
                        )
                    else:
                        filename = os.path.basename(storage_path)

                    download_url(url=url_or_filename, root=dirname, filename=filename)

    def _download_vis(self):

        storage_path = self.config.build_info.get(self.data_type).storage
        storage_path = utils.get_cache_path(storage_path)

        if not os.path.exists(storage_path):
            warnings.warn(
                f"""
                The specified path {storage_path} for visual inputs does not exist.
                Please provide a correct path to the visual inputs or
                refer to datasets/download_scripts/README.md for downloading instructions.
                """
            )

    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        build_info = self.config.build_info

        ann_info = build_info.annotations
        vis_info = build_info.get(self.data_type)

        datasets = dict()
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            # annotation path
            ann_paths = ann_info.get(split).storage
            if isinstance(ann_paths, str):
                ann_paths = [ann_paths]

            abs_ann_paths = []
            for ann_path in ann_paths:
                if not os.path.isabs(ann_path):
                    ann_path = utils.get_cache_path(ann_path)
                abs_ann_paths.append(ann_path)
            ann_paths = abs_ann_paths

            # visual data storage path
            vis_path = os.path.join(vis_info.storage, split)

            if not os.path.isabs(vis_path):
                # vis_path = os.path.join(utils.get_cache_path(), vis_path)
                vis_path = utils.get_cache_path(vis_path)

            if not os.path.exists(vis_path):
                warnings.warn("storage path {} does not exist.".format(vis_path))

            # create datasets
            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
            datasets[split] = dataset_cls(
                vis_processor=vis_processor,
                text_processor=text_processor,
                ann_paths=ann_paths,
                vis_root=vis_path,
            )

        return datasets


def load_dataset_config(cfg_path):
    cfg = OmegaConf.load(cfg_path).datasets
    cfg = cfg[list(cfg.keys())[0]]

    return cfg


================================================
FILE: stllm/datasets/builders/image_text_pair_builder.py
================================================
import os
import logging
import warnings

from stllm.common.registry import registry
from stllm.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from stllm.datasets.datasets.laion_dataset import LaionDataset
from stllm.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset


@registry.register_builder("cc_sbu")
class CCSBUBuilder(BaseDatasetBuilder):
    train_dataset_cls = CCSBUDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}

    def _download_ann(self):
        pass

    def _download_vis(self):
        pass

    def build(self):
        self.build_processors()

        build_info = self.config.build_info

        datasets = dict()
        split = "train"

        # create datasets
        # [NOTE] return inner_datasets (wds.DataPipeline)
        dataset_cls = self.train_dataset_cls
        datasets[split] = dataset_cls(
            vis_processor=self.vis_processors[split],
            text_processor=self.text_processors[split],
            location=build_info.storage,
        ).inner_dataset

        return datasets


@registry.register_builder("laion")
class LaionBuilder(BaseDatasetBuilder):
    train_dataset_cls = LaionDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}

    def _download_ann(self):
        pass

    def _download_vis(self):
        pass

    def build(self):
        self.build_processors()

        build_info = self.config.build_info

        datasets = dict()
        split = "train"

        # create datasets
        # [NOTE] return inner_datasets (wds.DataPipeline)
        dataset_cls = self.train_dataset_cls
        datasets[split] = dataset_cls(
            vis_processor=self.vis_processors[split],
            text_processor=self.text_processors[split],
            location=build_info.storage,
        ).inner_dataset

        return datasets


@registry.register_builder("cc_sbu_align")
class CCSBUAlignBuilder(BaseDatasetBuilder):
    train_dataset_cls = CCSBUAlignDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/cc_sbu/align.yaml",
    }

    def build_datasets(self):
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets...")
        self.build_processors()

        build_info = self.config.build_info
        storage_path = build_info.storage

        datasets = dict()

        if not os.path.exists(storage_path):
            warnings.warn("storage path {} does not exist.".format(storage_path))

        # create datasets
        dataset_cls = self.train_dataset_cls
        datasets['train'] = dataset_cls(
            vis_processor=self.vis_processors["train"],
            text_processor=self.text_processors["train"],
            ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
            vis_root=os.path.join(storage_path, 'image'),
        )

        return datasets



================================================
FILE: stllm/datasets/data_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import gzip
import logging
import os
import random as rnd
import tarfile
import zipfile
import random
from typing import List
from tqdm import tqdm

import decord
from decord import VideoReader
import webdataset as wds
import numpy as np
import torch
from torch.utils.data.dataset import IterableDataset

from stllm.common.registry import registry
from stllm.datasets.datasets.base_dataset import ConcatDataset


decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT")


class ChainDataset(wds.DataPipeline):
    r"""Dataset for chaining multiple :class:`DataPipeline` s.

    This class is useful to assemble different existing dataset streams. The
    chaining operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.

    Args:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets: List[wds.DataPipeline]) -> None:
        super().__init__()
        self.datasets = datasets
        self.prob = []
        self.names = []
        for dataset in self.datasets:
            if hasattr(dataset, 'name'):
                self.names.append(dataset.name)
            else:
                self.names.append('Unknown')
            if hasattr(dataset, 'sample_ratio'):
                self.prob.append(dataset.sample_ratio)
            else:
                self.prob.append(1)
                logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")

    def __iter__(self):
        datastreams = [iter(dataset) for dataset in self.datasets]
        while True:
            select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
            yield next(select_datastream)


def apply_to_sample(f, sample):
    if len(sample) == 0:
        return {}

    def _apply(x):
        if torch.is_tensor(x):
            return f(x)
        elif isinstance(x, dict):
            return {key: _apply(value) for key, value in x.items()}
        elif isinstance(x, list):
            return [_apply(x) for x in x]
        else:
            return x

    return _apply(sample)


def move_to_cuda(sample):
    def _move_to_cuda(tensor):
        return tensor.cuda()

    return apply_to_sample(_move_to_cuda, sample)


def prepare_sample(samples, cuda_enabled=True):
    if cuda_enabled:
        samples = move_to_cuda(samples)

    # TODO fp16 support

    return samples


def reorg_datasets_by_split(datasets):
    """
    Organizes datasets by split.

    Args:
        datasets: dict of torch.utils.data.Dataset objects by name.

    Returns:
        Dict of datasets by split {split_name: List[Datasets]}.
    """
    # if len(datasets) == 1:
    #     return datasets[list(datasets.keys())[0]]
    # else:
    reorg_datasets = dict()

    # reorganize by split
    for _, dataset in datasets.items():
        for split_name, dataset_split in dataset.items():
            if split_name not in reorg_datasets:
                reorg_datasets[split_name] = [dataset_split]
            else:
                reorg_datasets[split_name].append(dataset_split)

    return reorg_datasets


def concat_datasets(datasets):
    """
    Concatenates multiple datasets into a single dataset.

    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
    generic IterableDataset because it requires creating separate samplers.

    Now only supports conctenating training datasets and assuming validation and testing
    have only a single dataset. This is because metrics should not be computed on the concatenated
    datasets.

    Args:
        datasets: dict of torch.utils.data.Dataset objects by split.

    Returns:
        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
        "val" and "test" remain the same.

        If the input training datasets contain both map-style and DataPipeline datasets, returns
        a tuple, where the first element is a concatenated map-style dataset and the second
        element is a chained DataPipeline dataset.

    """
    # concatenate datasets in the same split
    for split_name in datasets:
        if split_name != "train":
            assert (
                len(datasets[split_name]) == 1
            ), "Do not support multiple {} datasets.".format(split_name)
            datasets[split_name] = datasets[split_name][0]
        else:
            iterable_datasets, map_datasets = [], []
            for dataset in datasets[split_name]:
                if isinstance(dataset, wds.DataPipeline):
                    logging.info(
                        "Dataset {} is IterableDataset, can't be concatenated.".format(
                            dataset
                        )
                    )
                    iterable_datasets.append(dataset)
                elif isinstance(dataset, IterableDataset):
                    raise NotImplementedError(
                        "Do not support concatenation of generic IterableDataset."
                    )
                else:
                    map_datasets.append(dataset)

            # if len(iterable_datasets) > 0:
            # concatenate map-style datasets and iterable-style datasets separately
            if len(iterable_datasets) > 1:
                chained_datasets = (
                    ChainDataset(iterable_datasets)
                )
            elif len(iterable_datasets) == 1:
                chained_datasets = iterable_datasets[0]
            else:
                chained_datasets = None

            concat_datasets = (
                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
            )

            train_datasets = concat_datasets, chained_datasets
            train_datasets = tuple([x for x in train_datasets if x is not None])
            train_datasets = (
                train_datasets[0] if len(train_datasets) == 1 else train_datasets
            )

            datasets[split_name] = train_datasets

    return datasets



================================================
FILE: stllm/datasets/datasets/__init__.py
================================================


================================================
FILE: stllm/datasets/datasets/base_dataset.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
from typing import Iterable

from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.dataloader import default_collate


class BaseDataset(Dataset):
    def __init__(
        self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
    ):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.annotation = []
        for ann_path in ann_paths:
            jfile = json.load(open(ann_path, "r"))
            if 'annotations' in jfile:
                self.annotation.extend(jfile['annotations'])
            else:
                self.annotation.extend(jfile)

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self._add_instance_ids()

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)

    def set_processors(self, vis_processor, text_processor):
        self.vis_processor = vis_processor
        self.text_processor = text_processor

    def _add_instance_ids(self, key="instance_id"):
        for idx, ann in enumerate(self.annotation):
            ann[key] = str(idx)


class ConcatDataset(ConcatDataset):
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super().__init__(datasets)

    def collater(self, samples):
        # TODO For now only supports datasets with same underlying collater implementations

        all_keys = set()
        for s in samples:
            all_keys.update(s)

        shared_keys = all_keys
        for s in samples:
            shared_keys = shared_keys & set(s.keys())

        samples_shared_keys = []
        for s in samples:
            samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})

        return self.datasets[0].collater(samples_shared_keys)


================================================
FILE: stllm/datasets/datasets/caption_datasets.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from collections import OrderedDict

from stllm.datasets.datasets.base_dataset import BaseDataset
from PIL import Image


class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]

        return OrderedDict(
            {
                "file": ann["image"],
                "caption": ann["caption"],
                "image": sample["image"],
            }
        )


class CaptionDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        img_file = '{:0>12}.jpg'.format(ann["image_id"])
        image_path = os.path.join(self.vis_root, img_file)
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])

        return {
            "image": image,
            "text_input": caption,
            "image_id": self.img_ids[ann["image_id"]],
        }


class CaptionEvalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

    def __getitem__(self, index):

        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        return {
            "image": image,
            "image_id": ann["image_id"],
            "instance_id": ann["instance_id"],
        }


================================================
FILE: stllm/datasets/datasets/cc_sbu_dataset.py
================================================
import os
import pickle
from PIL import Image
import webdataset as wds
from stllm.datasets.datasets.base_dataset import BaseDataset
from stllm.datasets.datasets.caption_datasets import CaptionDataset


class CCSBUDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, location):
        super().__init__(vis_processor=vis_processor, text_processor=text_processor)

        self.inner_dataset = wds.DataPipeline(
            wds.ResampledShards(location),
            wds.tarfile_to_samples(handler=wds.warn_and_continue),
            wds.shuffle(1000, handler=wds.warn_and_continue),
            wds.decode("pilrgb", handler=wds.warn_and_continue),
            wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
            wds.map(self.to_dict, handler=wds.warn_and_continue),
        )

    def to_dict(self, sample):
        return {
            "image": sample[0],
            "answer": self.text_processor(sample[1]["caption"]),
        }


class CCSBUAlignDataset(CaptionDataset):

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        img_file = '{}.jpg'.format(ann["image_id"])
        image_path = os.path.join(self.vis_root, img_file)
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = ann["caption"]

        return {
            "image": image,
            "answer": caption,
            "image_id": self.img_ids[ann["image_id"]],
        }


================================================
FILE: stllm/datasets/datasets/dataloader_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import time
import random
import torch
from stllm.datasets.data_utils import move_to_cuda
from torch.utils.data import DataLoader
import torch.distributed as dist

class MultiIterLoader:
    """
    A simple wrapper for iterating over multiple iterators.

    Args:
        loaders (List[Loader]): List of Iterator loaders.
        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
    """

    def __init__(self, loaders, ratios=None):
        # assert all loaders has __next__ method
        for loader in loaders:
            assert hasattr(
                loader, "__next__"
            ), "Loader {} has no __next__ method.".format(loader)

        if ratios is None:
            ratios = [1.0] * len(loaders)
        else:
            assert len(ratios) == len(loaders)
            ratios = [float(ratio) / sum(ratios) for ratio in ratios]

        self.loaders = loaders
        self.ratios = ratios

    def __next__(self):
        # random sample from each loader by ratio
        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
        return next(self.loaders[loader_idx])

class MetaLoader(object):
    """ wraps multiple data loader """
    def __init__(self, loaders, ratios=None):
        """Iterates over multiple dataloaders, it ensures all processes
        work on data from the same dataloader. This loader will end when
        the shorter dataloader raises StopIteration exception.

        loaders: List, [dataloader]
        """
        self.loaders = loaders
        self.iter_order = self.build_iter()

    def build_iter(self):
        iter_order = []

        for n, l in enumerate(self.loaders):
            iter_order.extend([n]*len(l))

        random.shuffle(iter_order)
        iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)

        # sync
        if dist.is_available():
            # make sure all processes have the same order so that
            # each step they will have data from the same loader
            dist.broadcast(iter_order, src=0)
        return iter_order

    def __len__(self):
        return len(self.iter_order)

    def __iter__(self):
        """ this iterator will run indefinitely """
        for i, loader_idx in enumerate(self.iter_order):
            batch = next(self.loaders[loader_idx])
            if i==len(self)-1:
                self.iter_order = self.build_iter()
            yield batch

class PrefetchLoader(object):
    """
    Modified from https://github.com/ChenRocks/UNITER.

    overlap compute and cuda data transfer
    (copied and then modified from nvidia apex)
    """

    def __init__(self, loader):
        self.loader = loader
        self.stream = torch.cuda.Stream()

    def __iter__(self):
        loader_it = iter(self.loader)
        self.preload(loader_it)
        batch = self.next(loader_it)
        while batch is not None:
            is_tuple = isinstance(batch, tuple)
            if is_tuple:
                task, batch = batch

            if is_tuple:
                yield task, batch
            else:
                yield batch
            batch = self.next(loader_it)

    def __len__(self):
        return len(self.loader)

    def preload(self, it):
        try:
            self.batch = next(it)
        except StopIteration:
            self.batch = None
            return
        # if record_stream() doesn't work, another option is to make sure
        # device inputs are created on the main stream.
        # self.next_input_gpu = torch.empty_like(self.next_input,
        #                                        device='cuda')
        # self.next_target_gpu = torch.empty_like(self.next_target,
        #                                         device='cuda')
        # Need to make sure the memory allocated for next_* is not still in use
        # by the main stream at the time we start copying to next_*:
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.batch = move_to_cuda(self.batch)
            # more code for the alternative if record_stream() doesn't work:
            # copy_ will record the use of the pinned source tensor in this
            # side stream.
            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
            # self.next_input = self.next_input_gpu
            # self.next_target = self.next_target_gpu

    def next(self, it):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is not None:
            record_cuda_stream(batch)
        self.preload(it)
        return batch

    def __getattr__(self, name):
        method = self.loader.__getattribute__(name)
        return method


def record_cuda_stream(batch):
    if isinstance(batch, torch.Tensor):
        batch.record_stream(torch.cuda.current_stream())
    elif isinstance(batch, list) or isinstance(batch, tuple):
        for t in batch:
            record_cuda_stream(t)
    elif isinstance(batch, dict):
        for t in batch.values():
            record_cuda_stream(t)
    else:
        pass


class IterLoader:
    """
    A wrapper to convert DataLoader as an infinite iterator.

    Modified from:
        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
    """

    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._use_distributed = use_distributed
        self._epoch = 0

    @property
    def epoch(self) -> int:
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __iter__(self):
        return self

    def __len__(self):
        return len(self._dataloader)


================================================
FILE: stllm/datasets/datasets/image_video_itdatasets.py
================================================
import logging
import os
import random
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from stllm.datasets.datasets.instruction_data import available_corpus, train_transform

import json
from os.path import basename
import numpy as np

from .utils import load_anno, pre_text, VIDEO_READER_FUNCS, load_image_from_path

try:
    from mmengine import fileio 
    has_client = True
except ImportError:
    has_client = False

logger = logging.getLogger(__name__)


class ImageVideoBaseDataset(Dataset):
    """Base class that implements the image and video loading methods"""

    media_type = "video"

    def __init__(self):
        assert self.media_type in ["image", "video", "only_video"]
        self.data_root = None
        self.anno_list = (
            None  # list(dict), each dict contains {"image": str, # image or video path}
        )
        self.transform = None
        self.video_reader = None
        self.num_tries = None

        self.client = None
        if has_client:
            self.client = fileio

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def get_anno(self, index):
        """obtain the annotation for one media (video or image)

        Args:
            index (int): The media index.

        Returns: dict.
            - "image": the filename, video also use "image".
            - "caption": The caption for this file.

        """
        anno = self.anno_list[index]
        if self.data_root is not None:
            anno["image"] = os.path.join(self.data_root, anno["image"])
        return anno

    def load_and_transform_media_data(self, index, data_path):
        if self.media_type == "image":
            return self.load_and_transform_media_data_image(index, data_path)
        else:
            return self.load_and_transform_media_data_video(index, data_path)

    def load_and_transform_media_data_image(self, index, data_path):
        image = load_image_from_path(data_path, client=self.client)
        image = self.transform(image)
        return image, index

    def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None):
        for _ in range(self.num_tries):
            try:
                max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
                frames, frame_indices, sec = self.video_reader(
                    data_path, self.num_frames, self.sample_type, 
                    max_num_frames=max_num_frames, client=self.client, clip=clip
                )
            except Exception as e:
                logger.warning(
                    f"Caught exception {e} when loading video {data_path}, "
                    f"randomly sample a new video as replacement"
                )
                index = random.randint(0, len(self) - 1)
                ann = self.get_anno(index)
                data_path = ann["image"]
                continue
            # shared aug for video frames
            frames = self.transform(frames)
            if return_fps:
                #sec = [str(round(f / fps, 1)) for f in frame_indices]
                return frames, index, sec
            else:
                return frames, index
        else:
            raise RuntimeError(
                f"Failed to fetch video after {self.num_tries} tries. "
                f"This might indicate that you have many corrupted videos."
            )

class PTImgTrainDataset(ImageVideoBaseDataset):
    media_type = "image"

    def __init__(self, ann_file, transform, pre_text=True):
        super().__init__()

        if len(ann_file) == 3 and ann_file[2] == "video":
            self.media_type = "video"  
        else:
            self.media_type = "image"
        self.label_file, self.data_root = ann_file[:2]

        logger.info('Load json file')
        with open(self.label_file, 'r') as f:
            self.anno = json.load(f)
        self.num_examples = len(self.anno)

        self.transform = transform
        self.pre_text = pre_text
        logger.info(f"Pre-process text: {pre_text}")

    def get_anno(self, index):
        filename = self.anno[index][self.media_type]
        caption = self.anno[index]["caption"]
        anno = {"image": os.path.join(self.data_root, filename), "caption": caption}
        return anno

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        try:
            ann = self.get_anno(index)
            image, index = self.load_and_transform_media_data(index, ann["image"])
            caption = pre_text(ann["caption"], pre_text=self.pre_text)
            return image, caption, index
        except Exception as e:
            logger.warning(f"Caught exception {e} when loading image {ann['image']}")
            index = np.random.randint(0, len(self))
            return self.__getitem__(index)

class PTVidTrainDataset(PTImgTrainDataset):
    media_type = "video"

    def __init__(
        self,
        ann_file,
        transform,
        num_frames=4,
        video_reader_type="decord",
        sample_type="rand",
        num_tries=3,
        pre_text=True
    ):
        super().__init__(ann_file, transform, pre_text=pre_text)
        self.num_frames = num_frames
        self.video_reader_type = video_reader_type
        self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
        self.sample_type = sample_type
        self.num_tries = num_tries

class ITImgTrainDataset(ImageVideoBaseDataset):
    media_type = "image"

    def __init__(
        self, ann_file, transform, simple=False,
        system="", role=("Human", "Assistant"),
        start_token="<Image>", end_token="</Image>",
        random_shuffle=True, # if True, shuffle the QA list
    ):
        super().__init__()

        if len(ann_file) == 3 and ann_file[2] == "video":
            self.media_type = "video"  
        else:
            self.media_type = "image"
        self.label_file, self.data_root = ann_file[:2]

        logger.info('Load json file')
        with open(self.label_file, 'r') as f:
            self.anno = json.load(f)
        self.num_examples = len(self.anno)
        self.transform = transform

        # prompt parameters
        if system:
            assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
        # currently not support add start_token and end_token in the system, since the msg should be added properly
        self.begin_signal = "###"
        self.end_signal = " "
        self.start_token = start_token
        self.end_token = end_token
        self.system = system
        self.role = role
        self.random_shuffle = random_shuffle
        self.simple = simple
        # instruction location and number
        logger.info(f"Random shuffle: {self.random_shuffle}")

    def get_anno(self, index):
        filename = self.anno[index][self.media_type]
        qa = self.anno[index]["QA"]
        if "num_frames" in self.anno[index]:
            self.max_num_frames = self.anno[index]["num_frames"]
        if "start" in self.anno[index] and "end" in self.anno[index]:
            anno = {
                "image": os.path.join(self.data_root, filename), "qa": qa,
                "start": self.anno[index]["start"], "end": self.anno[index]["end"],
            }
        else:
            anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
        return anno

    def __len__(self):
        return self.num_examples
    
    def process_qa(self, qa, msg=""):
        cur_instruction = ""
        # randomly shuffle qa for conversation
        if self.random_shuffle and len(qa) > 1:
            random.shuffle(qa)
        if "i" in qa[0].keys() and qa[0]["i"] != "":
            cur_instruction = qa[0]["i"] + self.end_signal

        conversation = self.system
        # add instruction as system message

        # rstrip() for the extra " " in msg
        if not self.simple:
            if cur_instruction:
                conversation += cur_instruction
            conversation += (
                self.begin_signal + self.role[0] + ": " + 
                self.start_token + '<ImageHere>' + self.end_token + msg.rstrip() + ' ' + 
                qa[0]["q"] + self.end_signal + self.begin_signal + self.role[1] + ": "
            )
        else:
            conversation += '<ImageHere>'
            conversation += (
                self.begin_signal + self.role[0] + ": " + cur_instruction + msg.rstrip() + 
                qa[0]["q"] + self.end_signal + self.begin_signal + self.role[1] + ": "
            )
        
        return conversation, qa[0]["a"]

    def __getitem__(self, index):
        try:
            ann = self.get_anno(index)
            image, index = self.load_and_transform_media_data_image(index, ann["image"])
            instruction, answer = self.process_qa(ann["qa"])
            return {
                "image": image,
                "answer": answer,
                "image_id": index,
                "instruction_input": instruction
            }
        except Exception as e:
            logger.warning(f"Caught exception {e} when loading image {ann['image']}")
            index = np.random.randint(0, len(self))
            return self.__getitem__(index)

class ITVidTrainDataset(ITImgTrainDataset):
    media_type = "video"

    def __init__(
        self, ann_file, transform, simple=False,
        num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
        system="", role=("Human", "Assistant"),
        start_token="<Video>", end_token="</Video>",
        add_second_msg=False,
        random_shuffle=True,
    ):
        super().__init__(
            ann_file, transform, 
            system=system, role=role,
            start_token=start_token, end_token=end_token,
            random_shuffle=random_shuffle,
            simple=simple,
        )
        self.num_frames = num_frames
        self.video_reader_type = video_reader_type
        self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
        self.sample_type = sample_type
        self.num_tries = num_tries
        self.add_second_msg = add_second_msg

        logger.info(f"Use {video_reader_type} for data in {ann_file}")
        if add_second_msg:
            logger.info(f"Add second message: The video contains X frames sampled at T seconds.")

    def __getitem__(self, index):
        try:
            ann = self.get_anno(index)
            msg = ""
            clip = None
            if "start" in ann and "end" in ann:
                clip = [ann["start"], ann["end"]]
            video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip)
            if self.add_second_msg:
                # " " should be added in the start and end
                msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
            instruction, answer = self.process_qa(ann["qa"], msg)
            return {
                "image": video,
                "answer": answer,
                "image_id": index,
                "instruction_input": instruction,
                "video_len": sec
            }
        except Exception as e:
            logger.warning(f"Caught exception {e} when loading video {ann['image']}")
            index = np.random.randint(0, len(self))
            return self.__getitem__(index)
 
if __name__ == "__main__":
    pass



================================================
FILE: stllm/datasets/datasets/instruction_data.py
================================================
from torchvision import transforms
from torchvision.transforms import InterpolationMode

mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
normalize = transforms.Normalize(mean, std)
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
train_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(
            224,
            scale=(0.5, 1.0),
            interpolation=InterpolationMode.BICUBIC,
        ),
        #transforms.RandomHorizontalFlip(),
        type_transform,
        normalize,
    ]
)

anno_root_it = '/Path/to/MVBench/VideoChat2-IT'

# ============== pretraining datasets=================
available_corpus = dict(
    # image
    llava_full=[
        f"{anno_root_it}/image/llava/llava_full.json", 
        "your_data_path/coco_caption",
    ],
    caption_coco=[
        f"{anno_root_it}/image/caption/coco/train.json", 
        "your_data_path/coco_caption",
    ],
    caption_llava=[
        f"{anno_root_it}/image/caption/llava/train.json", 
        "your_data_path/coco_caption",
    ],
    caption_minigpt4=[
        f"{anno_root_it}/image/caption/minigpt4/train.json", 
        "your_data_path/minigpt4/image",
    ],
    caption_paragraph_captioning=[
        f"{anno_root_it}/image/caption/paragraph_captioning/train.json", 
        "your_data_path/m3it/image-paragraph-captioning",
    ],
    caption_textcaps=[
        f"{anno_root_it}/image/caption/textcaps/train.json", 
        "your_data_path/m3it/textcap",
    ],
    classification_imagenet=[
        f"{anno_root_it}/image/classification/imagenet/train.json", 
        "your_data_path/m3it/imagenet",
    ],
    classification_coco_itm=[
        f"{anno_root_it}/image/classification/coco_itm/train.json", 
        "your_data_path/m3it/coco-itm",
    ],
    conversation_llava=[
        f"{anno_root_it}/image/conversation/llava/train.json", 
        "your_data_path/coco_caption",
    ],
    reasoning_clevr=[
        f"{anno_root_it}/image/reasoning/clevr/train.json", 
        "your_data_path/m3it/clevr",
    ],
    reasoning_visual_mrc=[
        f"{anno_root_it}/image/reasoning/visual_mrc/train.json", 
        "your_data_path/m3it/visual-mrc",
    ],
    reasoning_llava=[
        f"{anno_root_it}/image/reasoning/llava/train.json", 
        "your_data_path/coco_caption",
    ],
    vqa_vqav2=[
        f"{anno_root_it}/image/vqa/vqav2/train.json", 
        "your_data_path/m3it/vqa-v2",
    ],
    vqa_gqa=[
        f"{anno_root_it}/image/vqa/gqa/train.json", 
        "your_data_path/m3it/gqa",
    ],
    vqa_okvqa=[
        f"{anno_root_it}/image/vqa/okvqa/train.json", 
        "your_data_path/m3it/okvqa",
    ],
    vqa_a_okvqa=[
        f"{anno_root_it}/image/vqa/a_okvqa/train.json", 
        "your_data_path/m3it/a-okvqa",
    ],
    vqa_viquae=[
        f"{anno_root_it}/image/vqa/viquae/train.json", 
        "your_data_path/m3it/viquae",
    ],
    vqa_ocr_vqa=[
        f"{anno_root_it}/image/vqa/ocr_vqa/train.json", 
        "your_data_path/m3it/ocr-vqa",
    ],
    vqa_text_vqa=[
        f"{anno_root_it}/image/vqa/text_vqa/train.json", 
        "your_data_path/m3it/text-vqa",
    ],
    vqa_st_vqa=[
        f"{anno_root_it}/image/vqa/st_vqa/train.json", 
        "your_data_path/m3it/st-vqa",
    ],
    vqa_docvqa=[
        f"{anno_root_it}/image/vqa/docvqa/train.json", 
        "your_data_path/m3it/docvqa",
    ],
    # video
    caption_textvr=[
        f"{anno_root_it}/video/caption/textvr/train.json", 
        "your_data_path/TextVR/Video",
        "video"
    ],
    caption_videochat=[
        f"{anno_root_it}/video/caption/videochat/train.json", 
        "your_data_path/WebVid10M",
        "video"
    ],
    caption_webvid=[
        f"{anno_root_it}/video/caption/webvid/train.json", 
        "your_data_path/WebVid2M",
        "video"
    ],
    caption_youcook2=[
        f"{anno_root_it}/video/caption/youcook2/train.json", 
        "your_data_path/youcook2/split_videos",
        "video"
    ],
    classification_k710=[
        f"{anno_root_it}/video/classification/k710/train.json", 
        "",
        "video"
    ],
    classification_ssv2=[
        f"{anno_root_it}/video/classification/ssv2/train.json", 
        "your_data_path/video_pub/ssv2_video",
        "video"
    ],
    conversation_videochat1=[
        f"{anno_root_it}/video/conversation/videochat1/train_flat.json", 
        "your_data_path/WebVid10M",
        "video"
    ],
    conversation_videochat2=[
        f"{anno_root_it}/video/conversation/videochat2/train.json", 
        "your_data_path/internvid",
        "video"
    ],
    caption_videochatgpt=[
        f"{anno_root_it}/video/conversation/videochatgpt/train_full_flat.json", 
        "your_data_path/ANet/ANet_320p_fps30",
        "video"
    ],
    reasoning_next_qa=[
        f"{anno_root_it}/video/reasoning/next_qa/train.json", 
        "your_data_path/nextqa",
        "video"
    ],
    reasoning_clevrer_qa=[
        f"{anno_root_it}/video/reasoning/clevrer_qa/train.json", 
        "your_data_path/clevrer/video_train",
        "video"
    ],
    reasoning_clevrer_mc=[
        f"{anno_root_it}/video/reasoning/clevrer_mc/train.json",  
        "your_data_path/clevrer/video_train",
        "video"
    ],
    vqa_ego_qa=[
        f"{anno_root_it}/video/vqa/ego_qa/train.json", 
        "your_data_path/EgoQA/split_videos",
        "video"
    ],
    vqa_tgif_frame_qa=[
        f"{anno_root_it}/video/vqa/tgif_frame_qa/train.json", 
        "your_data_path/tgif",
        "video"
    ],
    vqa_tgif_transition_qa=[
        f"{anno_root_it}/video/vqa/tgif_transition_qa/train.json", 
        "your_data_path/tgif",
        "video"
    ],
    vqa_webvid_qa=[
        f"{anno_root_it}/video/vqa/webvid_qa/train.json", 
        "your_data_path/WebVid2M",
        "video"
    ],
)




================================================
FILE: stllm/datasets/datasets/laion_dataset.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import webdataset as wds
from stllm.datasets.datasets.base_dataset import BaseDataset


class LaionDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, location):
        super().__init__(vis_processor=vis_processor, text_processor=text_processor)

        self.inner_dataset = wds.DataPipeline(
            wds.ResampledShards(location),
            wds.tarfile_to_samples(handler=wds.warn_and_continue),
            wds.shuffle(1000, handler=wds.warn_and_continue),
            wds.decode("pilrgb", handler=wds.warn_and_continue),
            wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
            wds.map(self.to_dict, handler=wds.warn_and_continue),
        )

    def to_dict(self, sample):
        return {
            "image": sample[0],
            "answer": self.text_processor(sample[1]["caption"]),
        }



================================================
FILE: stllm/datasets/datasets/utils.py
================================================
#from utils.distributed import is_main_process, get_rank, get_world_size
import logging
import torch.distributed as dist
import torch
import io
import os
import json
import re
import copy
import numpy as np
from os.path import join
from tqdm import trange
from PIL import Image
from PIL import ImageFile
from torchvision.transforms import PILToTensor
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

import random
import av
import cv2
import decord
import imageio
from decord import VideoReader
import torch
import math
decord.bridge.set_bridge("torch")

import logging
logger = logging.getLogger(__name__)

def load_image_from_path(image_path, client):
    if image_path.startswith('s3') or image_path.startswith('p2'):
        value = client.get(image_path)
        img_bytes = np.frombuffer(value, dtype=np.uint8)
        buff = io.BytesIO(img_bytes)
        image = Image.open(buff).convert('RGB')
    else:
        image = Image.open(image_path).convert('RGB')  # PIL Image
    image = PILToTensor()(image).unsqueeze(0)  # (1, C, H, W), torch.uint8
    return image

def load_anno(ann_file_list):
    """[summary]

    Args:
        ann_file_list (List[List[str, str]] or List[str, str]):
            the latter will be automatically converted to the former.
            Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
            which specifies the data type, video or image

    Returns:
        List(dict): each dict is {
            image: str or List[str],  # image_path,
            caption: str or List[str]  # caption text string
        }
    """
    if isinstance(ann_file_list[0], str):
        ann_file_list = [ann_file_list]

    ann = []
    for d in ann_file_list:
        data_root = d[1]
        fp = d[0]
        is_video = len(d) == 3 and d[2] == "video"
        cur_ann = json.load(open(fp, "r"))
        iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
            if is_main_process() else range(len(cur_ann))
        for idx in iterator:
            key = "video" if is_video else "image"
            # unified to have the same key for data path
            if isinstance(cur_ann[idx][key], str):
                cur_ann[idx]["image"] = join(data_root, cur_ann[idx][key])
            else:  # list
                cur_ann[idx]["image"] = [join(data_root, e) for e in cur_ann[idx][key]]
        ann += cur_ann
    return ann


def pre_text(text, max_l=None, pre_text=True):
    if pre_text:
        text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
        text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')

        text = re.sub(r"\s{2,}", ' ', text)
        text = text.rstrip('\n').strip(' ')

        if max_l:  # truncate
            words = text.split(' ')
            if len(words) > max_l:
                text = ' '.join(words[:max_l])
    else:
        pass
    return text


logger = logging.getLogger(__name__)


def collect_result(result, result_dir, filename, is_json=True, is_list=True):
    if is_json:
        result_file = os.path.join(
            result_dir, '%s_rank%d.json' % (filename, get_rank()))
        final_result_file = os.path.join(result_dir, '%s.json' % filename)
        json.dump(result, open(result_file, 'w'))
    else:
        result_file = os.path.join(
            result_dir, '%s_rank%d.pth' % (filename, get_rank()))
        final_result_file = os.path.join(result_dir, '%s.pth' % filename)
        torch.save(result, result_file)

    dist.barrier()

    result = None
    if is_main_process():
        # combine results from all processes
        if is_list:
            result = []
        else:
            result = {}
        for rank in range(get_world_size()):
            if is_json:
                result_file = os.path.join(
                    result_dir, '%s_rank%d.json' % (filename, rank))
                res = json.load(open(result_file, 'r'))
            else:
                result_file = os.path.join(
                    result_dir, '%s_rank%d.pth' % (filename, rank))
                res = torch.load(result_file)
            if is_list:
                result += res
            else:
                result.update(res)

    return result


def sync_save_result(result, result_dir, filename, is_json=True, is_list=True):
    """gather results from multiple GPUs"""
    if is_json:
        result_file = os.path.join(
            result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank()))
        final_result_file = os.path.join(result_dir, '%s.json' % filename)
        os.makedirs(os.path.dirname(result_file), exist_ok=True)
        json.dump(result, open(result_file, 'w'))
    else:
        result_file = os.path.join(
            result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank()))
        os.makedirs(os.path.dirname(result_file), exist_ok=True)
        final_result_file = os.path.join(result_dir, '%s.pth' % filename)
        torch.save(result, result_file)

    dist.barrier()

    if is_main_process():
        # combine results from all processes
        if is_list:
            result = []
        else:
            result = {}
        for rank in range(get_world_size()):
            if is_json:
                result_file = os.path.join(
                    result_dir, "dist_res", '%s_rank%d.json' % (filename, rank))
                res = json.load(open(result_file, 'r'))
            else:
                result_file = os.path.join(
                    result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank))
                res = torch.load(result_file)
            if is_list:
                result += res
            else:
                result.update(res)
        if is_json:
            json.dump(result, open(final_result_file, 'w'))
        else:
            torch.save(result, final_result_file)

        logger.info('result file saved to %s' % final_result_file)
    dist.barrier()
    return final_result_file, result


def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
    """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
    into a (n+1)-d array, only allow the first dim has variable lengths.
    Args:
        sequences: list(n-d tensor or list)
        dtype: np.dtype or torch.dtype
        device:
        fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
            return will be of shape [len(sequences), fixed_length, ...]
    Returns:
        padded_seqs: ((n+1)-d tensor) padded with zeros
        mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
              1 indicate valid, 0 otherwise
    Examples:
        >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
        >>> pad_sequences_1d(test_data_list, dtype=torch.long)
        >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
        >>> pad_sequences_1d(test_data_3d, dtype=torch.float)
        >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
        >>> pad_sequences_1d(test_data_list, dtype=np.float32)
        >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
        >>> pad_sequences_1d(test_data_3d, dtype=np.float32)
    """
    if isinstance(sequences[0], list):
        if "torch" in str(dtype):
            sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
        else:
            sequences = [np.asarray(s, dtype=dtype) for s in sequences]

    extra_dims = sequences[0].shape[1:]  # the extra dims should be the same for all elements
    lengths = [len(seq) for seq in sequences]
    if fixed_length is not None:
        max_length = fixed_length
    else:
        max_length = max(lengths)
    if isinstance(sequences[0], torch.Tensor):
        assert "torch" in str(dtype), "dtype and input type does not match"
        padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
        mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
    else:  # np
        assert "numpy" in str(dtype), "dtype and input type does not match"
        padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
        mask = np.zeros((len(sequences), max_length), dtype=np.float32)

    for idx, seq in enumerate(sequences):
        end = lengths[idx]
        padded_seqs[idx, :end] = seq
        mask[idx, :end] = 1
    return padded_seqs, mask  # , lengths

def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
    """
    Converts a present time with the given time base and start_pts offset to seconds.

    Returns:
        time_in_seconds (float): The corresponding time in seconds.

    https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
    """
    if pts == math.inf:
        return math.inf

    return int(pts - start_pts) * time_base


def get_pyav_video_duration(video_reader):
    video_stream = video_reader.streams.video[0]
    video_duration = pts_to_secs(
        video_stream.duration,
        video_stream.time_base,
        video_stream.start_time
    )
    return float(video_duration)


def get_frame_indices_by_fps():
    pass

def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
    if sample in ["rand", "middle"]: # uniform sampling
        acc_samples = min(num_frames, vlen)
        # split the video into `acc_samples` intervals, and sample from each interval.
        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1] - 1))
        if sample == 'rand':
            try:
                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
            except:
                frame_indices = np.random.permutation(vlen)[:acc_samples]
                frame_indices.sort()
                frame_indices = list(frame_indices)
        elif fix_start is not None:
            frame_indices = [x[0] + fix_start for x in ranges]
        elif sample == 'middle':
            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
        else:
            raise NotImplementedError

        if len(frame_indices) < num_frames:  # padded with last frame
            padded_frame_indices = [frame_indices[-1]] * num_frames
            padded_frame_indices[:len(frame_indices)] = frame_indices
            frame_indices = padded_frame_indices
    elif "fps" in sample:  # fps0.5, sequentially sample frames at 0.5 fps
        output_fps = float(sample[3:])
        duration = float(vlen) / input_fps
        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents
        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
        frame_indices = np.around(frame_seconds * input_fps).astype(int)
        frame_indices = [e for e in frame_indices if e < vlen]
        if max_num_frames > 0 and len(frame_indices) > max_num_frames:
            frame_indices = frame_indices[:max_num_frames]
            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
    else:
        raise ValueError
    return frame_indices

def read_frames_av(
        video_path, num_frames, sample='rand', fix_start=None, 
        max_num_frames=-1, client=None, clip=None,
    ):
    reader = av.open(video_path)
    frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
    vlen = len(frames)
    duration = get_pyav_video_duration(reader)
    fps = vlen / float(duration)
    frame_indices = get_frame_indices(
        num_frames, vlen, sample=sample, fix_start=fix_start,
        input_fps=fps, max_num_frames=max_num_frames
    )
    frames = torch.stack([frames[idx] for idx in frame_indices])  # (T, H, W, C), torch.uint8
    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8
    return frames, frame_indices, fps

def read_frames_gif(
        video_path, num_frames, sample='rand', fix_start=None, 
        max_num_frames=-1, client=None, clip=None,
    ):
    if video_path.startswith('s3') or video_path.startswith('p2'):
        video_bytes = client.get(video_path)
        gif = imageio.get_reader(io.BytesIO(video_bytes))
    else:
        gif = imageio.get_reader(video_path)
    vlen = len(gif)
    frame_indices = get_frame_indices(
        num_frames, vlen, sample=sample, fix_start=fix_start,
        max_num_frames=max_num_frames
    )
    frames = []
    for index, frame in enumerate(gif):
        # for index in frame_idxs:
        if index in frame_indices:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            frame = torch.from_numpy(frame).byte()
            # # (H x W x C) to (C x H x W)
            frame = frame.permute(2, 0, 1)
            frames.append(frame)
    frames = torch.stack(frames)  # .float() / 255
    return frames, frame_indices, 25. # for tgif

def read_frames_decord(
        video_path, num_frames, sample='rand', fix_start=None, 
        max_num_frames=-1, client=None, clip=None
    ):
    if video_path.startswith('s3') or video_path.startswith('p2'):
        video_bytes = client.get(video_path)
        video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
    else:
        video_reader =
Download .txt
gitextract_1lgs5ey9/

├── LICENSE
├── PrepareVicuna.md
├── README.md
├── config/
│   ├── instructblipbase_avp.yaml
│   ├── instructblipbase_stllm_conversation.yaml
│   ├── instructblipbase_stllm_qa.yaml
│   ├── minigpt4base_avp.yaml
│   └── minigpt4base_stllm_qa.yaml
├── demo.py
├── demo_gradio.py
├── prompts/
│   └── alignment.txt
├── requirement.txt
├── script/
│   ├── inference/
│   │   ├── mvbench/
│   │   │   └── test_mvbench.sh
│   │   ├── qabench/
│   │   │   ├── anet_qa.sh
│   │   │   ├── msrvtt_qa.sh
│   │   │   ├── msvd_qa.sh
│   │   │   ├── score_anet.sh
│   │   │   ├── score_msrvtt.sh
│   │   │   └── score_msvd.sh
│   │   └── vcgbench/
│   │       ├── score_consist.sh
│   │       ├── score_context.sh
│   │       ├── score_correct.sh
│   │       ├── score_detail.sh
│   │       ├── score_temporal.sh
│   │       ├── test_consist.sh
│   │       ├── test_general.sh
│   │       └── test_temporal.sh
│   └── train/
│       └── train.sh
├── stllm/
│   ├── __init__.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dist_utils.py
│   │   ├── gradcam.py
│   │   ├── logger.py
│   │   ├── optims.py
│   │   ├── registry.py
│   │   └── utils.py
│   ├── configs/
│   │   ├── datasets/
│   │   │   ├── cc_sbu/
│   │   │   │   ├── align.yaml
│   │   │   │   └── defaults.yaml
│   │   │   └── laion/
│   │   │       └── defaults.yaml
│   │   ├── default.yaml
│   │   └── models/
│   │       ├── instructblip_vicuna0.yaml
│   │       ├── instructblip_vicuna0_btadapter.yaml
│   │       ├── minigpt4_vicuna0.yaml
│   │       └── minigpt4_vicuna0_btadapter.yaml
│   ├── conversation/
│   │   ├── __init__.py
│   │   ├── conversation.py
│   │   └── mvbench_conversation.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── builders/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset_builder.py
│   │   │   └── image_text_pair_builder.py
│   │   ├── data_utils.py
│   │   └── datasets/
│   │       ├── __init__.py
│   │       ├── base_dataset.py
│   │       ├── caption_datasets.py
│   │       ├── cc_sbu_dataset.py
│   │       ├── dataloader_utils.py
│   │       ├── image_video_itdatasets.py
│   │       ├── instruction_data.py
│   │       ├── laion_dataset.py
│   │       └── utils.py
│   ├── models/
│   │   ├── Qformer.py
│   │   ├── __init__.py
│   │   ├── base_decoder.py
│   │   ├── base_model.py
│   │   ├── blip2.py
│   │   ├── blip2_outputs.py
│   │   ├── eva_btadapter.py
│   │   ├── eva_vit.py
│   │   ├── modeling_llama_mem.py
│   │   ├── peft_model.py
│   │   ├── st_llm.py
│   │   └── utils.py
│   ├── processors/
│   │   ├── __init__.py
│   │   ├── base_processor.py
│   │   ├── blip_processors.py
│   │   ├── randaugment.py
│   │   └── video_transform.py
│   ├── runners/
│   │   ├── __init__.py
│   │   └── runner_base.py
│   ├── tasks/
│   │   ├── __init__.py
│   │   ├── base_task.py
│   │   └── image_text_pretrain.py
│   ├── test/
│   │   ├── __init__.py
│   │   ├── gpt_evaluation/
│   │   │   ├── evaluate_activitynet_qa.py
│   │   │   ├── evaluate_benchmark_1_correctness.py
│   │   │   ├── evaluate_benchmark_2_detailed_orientation.py
│   │   │   ├── evaluate_benchmark_3_context.py
│   │   │   ├── evaluate_benchmark_4_temporal.py
│   │   │   └── evaluate_benchmark_5_consistency.py
│   │   ├── mvbench/
│   │   │   ├── mv_bench.py
│   │   │   └── mv_bench_infer.py
│   │   ├── qabench/
│   │   │   ├── activitynet_qa.py
│   │   │   ├── msrvtt_qa.py
│   │   │   └── msvd_qa.py
│   │   ├── vcgbench/
│   │   │   ├── videochatgpt_benchmark_consist.py
│   │   │   └── videochatgpt_benchmark_general.py
│   │   ├── video_transforms.py
│   │   └── video_utils.py
│   └── train/
│       ├── stllm_trainer.py
│       ├── train.py
│       ├── train_hf.py
│       ├── zero2.json
│       ├── zero3.json
│       └── zero3_offload.json
└── trainval.md
Download .txt
SYMBOL INDEX (774 symbols across 61 files)

FILE: demo.py
  function parse_args (line 15) | def parse_args():

FILE: demo_gradio.py
  function parse_args (line 18) | def parse_args():
  function gradio_reset (line 57) | def gradio_reset(chat_state, img_list):
  function upload_video (line 65) | def upload_video(gr_video, chat_state, num_segments, text_prompt='Watch ...
  function gradio_ask (line 73) | def gradio_ask(user_message, chatbot, chat_state, gr_video, num_segments):
  function gradio_answer (line 85) | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
  class STLLM (line 94) | class STLLM(gr.themes.base.Base):
    method __init__ (line 95) | def __init__(

FILE: stllm/common/config.py
  class Config (line 16) | class Config:
    method __init__ (line 17) | def __init__(self, args):
    method _validate_runner_config (line 43) | def _validate_runner_config(self, runner_config):
    method _build_opt_list (line 52) | def _build_opt_list(self, opts):
    method build_model_config (line 57) | def build_model_config(config, **kwargs):
    method build_runner_config (line 84) | def build_runner_config(config):
    method build_dataset_config (line 88) | def build_dataset_config(config):
    method _convert_to_dot_list (line 117) | def _convert_to_dot_list(self, opts):
    method get_config (line 131) | def get_config(self):
    method run_cfg (line 135) | def run_cfg(self):
    method datasets_cfg (line 139) | def datasets_cfg(self):
    method model_cfg (line 143) | def model_cfg(self):
    method pretty_print (line 146) | def pretty_print(self):
    method _convert_node_to_json (line 164) | def _convert_node_to_json(self, node):
    method to_dict (line 168) | def to_dict(self):
  function node_to_dict (line 172) | def node_to_dict(node):
  class ConfigValidator (line 176) | class ConfigValidator:
    class _Argument (line 190) | class _Argument:
      method __init__ (line 191) | def __init__(self, name, choices=None, type=None, help=None):
      method __str__ (line 198) | def __str__(self):
    method __init__ (line 208) | def __init__(self, description):
    method __getitem__ (line 215) | def __getitem__(self, key):
    method __str__ (line 220) | def __str__(self) -> str:
    method add_argument (line 223) | def add_argument(self, *args, **kwargs):
    method validate (line 229) | def validate(self, config=None):
    method format_arguments (line 251) | def format_arguments(self):
    method format_help (line 254) | def format_help(self):
    method print_help (line 259) | def print_help(self):
  function create_runner_config_validator (line 264) | def create_runner_config_validator():

FILE: stllm/common/dist_utils.py
  function setup_for_distributed (line 17) | def setup_for_distributed(is_master):
  function is_dist_avail_and_initialized (line 33) | def is_dist_avail_and_initialized():
  function get_world_size (line 41) | def get_world_size():
  function get_rank (line 47) | def get_rank():
  function is_main_process (line 53) | def is_main_process():
  function init_distributed_mode (line 57) | def init_distributed_mode(args):
  function get_dist_info (line 96) | def get_dist_info():
  function main_process (line 110) | def main_process(func):
  function download_cached_file (line 120) | def download_cached_file(url, check_hash=True, progress=False):

FILE: stllm/common/gradcam.py
  function getAttMap (line 7) | def getAttMap(img, attMap, blur=True, overlap=True):

FILE: stllm/common/logger.py
  class SmoothedValue (line 19) | class SmoothedValue(object):
    method __init__ (line 24) | def __init__(self, window_size=20, fmt=None):
    method update (line 32) | def update(self, value, n=1):
    method synchronize_between_processes (line 37) | def synchronize_between_processes(self):
    method median (line 51) | def median(self):
    method avg (line 56) | def avg(self):
    method global_avg (line 61) | def global_avg(self):
    method max (line 65) | def max(self):
    method value (line 69) | def value(self):
    method __str__ (line 72) | def __str__(self):
  class MetricLogger (line 82) | class MetricLogger(object):
    method __init__ (line 83) | def __init__(self, delimiter="\t"):
    method update (line 87) | def update(self, **kwargs):
    method __getattr__ (line 94) | def __getattr__(self, attr):
    method __str__ (line 103) | def __str__(self):
    method global_avg (line 109) | def global_avg(self):
    method synchronize_between_processes (line 115) | def synchronize_between_processes(self):
    method add_meter (line 119) | def add_meter(self, name, meter):
    method log_every (line 122) | def log_every(self, iterable, print_freq, header=None):
  class AttrDict (line 184) | class AttrDict(dict):
    method __init__ (line 185) | def __init__(self, *args, **kwargs):
  function setup_logger (line 190) | def setup_logger():

FILE: stllm/common/optims.py
  class LinearWarmupStepLRScheduler (line 14) | class LinearWarmupStepLRScheduler:
    method __init__ (line 15) | def __init__(
    method step (line 37) | def step(self, cur_epoch, cur_step):
  class LinearWarmupCosineLRScheduler (line 56) | class LinearWarmupCosineLRScheduler:
    method __init__ (line 57) | def __init__(
    method step (line 78) | def step(self, cur_epoch, cur_step):
  function cosine_lr_schedule (line 97) | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
  function warmup_lr_schedule (line 105) | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
  function step_lr_schedule (line 111) | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):

FILE: stllm/common/registry.py
  class Registry (line 9) | class Registry:
    method register_builder (line 22) | def register_builder(cls, name):
    method register_task (line 54) | def register_task(cls, name):
    method register_model (line 83) | def register_model(cls, name):
    method register_processor (line 112) | def register_processor(cls, name):
    method register_lr_scheduler (line 141) | def register_lr_scheduler(cls, name):
    method register_runner (line 165) | def register_runner(cls, name):
    method register_path (line 189) | def register_path(cls, name, path):
    method register (line 205) | def register(cls, name, obj):
    method get_builder_class (line 232) | def get_builder_class(cls, name):
    method get_model_class (line 236) | def get_model_class(cls, name):
    method get_task_class (line 240) | def get_task_class(cls, name):
    method get_processor_class (line 244) | def get_processor_class(cls, name):
    method get_lr_scheduler_class (line 248) | def get_lr_scheduler_class(cls, name):
    method get_runner_class (line 252) | def get_runner_class(cls, name):
    method list_runners (line 256) | def list_runners(cls):
    method list_models (line 260) | def list_models(cls):
    method list_tasks (line 264) | def list_tasks(cls):
    method list_processors (line 268) | def list_processors(cls):
    method list_lr_schedulers (line 272) | def list_lr_schedulers(cls):
    method list_datasets (line 276) | def list_datasets(cls):
    method get_path (line 280) | def get_path(cls, name):
    method get (line 284) | def get(cls, name, default=None, no_warning=False):
    method unregister (line 315) | def unregister(cls, name):

FILE: stllm/common/utils.py
  function now (line 35) | def now():
  function is_url (line 41) | def is_url(url_or_filename):
  function get_cache_path (line 46) | def get_cache_path(rel_path):
  function get_abs_path (line 50) | def get_abs_path(rel_path):
  function load_json (line 54) | def load_json(filename):
  function makedir (line 64) | def makedir(dir_path):
  function get_redirected_url (line 78) | def get_redirected_url(url: str):
  function to_google_drive_download_url (line 93) | def to_google_drive_download_url(view_url: str) -> str:
  function download_google_drive_url (line 108) | def download_google_drive_url(url: str, output_path: str, output_file_na...
  function _get_google_drive_file_id (line 141) | def _get_google_drive_file_id(url: str) -> Optional[str]:
  function _urlretrieve (line 154) | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
  function download_url (line 167) | def download_url(
  function download_and_extract_archive (line 221) | def download_and_extract_archive(
  function cache_url (line 242) | def cache_url(url: str, cache_dir: str) -> str:
  function create_file_symlink (line 261) | def create_file_symlink(file1, file2):
  function save_file (line 275) | def save_file(data, filename, append_to_json=True, verbose=True):
  function load_file (line 313) | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
  function abspath (line 374) | def abspath(resource_path: str):
  function makedir (line 386) | def makedir(dir_path):
  function is_url (line 400) | def is_url(input_url):
  function cleanup_dir (line 408) | def cleanup_dir(dir):
  function get_file_size (line 419) | def get_file_size(filename):

FILE: stllm/conversation/conversation.py
  class SeparatorStyle (line 24) | class SeparatorStyle(Enum):
  class Conversation (line 31) | class Conversation:
    method get_prompt (line 46) | def get_prompt(self):
    method append_message (line 67) | def append_message(self, role, message):
    method to_gradio_chatbot (line 70) | def to_gradio_chatbot(self):
    method copy (line 79) | def copy(self):
    method dict (line 92) | def dict(self):
  class StoppingCriteriaSub (line 105) | class StoppingCriteriaSub(StoppingCriteria):
    method __init__ (line 107) | def __init__(self, stops=[], encounters=1):
    method __call__ (line 111) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function get_residual_index (line 118) | def get_residual_index(sample_segments, total_segments, devices):
  class Chat (line 181) | class Chat:
    method __init__ (line 182) | def __init__(self, model, device='cuda:0'):
    method ask (line 205) | def ask(self, text, conv):
    method answer (line 213) | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_...
    method upload_img (line 255) | def upload_img(self, image, conv, img_list):
    method upload_video (line 274) | def upload_video(self, video, conv, img_list, num_frame=64, text=None):
    method get_context_emb (line 301) | def get_context_emb(self, conv, img_list):
    method get_context_emb_sim (line 322) | def get_context_emb_sim(self, conv, img_list, system=True):

FILE: stllm/conversation/mvbench_conversation.py
  function get_prompt (line 5) | def get_prompt(conv):
  function get_prompt2 (line 14) | def get_prompt2(conv):
  function get_context_emb (line 28) | def get_context_emb(conv, model, img_list, answer_prompt=None):
  function get_context_emb_sim (line 62) | def get_context_emb_sim(conv, model, img_list, answer_prompt=None):
  function ask (line 87) | def ask(text, conv):
  class StoppingCriteriaSub (line 90) | class StoppingCriteriaSub(StoppingCriteria):
    method __init__ (line 91) | def __init__(self, stops=[], encounters=1):
    method __call__ (line 94) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function answer (line 100) | def answer(conv, model, img_list, ask_simple=False, do_sample=True, max_...
  class EasyDict (line 144) | class EasyDict(dict):
    method __init__ (line 256) | def __init__(self, d=None, **kwargs):
    method __setattr__ (line 268) | def __setattr__(self, name, value):
    method update (line 278) | def update(self, e=None, **f):
    method pop (line 284) | def pop(self, k, d=None):

FILE: stllm/datasets/builders/__init__.py
  function load_dataset (line 23) | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
  class DatasetZoo (line 61) | class DatasetZoo:
    method __init__ (line 62) | def __init__(self) -> None:
    method get_names (line 68) | def get_names(self):

FILE: stllm/datasets/builders/base_dataset_builder.py
  class BaseDatasetBuilder (line 25) | class BaseDatasetBuilder:
    method __init__ (line 28) | def __init__(self, cfg=None):
    method build_datasets (line 45) | def build_datasets(self):
    method build_processors (line 61) | def build_processors(self):
    method _build_proc_from_cfg (line 80) | def _build_proc_from_cfg(cfg):
    method default_config_path (line 88) | def default_config_path(cls, type="default"):
    method _download_data (line 91) | def _download_data(self):
    method _download_ann (line 95) | def _download_ann(self):
    method _download_vis (line 152) | def _download_vis(self):
    method build (line 166) | def build(self):
  function load_dataset_config (line 232) | def load_dataset_config(cfg_path):

FILE: stllm/datasets/builders/image_text_pair_builder.py
  class CCSBUBuilder (line 12) | class CCSBUBuilder(BaseDatasetBuilder):
    method _download_ann (line 17) | def _download_ann(self):
    method _download_vis (line 20) | def _download_vis(self):
    method build (line 23) | def build(self):
  class LaionBuilder (line 44) | class LaionBuilder(BaseDatasetBuilder):
    method _download_ann (line 49) | def _download_ann(self):
    method _download_vis (line 52) | def _download_vis(self):
    method build (line 55) | def build(self):
  class CCSBUAlignBuilder (line 76) | class CCSBUAlignBuilder(BaseDatasetBuilder):
    method build_datasets (line 83) | def build_datasets(self):

FILE: stllm/datasets/data_utils.py
  class ChainDataset (line 33) | class ChainDataset(wds.DataPipeline):
    method __init__ (line 43) | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
    method __iter__ (line 59) | def __iter__(self):
  function apply_to_sample (line 66) | def apply_to_sample(f, sample):
  function move_to_cuda (line 83) | def move_to_cuda(sample):
  function prepare_sample (line 90) | def prepare_sample(samples, cuda_enabled=True):
  function reorg_datasets_by_split (line 99) | def reorg_datasets_by_split(datasets):
  function concat_datasets (line 125) | def concat_datasets(datasets):

FILE: stllm/datasets/datasets/base_dataset.py
  class BaseDataset (line 15) | class BaseDataset(Dataset):
    method __init__ (line 16) | def __init__(
    method __len__ (line 38) | def __len__(self):
    method collater (line 41) | def collater(self, samples):
    method set_processors (line 44) | def set_processors(self, vis_processor, text_processor):
    method _add_instance_ids (line 48) | def _add_instance_ids(self, key="instance_id"):
  class ConcatDataset (line 53) | class ConcatDataset(ConcatDataset):
    method __init__ (line 54) | def __init__(self, datasets: Iterable[Dataset]) -> None:
    method collater (line 57) | def collater(self, samples):

FILE: stllm/datasets/datasets/caption_datasets.py
  class __DisplMixin (line 15) | class __DisplMixin:
    method displ_item (line 16) | def displ_item(self, index):
  class CaptionDataset (line 28) | class CaptionDataset(BaseDataset, __DisplMixin):
    method __init__ (line 29) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
    method __getitem__ (line 44) | def __getitem__(self, index):
  class CaptionEvalDataset (line 63) | class CaptionEvalDataset(BaseDataset, __DisplMixin):
    method __init__ (line 64) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
    method __getitem__ (line 72) | def __getitem__(self, index):

FILE: stllm/datasets/datasets/cc_sbu_dataset.py
  class CCSBUDataset (line 9) | class CCSBUDataset(BaseDataset):
    method __init__ (line 10) | def __init__(self, vis_processor, text_processor, location):
    method to_dict (line 23) | def to_dict(self, sample):
  class CCSBUAlignDataset (line 30) | class CCSBUAlignDataset(CaptionDataset):
    method __getitem__ (line 32) | def __getitem__(self, index):

FILE: stllm/datasets/datasets/dataloader_utils.py
  class MultiIterLoader (line 15) | class MultiIterLoader:
    method __init__ (line 24) | def __init__(self, loaders, ratios=None):
    method __next__ (line 40) | def __next__(self):
  class MetaLoader (line 45) | class MetaLoader(object):
    method __init__ (line 47) | def __init__(self, loaders, ratios=None):
    method build_iter (line 57) | def build_iter(self):
    method __len__ (line 73) | def __len__(self):
    method __iter__ (line 76) | def __iter__(self):
  class PrefetchLoader (line 84) | class PrefetchLoader(object):
    method __init__ (line 92) | def __init__(self, loader):
    method __iter__ (line 96) | def __iter__(self):
    method __len__ (line 111) | def __len__(self):
    method preload (line 114) | def preload(self, it):
    method next (line 139) | def next(self, it):
    method __getattr__ (line 147) | def __getattr__(self, name):
  function record_cuda_stream (line 152) | def record_cuda_stream(batch):
  class IterLoader (line 165) | class IterLoader:
    method __init__ (line 173) | def __init__(self, dataloader: DataLoader, use_distributed: bool = Fal...
    method epoch (line 180) | def epoch(self) -> int:
    method __next__ (line 183) | def __next__(self):
    method __iter__ (line 196) | def __iter__(self):
    method __len__ (line 199) | def __len__(self):

FILE: stllm/datasets/datasets/image_video_itdatasets.py
  class ImageVideoBaseDataset (line 25) | class ImageVideoBaseDataset(Dataset):
    method __init__ (line 30) | def __init__(self):
    method __getitem__ (line 44) | def __getitem__(self, index):
    method __len__ (line 47) | def __len__(self):
    method get_anno (line 50) | def get_anno(self, index):
    method load_and_transform_media_data (line 66) | def load_and_transform_media_data(self, index, data_path):
    method load_and_transform_media_data_image (line 72) | def load_and_transform_media_data_image(self, index, data_path):
    method load_and_transform_media_data_video (line 77) | def load_and_transform_media_data_video(self, index, data_path, return...
  class PTImgTrainDataset (line 107) | class PTImgTrainDataset(ImageVideoBaseDataset):
    method __init__ (line 110) | def __init__(self, ann_file, transform, pre_text=True):
    method get_anno (line 128) | def get_anno(self, index):
    method __len__ (line 134) | def __len__(self):
    method __getitem__ (line 137) | def __getitem__(self, index):
  class PTVidTrainDataset (line 148) | class PTVidTrainDataset(PTImgTrainDataset):
    method __init__ (line 151) | def __init__(
  class ITImgTrainDataset (line 168) | class ITImgTrainDataset(ImageVideoBaseDataset):
    method __init__ (line 171) | def __init__(
    method get_anno (line 206) | def get_anno(self, index):
    method __len__ (line 220) | def __len__(self):
    method process_qa (line 223) | def process_qa(self, qa, msg=""):
    method __getitem__ (line 252) | def __getitem__(self, index):
  class ITVidTrainDataset (line 268) | class ITVidTrainDataset(ITImgTrainDataset):
    method __init__ (line 271) | def __init__(
    method __getitem__ (line 297) | def __getitem__(self, index):

FILE: stllm/datasets/datasets/laion_dataset.py
  class LaionDataset (line 12) | class LaionDataset(BaseDataset):
    method __init__ (line 13) | def __init__(self, vis_processor, text_processor, location):
    method to_dict (line 26) | def to_dict(self, sample):

FILE: stllm/datasets/datasets/utils.py
  function load_image_from_path (line 32) | def load_image_from_path(image_path, client):
  function load_anno (line 43) | def load_anno(ann_file_list):
  function pre_text (line 80) | def pre_text(text, max_l=None, pre_text=True):
  function collect_result (line 100) | def collect_result(result, result_dir, filename, is_json=True, is_list=T...
  function sync_save_result (line 138) | def sync_save_result(result, result_dir, filename, is_json=True, is_list...
  function pad_sequences_1d (line 184) | def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("c...
  function pts_to_secs (line 234) | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
  function get_pyav_video_duration (line 249) | def get_pyav_video_duration(video_reader):
  function get_frame_indices_by_fps (line 259) | def get_frame_indices_by_fps():
  function get_frame_indices (line 262) | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, i...
  function read_frames_av (line 302) | def read_frames_av(
  function read_frames_gif (line 319) | def read_frames_gif(
  function read_frames_decord (line 345) | def read_frames_decord(
  function read_frames_rawframes (line 375) | def read_frames_rawframes(

FILE: stllm/models/Qformer.py
  class BertEmbeddings (line 51) | class BertEmbeddings(nn.Module):
    method __init__ (line 54) | def __init__(self, config):
    method forward (line 78) | def forward(
  class BertSelfAttention (line 111) | class BertSelfAttention(nn.Module):
    method __init__ (line 112) | def __init__(self, config, is_cross_attention):
    method save_attn_gradients (line 149) | def save_attn_gradients(self, attn_gradients):
    method get_attn_gradients (line 152) | def get_attn_gradients(self):
    method save_attention_map (line 155) | def save_attention_map(self, attention_map):
    method get_attention_map (line 158) | def get_attention_map(self):
    method transpose_for_scores (line 161) | def transpose_for_scores(self, x):
    method forward (line 169) | def forward(
  class BertSelfOutput (line 278) | class BertSelfOutput(nn.Module):
    method __init__ (line 279) | def __init__(self, config):
    method forward (line 285) | def forward(self, hidden_states, input_tensor):
  class BertAttention (line 292) | class BertAttention(nn.Module):
    method __init__ (line 293) | def __init__(self, config, is_cross_attention=False):
    method prune_heads (line 299) | def prune_heads(self, heads):
    method forward (line 322) | def forward(
  class BertIntermediate (line 349) | class BertIntermediate(nn.Module):
    method __init__ (line 350) | def __init__(self, config):
    method forward (line 358) | def forward(self, hidden_states):
  class BertOutput (line 364) | class BertOutput(nn.Module):
    method __init__ (line 365) | def __init__(self, config):
    method forward (line 371) | def forward(self, hidden_states, input_tensor):
  class BertLayer (line 378) | class BertLayer(nn.Module):
    method __init__ (line 379) | def __init__(self, config, layer_num):
    method forward (line 402) | def forward(
    method feed_forward_chunk (line 476) | def feed_forward_chunk(self, attention_output):
    method feed_forward_chunk_query (line 481) | def feed_forward_chunk_query(self, attention_output):
  class BertEncoder (line 487) | class BertEncoder(nn.Module):
    method __init__ (line 488) | def __init__(self, config):
    method forward (line 495) | def forward(
  class BertPooler (line 592) | class BertPooler(nn.Module):
    method __init__ (line 593) | def __init__(self, config):
    method forward (line 598) | def forward(self, hidden_states):
  class BertPredictionHeadTransform (line 607) | class BertPredictionHeadTransform(nn.Module):
    method __init__ (line 608) | def __init__(self, config):
    method forward (line 617) | def forward(self, hidden_states):
  class BertLMPredictionHead (line 624) | class BertLMPredictionHead(nn.Module):
    method __init__ (line 625) | def __init__(self, config):
    method forward (line 638) | def forward(self, hidden_states):
  class BertOnlyMLMHead (line 644) | class BertOnlyMLMHead(nn.Module):
    method __init__ (line 645) | def __init__(self, config):
    method forward (line 649) | def forward(self, sequence_output):
  class BertPreTrainedModel (line 654) | class BertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 664) | def _init_weights(self, module):
  class BertModel (line 677) | class BertModel(BertPreTrainedModel):
    method __init__ (line 687) | def __init__(self, config, add_pooling_layer=False):
    method get_input_embeddings (line 699) | def get_input_embeddings(self):
    method set_input_embeddings (line 702) | def set_input_embeddings(self, value):
    method _prune_heads (line 705) | def _prune_heads(self, heads_to_prune):
    method get_extended_attention_mask (line 713) | def get_extended_attention_mask(
    method forward (line 804) | def forward(
  class BertLMHeadModel (line 968) | class BertLMHeadModel(BertPreTrainedModel):
    method __init__ (line 973) | def __init__(self, config):
    method get_output_embeddings (line 981) | def get_output_embeddings(self):
    method set_output_embeddings (line 984) | def set_output_embeddings(self, new_embeddings):
    method forward (line 987) | def forward(
    method prepare_inputs_for_generation (line 1097) | def prepare_inputs_for_generation(
    method _reorder_cache (line 1120) | def _reorder_cache(self, past, beam_idx):
  class BertForMaskedLM (line 1131) | class BertForMaskedLM(BertPreTrainedModel):
    method __init__ (line 1136) | def __init__(self, config):
    method get_output_embeddings (line 1144) | def get_output_embeddings(self):
    method set_output_embeddings (line 1147) | def set_output_embeddings(self, new_embeddings):
    method forward (line 1150) | def forward(

FILE: stllm/models/__init__.py
  function load_model (line 27) | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint...
  function load_preprocess (line 61) | def load_preprocess(config):
  function load_model_and_preprocess (line 113) | def load_model_and_preprocess(name, model_type, is_eval=False, device="c...
  class ModelZoo (line 161) | class ModelZoo:
    method __init__ (line 172) | def __init__(self) -> None:
    method __str__ (line 178) | def __str__(self) -> str:
    method __iter__ (line 193) | def __iter__(self):
    method __len__ (line 196) | def __len__(self):

FILE: stllm/models/base_decoder.py
  class DropPath (line 10) | class DropPath(nn.Module):
    method __init__ (line 13) | def __init__(self, drop_prob=None):
    method forward (line 17) | def forward(self, x):
    method extra_repr (line 20) | def extra_repr(self) -> str:
  class Mlp (line 24) | class Mlp(nn.Module):
    method __init__ (line 25) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 34) | def forward(self, x):
  class Attention (line 44) | class Attention(nn.Module):
    method __init__ (line 45) | def __init__(
    method forward (line 68) | def forward(self, x):
  class Block (line 91) | class Block(nn.Module):
    method __init__ (line 93) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
    method forward (line 113) | def forward(self, x):
  class PretrainVisionTransformerDecoder (line 123) | class PretrainVisionTransformerDecoder(nn.Module):
    method __init__ (line 126) | def __init__(self, embed_dim=4096, depth=2, num_heads=32, mlp_ratio=2....
    method _init_weights (line 147) | def _init_weights(self, m):
    method get_num_layers (line 156) | def get_num_layers(self):
    method no_weight_decay (line 160) | def no_weight_decay(self):
    method get_classifier (line 163) | def get_classifier(self):
    method reset_classifier (line 166) | def reset_classifier(self, num_classes, global_pool=''):
    method forward (line 170) | def forward(self, x, return_token_num = 0):

FILE: stllm/models/base_model.py
  class BaseModel (line 19) | class BaseModel(nn.Module):
    method __init__ (line 22) | def __init__(self):
    method device (line 26) | def device(self):
    method load_checkpoint (line 29) | def load_checkpoint(self, url_or_filename):
    method from_pretrained (line 59) | def from_pretrained(cls, model_type):
    method default_config_path (line 75) | def default_config_path(cls, model_type):
    method load_checkpoint_from_config (line 81) | def load_checkpoint_from_config(self, cfg, **kwargs):
    method before_evaluation (line 102) | def before_evaluation(self, **kwargs):
    method show_n_params (line 105) | def show_n_params(self, return_str=True):
  class BaseEncoder (line 121) | class BaseEncoder(nn.Module):
    method __init__ (line 126) | def __init__(self):
    method forward_features (line 129) | def forward_features(self, samples, **kwargs):
    method device (line 133) | def device(self):
  class SharedQueueMixin (line 137) | class SharedQueueMixin:
    method _dequeue_and_enqueue (line 139) | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
  class MomentumDistilationMixin (line 161) | class MomentumDistilationMixin:
    method copy_params (line 163) | def copy_params(self):
    method _momentum_update (line 172) | def _momentum_update(self):
  class GatherLayer (line 182) | class GatherLayer(torch.autograd.Function):
    method forward (line 189) | def forward(ctx, x):
    method backward (line 197) | def backward(ctx, *grads):
  function all_gather_with_grad (line 203) | def all_gather_with_grad(tensors):
  function concat_all_gather (line 221) | def concat_all_gather(tensor):
  function tile (line 239) | def tile(x, dim, n_tile):

FILE: stllm/models/blip2.py
  class Blip2Base (line 29) | class Blip2Base(BaseModel):
    method init_tokenizer (line 31) | def init_tokenizer(cls, truncation_side="right"):
    method maybe_autocast (line 36) | def maybe_autocast(self, dtype=torch.float16):
    method init_Qformer (line 47) | def init_Qformer(cls, num_query_token, vision_width, cross_attention_f...
    method init_vision_encoder (line 62) | def init_vision_encoder(
    method load_from_pretrained (line 76) | def load_from_pretrained(self, url_or_filename):
  function disabled_train (line 97) | def disabled_train(self, mode=True):
  class LayerNorm (line 103) | class LayerNorm(nn.LayerNorm):
    method forward (line 106) | def forward(self, x: torch.Tensor):
  function compute_sim_matrix (line 112) | def compute_sim_matrix(model, data_loader, **kwargs):

FILE: stllm/models/blip2_outputs.py
  class BlipSimilarity (line 20) | class BlipSimilarity(ModelOutput):
  class BlipIntermediateOutput (line 32) | class BlipIntermediateOutput(ModelOutput):
  class BlipOutput (line 73) | class BlipOutput(ModelOutput):
  class BlipOutputFeatures (line 89) | class BlipOutputFeatures(ModelOutput):

FILE: stllm/models/eva_btadapter.py
  function constant_init (line 40) | def constant_init(module, val, bias=0):
  class EVAVisionTransformer_BTAdapter (line 46) | class EVAVisionTransformer_BTAdapter(nn.Module):
    method __init__ (line 49) | def __init__(self, depth=4, mask_rate=0):
    method init_weights (line 89) | def init_weights(self):
    method fix_init_weight (line 101) | def fix_init_weight(self):
    method get_cast_dtype (line 112) | def get_cast_dtype(self) -> torch.dtype:
    method _init_weights (line 115) | def _init_weights(self, m):
    method get_num_layers (line 124) | def get_num_layers(self):
    method lock (line 127) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 133) | def set_grad_checkpointing(self, enable=True):
    method no_weight_decay (line 137) | def no_weight_decay(self):
    method get_classifier (line 140) | def get_classifier(self):
    method reset_classifier (line 143) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 147) | def forward_features(self, x, mask=None):
    method forward_branch (line 186) | def forward_branch(self, x, branch_x, num_layer, mask=None):
    method init_input (line 209) | def init_input(self, x, mask=None):
    method forward (line 233) | def forward(self, x, return_all_features=False):
  class BTAdapter_Spatial (line 257) | class BTAdapter_Spatial(Block):
    method __init__ (line 258) | def __init__(self, d_model, n_head, drop_num=0.1):
    method forward (line 261) | def forward(self, x, T):
  class BTAdapter_Temp (line 283) | class BTAdapter_Temp(nn.Module):
    method __init__ (line 284) | def __init__(self, d_model, n_head, drop_num=0.1, norm_layer=partial(n...
    method forward (line 295) | def forward(self, x, T):
  function create_eva_btadapter (line 312) | def create_eva_btadapter(precision="fp16"):

FILE: stllm/models/eva_vit.py
  function _cfg (line 20) | def _cfg(url='', **kwargs):
  class DropPath (line 30) | class DropPath(nn.Module):
    method __init__ (line 33) | def __init__(self, drop_prob=None):
    method forward (line 37) | def forward(self, x):
    method extra_repr (line 40) | def extra_repr(self) -> str:
  class Mlp (line 44) | class Mlp(nn.Module):
    method __init__ (line 45) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 54) | def forward(self, x):
  class Attention (line 64) | class Attention(nn.Module):
    method __init__ (line 65) | def __init__(
    method forward (line 118) | def forward(self, x, rel_pos_bias=None):
  class Block (line 151) | class Block(nn.Module):
    method __init__ (line 153) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
    method forward (line 173) | def forward(self, x, rel_pos_bias=None):
  class PatchEmbed (line 183) | class PatchEmbed(nn.Module):
    method __init__ (line 186) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 198) | def forward(self, x, **kwargs):
  class RelativePositionBias (line 207) | class RelativePositionBias(nn.Module):
    method __init__ (line 209) | def __init__(self, window_size, num_heads):
    method forward (line 238) | def forward(self):
  class VisionTransformer (line 246) | class VisionTransformer(nn.Module):
    method __init__ (line 249) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
    method fix_init_weight (line 300) | def fix_init_weight(self):
    method _init_weights (line 308) | def _init_weights(self, m):
    method get_classifier (line 317) | def get_classifier(self):
    method reset_classifier (line 320) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 324) | def forward_features(self, x):
    method forward (line 349) | def forward(self, x):
    method get_intermediate_layers (line 354) | def get_intermediate_layers(self, x):
  function interpolate_pos_embed (line 373) | def interpolate_pos_embed(model, checkpoint_model):
  function convert_weights_to_fp16 (line 397) | def convert_weights_to_fp16(model: nn.Module):
  function create_eva_vit_g (line 415) | def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=Fals...

FILE: stllm/models/modeling_llama_mem.py
  function _make_causal_mask (line 29) | def _make_causal_mask(
  function _expand_mask (line 47) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class LlamaRMSNorm (line 61) | class LlamaRMSNorm(nn.Module):
    method __init__ (line 62) | def __init__(self, hidden_size, eps=1e-6):
    method forward (line 70) | def forward(self, hidden_states):
  class LlamaRotaryEmbedding (line 81) | class LlamaRotaryEmbedding(torch.nn.Module):
    method __init__ (line 82) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method forward (line 96) | def forward(self, x, seq_len=None):
  function rotate_half (line 113) | def rotate_half(x):
  function apply_rotary_pos_emb (line 120) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  class LlamaMLP (line 130) | class LlamaMLP(nn.Module):
    method __init__ (line 131) | def __init__(
    method forward (line 143) | def forward(self, x):
  class LlamaAttention (line 147) | class LlamaAttention(nn.Module):
    method __init__ (line 150) | def __init__(self, config: LlamaConfig):
    method _shape (line 169) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 172) | def forward(
  class LlamaDecoderLayer (line 251) | class LlamaDecoderLayer(nn.Module):
    method __init__ (line 252) | def __init__(self, config: LlamaConfig):
    method forward (line 264) | def forward(
  class LlamaPreTrainedModel (line 340) | class LlamaPreTrainedModel(PreTrainedModel):
    method _init_weights (line 347) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 358) | def _set_gradient_checkpointing(self, module, value=False):
  class LlamaModel (line 431) | class LlamaModel(LlamaPreTrainedModel):
    method __init__ (line 439) | def __init__(self, config: LlamaConfig):
    method get_input_embeddings (line 452) | def get_input_embeddings(self):
    method set_input_embeddings (line 455) | def set_input_embeddings(self, value):
    method _prepare_decoder_attention_mask (line 458) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
    method forward (line 464) | def forward(
  class LlamaForCausalLM (line 597) | class LlamaForCausalLM(LlamaPreTrainedModel):
    method __init__ (line 598) | def __init__(self, config):
    method get_input_embeddings (line 607) | def get_input_embeddings(self):
    method set_input_embeddings (line 610) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 613) | def get_output_embeddings(self):
    method set_output_embeddings (line 616) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 619) | def set_decoder(self, decoder):
    method get_decoder (line 622) | def get_decoder(self):
    method forward (line 627) | def forward(
    method prepare_inputs_for_generation (line 715) | def prepare_inputs_for_generation(
    method _reorder_cache (line 748) | def _reorder_cache(past_key_values, beam_idx):

FILE: stllm/models/peft_model.py
  function forward (line 26) | def forward(
  function replace_peftmodel_with_sample_input (line 101) | def replace_peftmodel_with_sample_input():

FILE: stllm/models/st_llm.py
  class StllmConfig (line 31) | class StllmConfig(LlamaConfig):
  class Linear_Decoder (line 35) | class Linear_Decoder(nn.Module):
    method __init__ (line 36) | def __init__(self, output_dim=4096, embed_dim=4096):
    method forward (line 41) | def forward(self, x):
  class STLLMLlamaModel (line 45) | class STLLMLlamaModel(LlamaModel):
    method __init__ (line 47) | def __init__(self, config: LlamaConfig):  # TODO: Remove unused params
    method initialize_vision_modules (line 50) | def initialize_vision_modules(self, cfg):
    method forward (line 56) | def forward(self, samples=None, inputs_embeds=None, **kwargs):
  class STLLMForCausalLM (line 95) | class STLLMForCausalLM(LlamaForCausalLM, BaseModel):
    method __init__ (line 104) | def __init__(self, config):
    method get_model (line 113) | def get_model(self):
    method forward (line 116) | def forward(self, samples=None, inputs_embeds=None, **kwargs):
    method get_state_dict (line 150) | def get_state_dict(self, path, prefix='pytorch_model'):
    method from_config (line 161) | def from_config(cls, cfg):
  class STLLMModel (line 205) | class STLLMModel(Blip2Base):
    method __init__ (line 209) | def __init__(
    method encode_img (line 321) | def encode_img(self, image, text=None):
    method prompt_wrap (line 379) | def prompt_wrap(self, img_embeds, atts_img, prompts):
    method concat_emb_input_output (line 409) | def concat_emb_input_output(self, input_embs, input_atts, output_embs,...
    method get_residual_index (line 434) | def get_residual_index(self, sample_segments, total_segments, devices):
    method forward (line 447) | def forward(self, samples):
    method from_config (line 549) | def from_config(cls, cfg):

FILE: stllm/models/utils.py
  function RandomMaskingGenerator (line 4) | def RandomMaskingGenerator(num_patches, mask_ratio, batch, device='cuda'):
  function get_sinusoid_encoding_table (line 18) | def get_sinusoid_encoding_table(n_position, d_hid):

FILE: stllm/processors/__init__.py
  function load_processor (line 25) | def load_processor(name, cfg=None):

FILE: stllm/processors/base_processor.py
  class BaseProcessor (line 11) | class BaseProcessor:
    method __init__ (line 12) | def __init__(self):
    method __call__ (line 16) | def __call__(self, item):
    method from_config (line 20) | def from_config(cls, cfg=None):
    method build (line 23) | def build(self, **kwargs):

FILE: stllm/processors/blip_processors.py
  class BlipImageBaseProcessor (line 19) | class BlipImageBaseProcessor(BaseProcessor):
    method __init__ (line 20) | def __init__(self, mean=None, std=None):
  class BlipCaptionProcessor (line 30) | class BlipCaptionProcessor(BaseProcessor):
    method __init__ (line 31) | def __init__(self, prompt="", max_words=50):
    method __call__ (line 35) | def __call__(self, caption):
    method from_config (line 41) | def from_config(cls, cfg=None):
    method pre_caption (line 50) | def pre_caption(self, caption):
  class Blip2ImageTrainProcessor (line 73) | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
    method __init__ (line 74) | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5,...
    method __call__ (line 89) | def __call__(self, item):
    method from_config (line 93) | def from_config(cls, cfg=None):
  class Blip2VideoTrainProcessor (line 114) | class Blip2VideoTrainProcessor(BaseProcessor):
    method __init__ (line 115) | def __init__(self, num_frames=16, test_mode=True):
    method __call__ (line 125) | def __call__(self, item):
    method from_config (line 129) | def from_config(cls, cfg=None):
  class Blip2ImageEvalProcessor (line 139) | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
    method __init__ (line 140) | def __init__(self, image_size=224, mean=None, std=None):
    method __call__ (line 153) | def __call__(self, item):
    method from_config (line 157) | def from_config(cls, cfg=None):

FILE: stllm/processors/randaugment.py
  function identity_func (line 15) | def identity_func(img):
  function autocontrast_func (line 19) | def autocontrast_func(img, cutoff=0):
  function equalize_func (line 52) | def equalize_func(img):
  function rotate_func (line 76) | def rotate_func(img, degree, fill=(0, 0, 0)):
  function solarize_func (line 87) | def solarize_func(img, thresh=128):
  function color_func (line 97) | def color_func(img, factor):
  function contrast_func (line 115) | def contrast_func(img, factor):
  function brightness_func (line 129) | def brightness_func(img, factor):
  function sharpness_func (line 138) | def sharpness_func(img, factor):
  function shear_x_func (line 159) | def shear_x_func(img, factor, fill=(0, 0, 0)):
  function translate_x_func (line 168) | def translate_x_func(img, offset, fill=(0, 0, 0)):
  function translate_y_func (line 180) | def translate_y_func(img, offset, fill=(0, 0, 0)):
  function posterize_func (line 192) | def posterize_func(img, bits):
  function shear_y_func (line 200) | def shear_y_func(img, factor, fill=(0, 0, 0)):
  function cutout_func (line 209) | def cutout_func(img, pad_size, replace=(0, 0, 0)):
  function enhance_level_to_args (line 223) | def enhance_level_to_args(MAX_LEVEL):
  function shear_level_to_args (line 230) | def shear_level_to_args(MAX_LEVEL, replace_value):
  function translate_level_to_args (line 240) | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
  function cutout_level_to_args (line 250) | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
  function solarize_level_to_args (line 258) | def solarize_level_to_args(MAX_LEVEL):
  function none_level_to_args (line 266) | def none_level_to_args(level):
  function posterize_level_to_args (line 270) | def posterize_level_to_args(MAX_LEVEL):
  function rotate_level_to_args (line 278) | def rotate_level_to_args(MAX_LEVEL, replace_value):
  class RandomAugment (line 326) | class RandomAugment(object):
    method __init__ (line 327) | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
    method get_random_ops (line 336) | def get_random_ops(self):
    method __call__ (line 340) | def __call__(self, img):
  class VideoRandomAugment (line 352) | class VideoRandomAugment(object):
    method __init__ (line 353) | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
    method get_random_ops (line 363) | def get_random_ops(self):
    method __call__ (line 367) | def __call__(self, frames):
    method _aug (line 386) | def _aug(self, img, ops, apply_or_not):

FILE: stllm/processors/video_transform.py
  class SampleFrames (line 3) | class SampleFrames:
    method __init__ (line 40) | def __init__(self,
    method _get_train_clips (line 61) | def _get_train_clips(self, num_frames: int,
    method _get_test_clips (line 106) | def _get_test_clips(self, num_frames: int,
    method _sample_clips (line 147) | def _sample_clips(self, num_frames: int, ori_clip_len: float) -> np.ar...
    method _get_ori_clip_len (line 163) | def _get_ori_clip_len(self, fps_scale_ratio: float) -> float:
    method __call__ (line 180) | def __call__(self, x):

FILE: stllm/runners/runner_base.py
  class RunnerBase (line 39) | class RunnerBase:
    method __init__ (line 47) | def __init__(self, cfg, task, model, datasets, job_id):
    method device (line 69) | def device(self):
    method use_distributed (line 76) | def use_distributed(self):
    method model (line 80) | def model(self):
    method optimizer (line 100) | def optimizer(self):
    method scaler (line 133) | def scaler(self):
    method lr_scheduler (line 143) | def lr_scheduler(self):
    method dataloaders (line 185) | def dataloaders(self) -> dict:
    method cuda_enabled (line 282) | def cuda_enabled(self):
    method max_epoch (line 286) | def max_epoch(self):
    method log_freq (line 290) | def log_freq(self):
    method init_lr (line 295) | def init_lr(self):
    method min_lr (line 299) | def min_lr(self):
    method accum_grad_iters (line 303) | def accum_grad_iters(self):
    method valid_splits (line 307) | def valid_splits(self):
    method test_splits (line 316) | def test_splits(self):
    method train_splits (line 322) | def train_splits(self):
    method evaluate_only (line 331) | def evaluate_only(self):
    method use_dist_eval_sampler (line 338) | def use_dist_eval_sampler(self):
    method resume_ckpt_path (line 342) | def resume_ckpt_path(self):
    method train_loader (line 346) | def train_loader(self):
    method setup_output_dir (line 351) | def setup_output_dir(self):
    method train (line 366) | def train(self):
    method evaluate (line 426) | def evaluate(self, cur_epoch="best", skip_reload=False):
    method train_epoch (line 437) | def train_epoch(self, epoch):
    method eval_epoch (line 454) | def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
    method unwrap_dist_model (line 488) | def unwrap_dist_model(self, model):
    method create_loaders (line 494) | def create_loaders(
    method _save_checkpoint (line 584) | def _save_checkpoint(self, cur_epoch, is_best=False):
    method _reload_best_model (line 611) | def _reload_best_model(self, model):
    method _load_checkpoint (line 631) | def _load_checkpoint(self, url_or_filename):
    method log_stats (line 656) | def log_stats(self, stats, split_name):
    method log_config (line 665) | def log_config(self):

FILE: stllm/tasks/__init__.py
  function setup_task (line 13) | def setup_task(cfg):

FILE: stllm/tasks/base_task.py
  class BaseTask (line 19) | class BaseTask:
    method __init__ (line 20) | def __init__(self, **kwargs):
    method setup_task (line 26) | def setup_task(cls, **kwargs):
    method build_model (line 29) | def build_model(self, cfg):
    method build_datasets (line 35) | def build_datasets(self, cfg):
    method train_step (line 67) | def train_step(self, model, samples):
    method valid_step (line 71) | def valid_step(self, model, samples):
    method before_evaluation (line 74) | def before_evaluation(self, model, dataset, **kwargs):
    method after_evaluation (line 77) | def after_evaluation(self, **kwargs):
    method inference_step (line 80) | def inference_step(self):
    method evaluation (line 83) | def evaluation(self, model, data_loader, cuda_enabled=True):
    method train_epoch (line 102) | def train_epoch(
    method train_iters (line 127) | def train_iters(
    method _train_inner_loop (line 155) | def _train_inner_loop(
    method save_result (line 249) | def save_result(result, result_dir, filename, remove_duplicate=""):

FILE: stllm/tasks/image_text_pretrain.py
  class ImageTextPretrainTask (line 14) | class ImageTextPretrainTask(BaseTask):
    method __init__ (line 15) | def __init__(self):
    method evaluation (line 18) | def evaluation(self, model, data_loader, cuda_enabled=True):
  class VideoTextItTask (line 22) | class VideoTextItTask(ImageTextPretrainTask):
    method __init__ (line 23) | def __init__(self):
    method build_datasets (line 26) | def build_datasets(self, cfg):
  function get_media_type (line 51) | def get_media_type(dataset_info):

FILE: stllm/test/gpt_evaluation/evaluate_activitynet_qa.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 74) | def main():

FILE: stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 76) | def main():

FILE: stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 75) | def main():

FILE: stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 75) | def main():

FILE: stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 74) | def main():

FILE: stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py
  function parse_args (line 9) | def parse_args():
  function annotate (line 20) | def annotate(prediction_set, caption_files, output_dir):
  function main (line 80) | def main():

FILE: stllm/test/mvbench/mv_bench.py
  class MVBench_dataset (line 49) | class MVBench_dataset(Dataset):
    method __init__ (line 50) | def __init__(self, data_dir, data_list=data_list, num_segments=8, reso...
    method __str__ (line 87) | def __str__(self):
    method __len__ (line 109) | def __len__(self):
    method get_index (line 112) | def get_index(self, bound, fps, max_frame, first_idx=0):
    method read_video (line 141) | def read_video(self, video_path, bound=None):
    method read_gif (line 156) | def read_gif(self, video_path, bound=None, fps=25):
    method read_frame (line 171) | def read_frame(self, video_path, bound=None, fps=3):
    method qa_template (line 187) | def qa_template(self, data):
    method __getitem__ (line 200) | def __getitem__(self, idx):
  function get_residual_index (line 220) | def get_residual_index(sample_segments, total_segments, devices):
  function infer_mvbench (line 229) | def infer_mvbench(
  function check_ans (line 285) | def check_ans(pred, gt):

FILE: stllm/test/mvbench/mv_bench_infer.py
  function parse_args (line 20) | def parse_args():
  function run_inference (line 44) | def run_inference(args):

FILE: stllm/test/qabench/activitynet_qa.py
  function parse_args (line 24) | def parse_args():
  function run_inference (line 51) | def run_inference(args):

FILE: stllm/test/qabench/msrvtt_qa.py
  function parse_args (line 23) | def parse_args():
  function run_inference (line 48) | def run_inference(args):

FILE: stllm/test/qabench/msvd_qa.py
  function parse_args (line 23) | def parse_args():
  function run_inference (line 48) | def run_inference(args):

FILE: stllm/test/vcgbench/videochatgpt_benchmark_consist.py
  function parse_args (line 23) | def parse_args():
  function run_inference (line 48) | def run_inference(args):

FILE: stllm/test/vcgbench/videochatgpt_benchmark_general.py
  function parse_args (line 23) | def parse_args():
  function run_inference (line 48) | def run_inference(args):

FILE: stllm/test/video_transforms.py
  class GroupRandomCrop (line 10) | class GroupRandomCrop(object):
    method __init__ (line 11) | def __init__(self, size):
    method __call__ (line 17) | def __call__(self, img_group):
  class MultiGroupRandomCrop (line 37) | class MultiGroupRandomCrop(object):
    method __init__ (line 38) | def __init__(self, size, groups=1):
    method __call__ (line 45) | def __call__(self, img_group):
  class GroupCenterCrop (line 66) | class GroupCenterCrop(object):
    method __init__ (line 67) | def __init__(self, size):
    method __call__ (line 70) | def __call__(self, img_group):
  class GroupRandomHorizontalFlip (line 74) | class GroupRandomHorizontalFlip(object):
    method __init__ (line 78) | def __init__(self, is_flow=False):
    method __call__ (line 81) | def __call__(self, img_group, is_flow=False):
  class GroupNormalize (line 94) | class GroupNormalize(object):
    method __init__ (line 95) | def __init__(self, mean, std):
    method __call__ (line 99) | def __call__(self, tensor):
  class GroupScale (line 110) | class GroupScale(object):
    method __init__ (line 119) | def __init__(self, size, interpolation=Image.BILINEAR):
    method __call__ (line 122) | def __call__(self, img_group):
  class GroupOverSample (line 126) | class GroupOverSample(object):
    method __init__ (line 127) | def __init__(self, crop_size, scale_size=None, flip=True):
    method __call__ (line 137) | def __call__(self, img_group):
  class GroupFullResSample (line 167) | class GroupFullResSample(object):
    method __init__ (line 168) | def __init__(self, crop_size, scale_size=None, flip=True):
    method __call__ (line 178) | def __call__(self, img_group):
  class GroupMultiScaleCrop (line 214) | class GroupMultiScaleCrop(object):
    method __init__ (line 216) | def __init__(self, input_size, scales=None, max_distort=1,
    method __call__ (line 226) | def __call__(self, img_group):
    method _sample_crop_size (line 243) | def _sample_crop_size(self, im_size):
    method _sample_fix_offset (line 272) | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
    method fill_fix_offset (line 278) | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
  class GroupRandomSizedCrop (line 303) | class GroupRandomSizedCrop(object):
    method __init__ (line 311) | def __init__(self, size, interpolation=Image.BILINEAR):
    method __call__ (line 315) | def __call__(self, img_group):
  class ConvertDataFormat (line 353) | class ConvertDataFormat(object):
    method __init__ (line 354) | def __init__(self, model_type):
    method __call__ (line 357) | def __call__(self, images):
  class Stack (line 367) | class Stack(object):
    method __init__ (line 369) | def __init__(self, roll=False):
    method __call__ (line 372) | def __call__(self, img_group):
  class ToTorchFormatTensor (line 386) | class ToTorchFormatTensor(object):
    method __init__ (line 390) | def __init__(self, div=True):
    method __call__ (line 393) | def __call__(self, pic):
  class IdentityTransform (line 409) | class IdentityTransform(object):
    method __call__ (line 411) | def __call__(self, data):

FILE: stllm/test/video_utils.py
  function load_video (line 11) | def load_video(vis_path, n_clips=1, num_frm=100):
  function load_video_rawframes (line 50) | def load_video_rawframes(vis_path, total_frame_num, n_clips=1, num_frm=1...
  function get_seq_frames (line 75) | def get_seq_frames(total_num_frames, desired_num_frames):
  function get_frames_from_raw (line 101) | def get_frames_from_raw(directory, frame_idx, filename_tmpl="{:0>6}.jpg"...

FILE: stllm/train/stllm_trainer.py
  function maybe_zero_3 (line 29) | def maybe_zero_3(param, ignore_status=False, name=None):
  function get_mm_adapter_state_maybe_zero_3 (line 43) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
  function split_to_even_chunks (line 49) | def split_to_even_chunks(indices, lengths, num_chunks):
  function get_modality_length_grouped_indices (line 71) | def get_modality_length_grouped_indices(lengths, batch_size, world_size,...
  function get_length_grouped_indices (line 99) | def get_length_grouped_indices(lengths, batch_size, world_size, generato...
  class LengthGroupedSampler (line 110) | class LengthGroupedSampler(Sampler):
    method __init__ (line 116) | def __init__(
    method __len__ (line 133) | def __len__(self):
    method __iter__ (line 136) | def __iter__(self):
  class STLLMTrainer (line 144) | class STLLMTrainer(Trainer):
    method _get_train_sampler (line 146) | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
    method get_train_dataloader (line 161) | def get_train_dataloader(self):
    method create_optimizer (line 218) | def create_optimizer(self):
    method compute_loss (line 307) | def compute_loss(self, model, inputs, return_outputs=False):

FILE: stllm/train/train.py
  function parse_args (line 36) | def parse_args():
  function setup_seeds (line 55) | def setup_seeds(config):
  function get_runner_class (line 66) | def get_runner_class(cfg):
  function main (line 75) | def main():

FILE: stllm/train/train_hf.py
  function rank0_print (line 51) | def rank0_print(*args):
  function parse_args (line 55) | def parse_args():
  class ModelArguments (line 75) | class ModelArguments:
  class DataArguments (line 79) | class DataArguments:
  class TrainingArguments (line 84) | class TrainingArguments(transformers.TrainingArguments):
  function maybe_zero_3 (line 110) | def maybe_zero_3(param, ignore_status=False, name=None):
  function get_peft_state_maybe_zero_3 (line 125) | def get_peft_state_maybe_zero_3(named_params, bias):
  function get_peft_state_non_lora_maybe_zero_3 (line 150) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
  function get_mm_adapter_state_maybe_zero_3 (line 158) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
  function find_all_linear_names (line 164) | def find_all_linear_names(model):
  function safe_save_model_for_hf_trainer (line 180) | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
  function merge_dict_to_argv (line 205) | def merge_dict_to_argv(input_dict):
  class DefaultDataCollator (line 218) | class DefaultDataCollator(object):
    method __call__ (line 219) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
  function train (line 222) | def train():
Condensed preview — 107 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (599K chars).
[
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "PrepareVicuna.md",
    "chars": 1602,
    "preview": "## How to Prepare Vicuna Weight\nVicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT. \nWe cu"
  },
  {
    "path": "README.md",
    "chars": 9944,
    "preview": "<p align=\"center\" width=\"100%\">\n<a target=\"_blank\"><img src=\"example/material/stllm_logo.png\" alt=\"ST-LLM\" style=\"width:"
  },
  {
    "path": "config/instructblipbase_avp.yaml",
    "chars": 1226,
    "preview": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  e"
  },
  {
    "path": "config/instructblipbase_stllm_conversation.yaml",
    "chars": 1272,
    "preview": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"#"
  },
  {
    "path": "config/instructblipbase_stllm_qa.yaml",
    "chars": 1267,
    "preview": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  e"
  },
  {
    "path": "config/minigpt4base_avp.yaml",
    "chars": 1199,
    "preview": "model:\n  arch: st_llm_hf\n  model_type: minigpt4_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_s"
  },
  {
    "path": "config/minigpt4base_stllm_qa.yaml",
    "chars": 1240,
    "preview": "model:\n  arch: st_llm_hf\n  model_type: minigpt4_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_s"
  },
  {
    "path": "demo.py",
    "chars": 2258,
    "preview": "import argparse\nimport torch\n\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stl"
  },
  {
    "path": "demo_gradio.py",
    "chars": 8554,
    "preview": "import gradio as gr\nfrom gradio.themes.utils import colors, fonts, sizes\n\nimport argparse\nimport torch\n\nfrom stllm.commo"
  },
  {
    "path": "prompts/alignment.txt",
    "chars": 286,
    "preview": "<Img><ImageHere></Img> Describe this image in detail.\n<Img><ImageHere></Img> Take a look at this image and describe what"
  },
  {
    "path": "requirement.txt",
    "chars": 937,
    "preview": "torch==2.0.0\ntorchaudio==2.0.1\ntorchvision==0.15.1+cu118\naccelerate\naiohttp==3.8.4\naiosignal==1.3.1\nasync-timeout==4.0.2"
  },
  {
    "path": "script/inference/mvbench/test_mvbench.sh",
    "chars": 377,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/mvbench/mv_bench_infer.py \\\n    --cfg-path config/instructblipbase_"
  },
  {
    "path": "script/inference/qabench/anet_qa.sh",
    "chars": 444,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython script/inference/qabench/anet_qa.sh \\\n    --cfg-path config/instructblipbase_s"
  },
  {
    "path": "script/inference/qabench/msrvtt_qa.sh",
    "chars": 394,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/qabench/msrvtt_qa.py \\\n    --cfg-path config/instructblipbase_stllm"
  },
  {
    "path": "script/inference/qabench/msvd_qa.sh",
    "chars": 385,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/qabench/msvd_qa.py \\\n    --cfg-path config/instructblipbase_stllm_q"
  },
  {
    "path": "script/inference/qabench/score_anet.sh",
    "chars": 381,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_ou"
  },
  {
    "path": "script/inference/qabench/score_msrvtt.sh",
    "chars": 377,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_ou"
  },
  {
    "path": "script/inference/qabench/score_msvd.sh",
    "chars": 369,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_ou"
  },
  {
    "path": "script/inference/vcgbench/score_consist.sh",
    "chars": 350,
    "preview": "python stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py \\\n    --pred_path test_output/vcgbench/stllm_instru"
  },
  {
    "path": "script/inference/vcgbench/score_context.sh",
    "chars": 346,
    "preview": "python stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py \\\n    --pred_path test_output/vcgbench/stllm_instructbl"
  },
  {
    "path": "script/inference/vcgbench/score_correct.sh",
    "chars": 362,
    "preview": "python stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py \\\n    --pred_path test_output/vcgbench/stllm_instru"
  },
  {
    "path": "script/inference/vcgbench/score_detail.sh",
    "chars": 356,
    "preview": "python stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py \\\n    --pred_path test_output/vcgbench/stl"
  },
  {
    "path": "script/inference/vcgbench/score_temporal.sh",
    "chars": 351,
    "preview": "python stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py \\\n    --pred_path test_output/vcgbench/stllm_instructb"
  },
  {
    "path": "script/inference/vcgbench/test_consist.sh",
    "chars": 467,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_consist.py \\\n    --cfg-path config/"
  },
  {
    "path": "script/inference/vcgbench/test_general.sh",
    "chars": 463,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_general.py \\\n    --cfg-path config/"
  },
  {
    "path": "script/inference/vcgbench/test_temporal.sh",
    "chars": 451,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_general.py \\\n    --cfg-path config/"
  },
  {
    "path": "script/train/train.sh",
    "chars": 159,
    "preview": "export PYTHONPATH=\"./:$PYTHONPATH\"\ndeepspeed --master_port=20000 --include=localhost:0,1,2,3,4,5,6,7 stllm/train/train_h"
  },
  {
    "path": "stllm/__init__.py",
    "chars": 936,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "stllm/common/config.py",
    "chars": 15229,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/dist_utils.py",
    "chars": 3715,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/gradcam.py",
    "chars": 815,
    "preview": "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom scipy.ndimage import filters\nfrom skimage import transform "
  },
  {
    "path": "stllm/common/logger.py",
    "chars": 5998,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/optims.py",
    "chars": 3509,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/registry.py",
    "chars": 9876,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/common/utils.py",
    "chars": 13804,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/configs/datasets/cc_sbu/align.yaml",
    "chars": 92,
    "preview": "datasets:\n  cc_sbu_align:\n    data_type: images\n    build_info:\n      storage: cc_sbu_align\n"
  },
  {
    "path": "stllm/configs/datasets/cc_sbu/defaults.yaml",
    "chars": 116,
    "preview": "datasets:\n  cc_sbu:\n    data_type: images\n    build_info:\n      storage: /path/to/cc_sbu_dataset/{00000..01255}.tar\n"
  },
  {
    "path": "stllm/configs/datasets/laion/defaults.yaml",
    "chars": 114,
    "preview": "datasets:\n  laion:\n    data_type: images\n    build_info:\n      storage: /path/to/laion_dataset/{00000..10488}.tar\n"
  },
  {
    "path": "stllm/configs/default.yaml",
    "chars": 141,
    "preview": "env:\n  # For default users\n  # cache_root: \"cache\"\n  # For internal use with persistent storage\n  cache_root: \"/export/h"
  },
  {
    "path": "stllm/configs/models/instructblip_vicuna0.yaml",
    "chars": 793,
    "preview": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_preci"
  },
  {
    "path": "stllm/configs/models/instructblip_vicuna0_btadapter.yaml",
    "chars": 822,
    "preview": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  vit_model: \"eva_btadapter_g\"\n  image_size: 224\n  drop_path_rate: 0\n  use_gra"
  },
  {
    "path": "stllm/configs/models/minigpt4_vicuna0.yaml",
    "chars": 771,
    "preview": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_preci"
  },
  {
    "path": "stllm/configs/models/minigpt4_vicuna0_btadapter.yaml",
    "chars": 802,
    "preview": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  vit_model: \"eva_btadapter_g\"\n  image_size: 224\n  drop_path_rate: 0\n  use_gra"
  },
  {
    "path": "stllm/conversation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "stllm/conversation/conversation.py",
    "chars": 13559,
    "preview": "import argparse\nimport time\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nfrom transformers import AutoTokenize"
  },
  {
    "path": "stllm/conversation/mvbench_conversation.py",
    "chars": 8679,
    "preview": "import torch\nimport numpy as np\nfrom transformers import StoppingCriteria, StoppingCriteriaList\n\ndef get_prompt(conv):\n "
  },
  {
    "path": "stllm/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "stllm/datasets/builders/__init__.py",
    "chars": 1888,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/builders/base_dataset_builder.py",
    "chars": 8093,
    "preview": "\"\"\"\n This file is from\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-C"
  },
  {
    "path": "stllm/datasets/builders/image_text_pair_builder.py",
    "chars": 2988,
    "preview": "import os\nimport logging\nimport warnings\n\nfrom stllm.common.registry import registry\nfrom stllm.datasets.builders.base_d"
  },
  {
    "path": "stllm/datasets/data_utils.py",
    "chars": 6275,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "stllm/datasets/datasets/base_dataset.py",
    "chars": 2200,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/datasets/caption_datasets.py",
    "chars": 2598,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/datasets/cc_sbu_dataset.py",
    "chars": 1612,
    "preview": "import os\nimport pickle\nfrom PIL import Image\nimport webdataset as wds\nfrom stllm.datasets.datasets.base_dataset import "
  },
  {
    "path": "stllm/datasets/datasets/dataloader_utils.py",
    "chars": 6595,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/datasets/image_video_itdatasets.py",
    "chars": 11603,
    "preview": "import logging\nimport os\nimport random\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset\nfrom torchvision impor"
  },
  {
    "path": "stllm/datasets/datasets/instruction_data.py",
    "chars": 5830,
    "preview": "from torchvision import transforms\nfrom torchvision.transforms import InterpolationMode\n\nmean = (0.48145466, 0.4578275, "
  },
  {
    "path": "stllm/datasets/datasets/laion_dataset.py",
    "chars": 1167,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/datasets/datasets/utils.py",
    "chars": 15792,
    "preview": "#from utils.distributed import is_main_process, get_rank, get_world_size\nimport logging\nimport torch.distributed as dist"
  },
  {
    "path": "stllm/models/Qformer.py",
    "chars": 48386,
    "preview": "\"\"\"\n * Copyright (c) 2023, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
  },
  {
    "path": "stllm/models/__init__.py",
    "chars": 5743,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/models/base_decoder.py",
    "chars": 7039,
    "preview": "from functools import partial\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom"
  },
  {
    "path": "stllm/models/base_model.py",
    "chars": 7860,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/models/blip2.py",
    "chars": 7958,
    "preview": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/models/blip2_outputs.py",
    "chars": 4153,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/models/eva_btadapter.py",
    "chars": 11690,
    "preview": "# --------------------------------------------------------\n# Adapted from  https://github.com/microsoft/unilm/tree/maste"
  },
  {
    "path": "stllm/models/eva_vit.py",
    "chars": 19604,
    "preview": "# Based on EVA, BEIT, timm and DeiT code bases\n# https://github.com/baaivision/EVA\n# https://github.com/rwightman/pytorc"
  },
  {
    "path": "stllm/models/modeling_llama_mem.py",
    "chars": 32829,
    "preview": "# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_l"
  },
  {
    "path": "stllm/models/peft_model.py",
    "chars": 3947,
    "preview": "# coding=utf-8\n# Copyright 2023-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "stllm/models/st_llm.py",
    "chars": 26818,
    "preview": "import logging\nimport random\nimport re\nimport os\nimport math\nimport einops\nimport ast\n\nimport torch\nfrom torch.cuda.amp "
  },
  {
    "path": "stllm/models/utils.py",
    "chars": 1099,
    "preview": "import numpy as np\nimport torch\n\ndef RandomMaskingGenerator(num_patches, mask_ratio, batch, device='cuda'):\n    num_mask"
  },
  {
    "path": "stllm/processors/__init__.py",
    "chars": 814,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/processors/base_processor.py",
    "chars": 610,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/processors/blip_processors.py",
    "chars": 4805,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/processors/randaugment.py",
    "chars": 11298,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/processors/video_transform.py",
    "chars": 8806,
    "preview": "import numpy as np\n\nclass SampleFrames:\n    \"\"\"Sample frames from the video.\n\n    Required Keys:\n\n        - total_frames"
  },
  {
    "path": "stllm/runners/__init__.py",
    "chars": 303,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/runners/runner_base.py",
    "chars": 23482,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/tasks/__init__.py",
    "chars": 767,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/tasks/base_task.py",
    "chars": 8944,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/tasks/image_text_pretrain.py",
    "chars": 2026,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/test/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_activitynet_qa.py",
    "chars": 7879,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py",
    "chars": 7767,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py",
    "chars": 7848,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py",
    "chars": 7940,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py",
    "chars": 7704,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py",
    "chars": 8390,
    "preview": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():"
  },
  {
    "path": "stllm/test/mvbench/mv_bench.py",
    "chars": 11935,
    "preview": "import os \nimport json\nimport math\nimport numpy as np\nimport cv2\nimport io\nimport imageio\nfrom mmengine.fileio import Fi"
  },
  {
    "path": "stllm/test/mvbench/mv_bench_infer.py",
    "chars": 4566,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/qabench/activitynet_qa.py",
    "chars": 5070,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/qabench/msrvtt_qa.py",
    "chars": 4600,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/qabench/msvd_qa.py",
    "chars": 4720,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/vcgbench/videochatgpt_benchmark_consist.py",
    "chars": 4712,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/vcgbench/videochatgpt_benchmark_general.py",
    "chars": 4547,
    "preview": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm i"
  },
  {
    "path": "stllm/test/video_transforms.py",
    "chars": 14445,
    "preview": "import torchvision\nimport random\nfrom PIL import Image, ImageOps\nimport numpy as np\nimport numbers\nimport math\nimport to"
  },
  {
    "path": "stllm/test/video_utils.py",
    "chars": 4448,
    "preview": "import os\nimport copy\nimport numpy as np\nfrom PIL import Image\nimport decord\nfrom decord import VideoReader, cpu\nfrom tr"
  },
  {
    "path": "stllm/train/stllm_trainer.py",
    "chars": 14467,
    "preview": "import os\nimport torch\nfrom stllm.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset\nfrom"
  },
  {
    "path": "stllm/train/train.py",
    "chars": 2659,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "stllm/train/train_hf.py",
    "chars": 9372,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "stllm/train/zero2.json",
    "chars": 556,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "stllm/train/zero3.json",
    "chars": 801,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "stllm/train/zero3_offload.json",
    "chars": 1389,
    "preview": "{\n    \"fp16\": {\n      \"enabled\": \"auto\",\n      \"loss_scale\": 0,\n      \"loss_scale_window\": 1000,\n      \"initial_scale_po"
  },
  {
    "path": "trainval.md",
    "chars": 3110,
    "preview": "## 1. Prepare the Pretrained Weights\nAlthough some weights can be downloaded dynamically at runtime, it is recommended t"
  }
]

About this extraction

This page contains the full source code of the TencentARC/ST-LLM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 107 files (558.1 KB), approximately 135.4k tokens, and a symbol index with 774 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!