[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "PrepareVicuna.md",
    "content": "## How to Prepare Vicuna Weight\nVicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT. \nWe currently use the v0 version of Vicuna-13B. \n\nTo prepare Vicuna’s weight, first download Vicuna’s **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0). \nIn case you have git-lfs installed (https://git-lfs.com), this can be done by\n\n```\ngit lfs install\ngit clone https://huggingface.co/lmsys/vicuna-13b-delta-v0  # more powerful, need at least 24G gpu memory\n# or\ngit clone https://huggingface.co/lmsys/vicuna-7b-delta-v0  # smaller, need 12G gpu memory\n```\n\nNote that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMA’s rules, we cannot distribute the weight of LLAMA.)\n\nThen, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format \neither following the instruction provided by HuggingFace \n[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet. \n\nWhen these two weights are ready, we can use tools from Vicuna’s team to create the real working weight.\nFirst, Install their library that is compatible with v0 Vicuna by\n\n```\npip install git+https://github.com/lm-sys/FastChat.git@v0.1.10\n```\n\nThen, run the following command to create the final working weight\n\n```\npython -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/  --target /path/to/save/working/vicuna/weight/  --delta /path/to/vicuna-13bOR7b-delta-v0/\n```\n\nNow you are good to go!\n\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\" width=\"100%\">\n<a target=\"_blank\"><img src=\"example/material/stllm_logo.png\" alt=\"ST-LLM\" style=\"width: 50%; min-width: 150px; display: block; margin: auto;\"></a>\n</p>\n\n<h2 align=\"center\"> <a href=\"https://arxiv.org/abs/2404.00308\">ST-LLM: Large Language Models Are Effective Temporal Learners</a></h2>\n\n<h5 align=center>\n\n[![hf](https://img.shields.io/badge/🤗-Hugging%20Face-blue.svg)](https://huggingface.co/farewellthree/ST_LLM_weight/tree/main)\n[![arXiv](https://img.shields.io/badge/Arxiv-2311.08046-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2404.00308)\n[![License](https://img.shields.io/badge/Code%20License-Apache2.0-yellow)](https://github.com/farewellthree/ST-LLM/blob/main/LICENSE)\n</h5>\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-1)](https://paperswithcode.com/sota/video-based-generative-performance-1?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-5)](https://paperswithcode.com/sota/video-based-generative-performance-5?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-2)](https://paperswithcode.com/sota/video-based-generative-performance-2?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-3)](https://paperswithcode.com/sota/video-based-generative-performance-3?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/video-based-generative-performance-4)](https://paperswithcode.com/sota/video-based-generative-performance-4?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=st-llm-large-language-models-are-effective-1)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/st-llm-large-language-models-are-effective-1/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=st-llm-large-language-models-are-effective-1)\n\n## News :loudspeaker:\n\n* **[2024/3/28]**  All codes and weights are available now! Welcome to watch this repository for the latest updates.\n\n## Introduction :bulb:\n\n- **ST-LLM** is a temporal-sensitive video large language model. Our model incorporates three key architectural: \n  - (1) Joint spatial-temporal modeling within large language models for effective video understanding.\n  - (2) Dynamic masking strategy and mask video modeling for efficiency and robustness.\n  - (3) Global-local input module for long video understanding.\n- **ST-LLM** has established new state-of-the-art results on MVBench, VideoChatGPT Bench and VideoQA Bench:\n\n<div align=\"center\">\n<table border=\"1\" width=\"100%\">\n    <tr align=\"center\">\n        <th rowspan=\"2\">Method</th><th rowspan=\"2\">MVBench</th><th colspan=\"6\">VcgBench</th><th colspan=\"3\">VideoQABench</th>\n    </tr>\n  <tr align=\"center\">\n        <th>Avg</th><th>Correct</th><th>Detail</th><th>Context</th><th>Temporal</th><th>Consist</th><th>MSVD</th><th>MSRVTT</th><th>ANet</th>\n    </tr>\n  <tr align=\"center\">\n        <td>VideoLLaMA</td><td>34.1</td><td>1.96</td><td>2.18</td><td>2.16</td><td>1.82</td><td>1.79</td><td>1.98</td><td>51.6</td><td>29.6</td><td>12.4</td>\n    </tr>\n  <tr align=\"center\">\n        <td>LLaMA-Adapter</td><td>31.7</td><td>2.03</td><td>2.32</td><td>2.30</td><td>1.98</td><td>2.15</td><td>2.16</td><td>54.9</td><td>43.8</td><td>34.2</td>\n    </tr>\n  <tr align=\"center\">\n        <td>VideoChat</td><td>35.5</td><td>2.23</td><td>2.50</td><td>2.53</td><td>1.94</td><td>2.24</td><td>2.29</td><td>56.3</td><td>45.0</td><td>26.5</td>\n    </tr>\n  <tr align=\"center\">\n        <td>VideoChatGPT</td><td>32.7</td><td>2.38</td><td>2.40</td><td>2.52</td><td>2.62</td><td>1.98</td><td>2.37</td><td>64.9</td><td>49.3</td><td>35.7</td>\n    </tr>\n  <tr align=\"center\">\n        <td>MovieChat</td><td>-</td><td>2.76</td><td>2.93</td><td>3.01</td><td>2.24</td><td>2.42</td><td>2.67</td><td>74.2</td><td>52.7</td><td>45.7</td>\n    </tr>\n  <tr align=\"center\">\n        <td>Vista-LLaMA</td><td>-</td><td>2.44</td><td>2.64</td><td>3.18</td><td>2.26</td><td>2.31</td><td>2.57</td><td>65.3</td><td>60.5</td><td>48.3</td>\n    </tr>\n  <tr align=\"center\">\n        <td>LLaMA-VID</td><td>-</td><td>2.89</td><td>2.96</td><td>3.00</td><td>3.53</td><td>2.46</td><td>2.51</td><td>69.7</td><td>57.7</td><td>47.4</td>\n    </tr>\n  <tr align=\"center\">\n        <td>Chat-UniVi</td><td>-</td><td>2.99</td><td>2.89</td><td>2.91</td><td>3.46</td><td>2.89</td><td>2.81</td><td>65.0</td><td>54.6</td><td>45.8</td>\n    </tr>\n  <tr align=\"center\">\n        <td>VideoChat2</td><td>51.1</td><td>2.98</td><td>3.02</td><td>2.88</td><td>3.51</td><td>2.66</td><td>2.81</td><td>70.0</td><td>54.1</td><td>49.1</td>\n    </tr>\n  <tr align=\"center\">\n        <td>ST-LLM</td><td><b>54.9</b></td><td><b>3.15</b></td><td><b>3.23</b></td><td><b>3.05</b></td><td><b>3.74</b></td><td><b>2.93</b></td><td><b>2.81</b></td><td><b>74.6</b></td><td><b>63.2</b></td><td><b>50.9</b></td>\n    </tr>\n  \n</table>\n</div>\n\n## Demo 🤗\nPlease download the conversation weights from [here](https://huggingface.co/farewellthree/ST_LLM_weight/tree/main/conversation_weight) and follow the instructions in [installation](README.md#Installation) first. Then, run the gradio demo:\n```\nCUDA_VISIBLE_DEVICES=0 python3 demo_gradio.py --ckpt-path /path/to/STLLM_conversation_weight\n```\nWe have also prepared local scripts that are easy to modify：[demo.py](demo.py)\n\n<div align=center>\n<img src=\"example/material/Mabaoguo.gif\" width=\"70%\" />\n</div>\n\n<div align=center>\n<img src=\"example/material/Driving.gif\" width=\"70%\" />\n</div>\n\n## Examples 👀\n- **Video Description: for high-difficulty videos with complex scene changes, ST-LLM can accurately describe all the contents.**\n<p align=\"center\">\n  <img src=\"example/driving.gif\" width=\"25%\" style=\"display:inline-block\" />\n  <img src=\"example/driving.jpg\" width=\"65%\" style=\"display:inline-block\" /> \n</p>\n\n- **Action Identification: ST-LLM can accurately and comprehensively describe the actions occurring in the video.**\n<p align=\"center\">\n  <img src=\"example/cooking.gif\" width=\"21%\" style=\"display:inline-block\" />\n  <img src=\"example/cooking.jpg\" width=\"68%\" style=\"display:inline-block\" /> \n</p>\n\n<p align=\"center\">\n  <img src=\"example/TVshow.gif\" width=\"21%\" style=\"display:inline-block\" />\n  <img src=\"example/TVshow.jpg\" width=\"68%\" style=\"display:inline-block\" /> \n</p>\n\n<p align=\"center\">\n  <img src=\"example/monkey.gif\" width=\"21%\" style=\"display:inline-block\" />\n  <img src=\"example/monkey.jpg\" width=\"68%\" style=\"display:inline-block\" /> \n</p>\n\n- **Reasoning: for the challenging open-ended reasoning questions, STLLM can also provide reasonable answers.**\n  <p align=\"center\">\n  <img src=\"example/BaoguoMa.gif\" width=\"26%\" style=\"display:inline-block\" />\n  <img src=\"example/baoguoma.jpg\" width=\"66%\" style=\"display:inline-block\" /> \n</p>\n\n## Installation 🛠️\nGit clone our repository, creating a Python environment and activate it via the following command\n\n```bash\ngit clone https://github.com/farewellthree/ST-LLM.git\ncd ST-LLM\nconda create --name stllm python=3.10\nconda activate stllm\npip install -r requirement.txt\n```\n\n## Training & Validation :bar_chart:\nThe instructions of data, training and evaluating can be found in [trainval.md](trainval.md).\n\n## Acknowledgement 👍\n* [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT) and [MVBench](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) Great job contributing video LLM benchmark.\n* [InstuctBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip) and [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4/tree/main) The codebase and the basic image LLM we built upon.\n\n## Citation ✏️\nIf you find the code and paper useful for your research, please consider staring this repo and citing our paper:\n```\n@article{liu2023one,\n  title={One for all: Video conversation is feasible without video instruction tuning},\n  author={Liu, Ruyang and Li, Chen and Ge, Yixiao and Shan, Ying and Li, Thomas H and Li, Ge},\n  journal={arXiv preprint arXiv:2309.15785},\n  year={2023}\n}\n```\n```\n@article{liu2023one,\n  title={ST-LLM: Large Language Models Are Effective Temporal Learners},\n  author={Liu, Ruyang and Li, Chen and Tang, Haoran and Ge, Yixiao and Shan, Ying and Li, Ge},\n  journal={https://arxiv.org/abs/2404.00308},\n  year={2023}\n}\n```\n \n"
  },
  {
    "path": "config/instructblipbase_avp.yaml",
    "content": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"###\"\n  video_input: \"mean\"\n  llama_model: '/path/to/vicuna-7b-v1.1'\n  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  qformer_text_input: True\n  freeze_LLM: False\n\ndatasets:\n  caption_videochatgpt:\n    num_frames: 16\n    #video_reader_type: 'rawframe'\n  classification_k710:\n    num_frames: 16\n  classification_ssv2:\n    num_frames: 16\n  reasoning_next_qa:\n    num_frames: 16\n  reasoning_clevrer_qa:\n    num_frames: 16\n  reasoning_clevrer_mc:\n    num_frames: 16\n  vqa_webvid_qa:\n    num_frames: 16\n\nrun:\n  task: video_text_it\n  bf16: True\n  tf32: False\n  output_dir: \"./stllm/output/instructblipbase_avp\"\n  num_train_epochs: 2\n  dataloader_num_workers: 4\n  per_device_train_batch_size: 16\n  per_device_eval_batch_size: 16\n  gradient_accumulation_steps: 1\n  evaluation_strategy: \"no\"\n  learning_rate: 2e-5\n  weight_decay: 0.\n  warmup_ratio: 0.03\n  lr_scheduler_type: 'cosine'\n  logging_steps: 50\n  model_max_length: 1024\n  #save_steps: 10000 \n  save_strategy: \"epoch\" \n  save_total_limit: 1\n  deepspeed: 'stllm/train/zero3.json'"
  },
  {
    "path": "config/instructblipbase_stllm_conversation.yaml",
    "content": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"###\"\n  #prompt_path: \"prompts/alignment.txt\"\n  prompt_template: '###Human: {} ###Assistant: '\n  llama_model: '/path/to/vicuna-7b-v1.1'\n  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  qformer_text_input: True\n  freeze_LLM: False\n  video_input: \"residual\"\n  residual_size: 16\n  use_mask : True\n  mvm_decode: True\n\ndatasets:\n  caption_videochat:\n    num_frames: 64\n  conversation_videochat1:\n    num_frames: 64\n  caption_videochatgpt:\n    num_frames: 64\n    #video_reader_type: 'rawframe'\n  caption_webvid:\n    num_frames: 64\n  vqa_webvid_qa:\n    num_frames: 64\n\nrun:\n  task: video_text_it\n  bf16: True\n  tf32: False\n  output_dir: \"./stllm/output/instructblipbase_stllm_conversation\"\n  num_train_epochs: 2\n  dataloader_num_workers: 4\n  per_device_train_batch_size: 16\n  per_device_eval_batch_size: 16\n  gradient_accumulation_steps: 1\n  evaluation_strategy: \"no\"\n  learning_rate: 2e-5\n  weight_decay: 0.\n  warmup_ratio: 0.03\n  lr_scheduler_type: 'cosine'\n  logging_steps: 50\n  model_max_length: 1024\n  save_strategy: \"epoch\" \n  save_total_limit: 1\n  deepspeed: 'stllm/train/zero2.json'"
  },
  {
    "path": "config/instructblipbase_stllm_qa.yaml",
    "content": "model:\n  arch: st_llm_hf\n  model_type: instructblip_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"###\"\n  video_input: \"all\"\n  llama_model: '/path/to/vicuna-7b-v1.1'\n  ckpt: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  q_former_model: '/Path/to/instruct_blip_vicuna7b_trimmed.pth'\n  qformer_text_input: True\n  freeze_LLM: False\n  use_mask : True\n  mvm_decode: True\n\ndatasets:\n  caption_videochatgpt:\n    num_frames: 16\n    #video_reader_type: 'rawframe'\n  classification_k710:\n    num_frames: 16\n  classification_ssv2:\n    num_frames: 16\n  reasoning_next_qa:\n    num_frames: 16\n  reasoning_clevrer_qa:\n    num_frames: 16\n  reasoning_clevrer_mc:\n    num_frames: 16\n  vqa_webvid_qa:\n    num_frames: 16\n\nrun:\n  task: video_text_it\n  bf16: True\n  tf32: False\n  output_dir: \"./stllm/output/instructblipbase_stllm_qa\"\n  num_train_epochs: 2\n  dataloader_num_workers: 4\n  per_device_train_batch_size: 16\n  per_device_eval_batch_size: 16\n  gradient_accumulation_steps: 1\n  evaluation_strategy: \"no\"\n  learning_rate: 2e-5\n  weight_decay: 0.\n  warmup_ratio: 0.03\n  lr_scheduler_type: 'cosine'\n  logging_steps: 50\n  model_max_length: 1024\n  #save_steps: 10000 \n  save_strategy: \"epoch\" \n  save_total_limit: 1\n  deepspeed: 'stllm/train/zero3.json'"
  },
  {
    "path": "config/minigpt4base_avp.yaml",
    "content": "model:\n  arch: st_llm_hf\n  model_type: minigpt4_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"###\"\n  video_input: \"mean\"\n  llama_model: \"/path/to/vicuna-7b\"\n  ckpt: '/Path/to/prerained_minigpt4_7b.pth'\n  q_former_model: /Path/to/blip2_pretrained_flant5xxl.pth\n  qformer_text_input: False\n  freeze_LLM: False\n\ndatasets:\n  caption_videochatgpt:\n    num_frames: 16\n    #video_reader_type: 'rawframe'\n  classification_k710:\n    num_frames: 16\n  classification_ssv2:\n    num_frames: 16\n  reasoning_next_qa:\n    num_frames: 16\n  reasoning_clevrer_qa:\n    num_frames: 16\n  reasoning_clevrer_mc:\n    num_frames: 16\n  vqa_webvid_qa:\n    num_frames: 16\n\nrun:\n  task: video_text_it\n  bf16: True\n  tf32: False\n  output_dir: \"./stllm/output/minigpt4base_avp\"\n  num_train_epochs: 2\n  dataloader_num_workers: 4\n  per_device_train_batch_size: 16\n  per_device_eval_batch_size: 16\n  gradient_accumulation_steps: 1\n  evaluation_strategy: \"no\"\n  learning_rate: 2e-5\n  weight_decay: 0.\n  warmup_ratio: 0.03\n  lr_scheduler_type: 'cosine'\n  logging_steps: 50\n  model_max_length: 1024\n  #save_steps: 10000 \n  save_strategy: \"epoch\" \n  save_total_limit: 1\n  deepspeed: 'stllm/train/zero3.json'"
  },
  {
    "path": "config/minigpt4base_stllm_qa.yaml",
    "content": "model:\n  arch: st_llm_hf\n  model_type: minigpt4_vicuna0_btadapter\n  use_grad_checkpoint: True\n  max_txt_len: 256\n  end_sym: \"###\"\n  video_input: \"all\"\n  llama_model: \"/path/to/vicuna-7b\"\n  ckpt: '/Path/to/prerained_minigpt4_7b.pth'\n  q_former_model: /Path/to/blip2_pretrained_flant5xxl.pth\n  qformer_text_input: False\n  freeze_LLM: False\n  use_mask : True\n  mvm_decode: True\n\ndatasets:\n  caption_videochatgpt:\n    num_frames: 16\n    #video_reader_type: 'rawframe'\n  classification_k710:\n    num_frames: 16\n  classification_ssv2:\n    num_frames: 16\n  reasoning_next_qa:\n    num_frames: 16\n  reasoning_clevrer_qa:\n    num_frames: 16\n  reasoning_clevrer_mc:\n    num_frames: 16\n  vqa_webvid_qa:\n    num_frames: 16\n\nrun:\n  task: video_text_it\n  bf16: True\n  tf32: False\n  output_dir: \"./stllm/output/minigpt4base_stllm_qa\"\n  num_train_epochs: 2\n  dataloader_num_workers: 4\n  per_device_train_batch_size: 16\n  per_device_eval_batch_size: 16\n  gradient_accumulation_steps: 1\n  evaluation_strategy: \"no\"\n  learning_rate: 2e-5\n  weight_decay: 0.\n  warmup_ratio: 0.03\n  lr_scheduler_type: 'cosine'\n  logging_steps: 50\n  model_max_length: 1024\n  #save_steps: 10000 \n  save_strategy: \"epoch\" \n  save_total_limit: 1\n  deepspeed: 'stllm/train/zero3.json'"
  },
  {
    "path": "demo.py",
    "content": "import argparse\nimport torch\n\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Demo\")\n    parser.add_argument(\"--cfg-path\", default='config/instructblipbase_stllm_conversation.yaml', help=\"path to configuration file.\")\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to STLLM_conversation_weight.\")\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\n# ========================================\n#             Model Initialization\n# ========================================\n\nprint('Initializing Chat')\nargs = parse_args()\ncfg = Config(args)\n\nckpt_path = args.ckpt_path\nmodel_config = cfg.model_cfg\nmodel_config.device_8bit = args.gpu_id\nmodel_config.ckpt = ckpt_path\nmodel_config.llama_model = ckpt_path\nmodel_cls = registry.get_model_class(model_config.arch)\nmodel = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\nmodel.to(torch.float16)\nCONV_VISION = CONV_instructblip_Vicuna0\n\nchat = Chat(model, device='cuda:{}'.format(args.gpu_id))\nprint('Initialization Finished')\n\nchat_state = CONV_VISION.copy()\nvideo = 'example/BaoguoMa.mp4'\nprompt = 'Tell me why this video looks so funny?'\nimg_list = []\n\nchat.upload_video(video, chat_state, img_list, 64, text=prompt)\nchat.ask(\"###Human: \" + prompt + \" ###Assistant: \", chat_state)\nllm_message = chat.answer(conv=chat_state,\n                img_list=img_list,\n                num_beams=5,\n                do_sample=False,\n                temperature=1,\n                max_new_tokens=300,\n                max_length=2000)[0]\nprint (llm_message)\n\n\n"
  },
  {
    "path": "demo_gradio.py",
    "content": "import gradio as gr\nfrom gradio.themes.utils import colors, fonts, sizes\n\nimport argparse\nimport torch\n\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Demo\")\n    parser.add_argument(\"--cfg-path\", default='config/instructblipbase_stllm_conversation.yaml', help=\"path to configuration file.\")\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to STLLM_conversation_weight.\")\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    args = parser.parse_args()\n    return args\n\n# ========================================\n#             Model Initialization\n# ========================================\n\nprint('Initializing Chat')\nargs = parse_args()\ncfg = Config(args)\n\nckpt_path = args.ckpt_path\nmodel_config = cfg.model_cfg\nmodel_config.device_8bit = args.gpu_id\nmodel_config.ckpt = ckpt_path\nmodel_config.llama_model = ckpt_path\nmodel_cls = registry.get_model_class(model_config.arch)\nmodel = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\nmodel.to(torch.float16)\nCONV_VISION = CONV_instructblip_Vicuna0\n\nchat = Chat(model, device='cuda:{}'.format(args.gpu_id))\nprint('Initialization Finished')\n\n# ========================================\n#             Gradio Setting\n# ========================================\ndef gradio_reset(chat_state, img_list):\n    if chat_state is not None:\n        chat_state.messages = []\n    if img_list is not None:\n        img_list = []\n    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value=\"Upload & Start Chat\", interactive=True), chat_state, img_list\n\n\ndef upload_video(gr_video, chat_state, num_segments, text_prompt='Watch the video and answer the question.'):\n    print('gr_video: ', gr_video)\n    img_list = []\n    if gr_video: \n        chat_state = CONV_VISION.copy()\n        chat.upload_video(gr_video, chat_state, img_list, num_segments, text=text_prompt)\n        return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value=\"Start Chatting\", interactive=False), chat_state, img_list\n\ndef gradio_ask(user_message, chatbot, chat_state, gr_video, num_segments):\n    if len(user_message) == 0:\n        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state\n    chat_state = CONV_VISION.copy()\n    img_list = []\n    chat.upload_video(gr_video, chat_state, img_list, num_segments, text=user_message)\n    msg = \"###Human: \" + user_message + \" ###Assistant: \"\n    chat.ask(msg, chat_state)\n    chatbot = chatbot + [[user_message, None]]\n    return '', chatbot, chat_state, img_list\n\n\ndef gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):\n    llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, do_sample=False, temperature=temperature, max_length=2000)[0]\n    llm_message = llm_message.replace(\"<s>\", \"\") # handle <s>\n    chatbot[-1][1] = llm_message\n    print(chat_state)\n    print(f\"Answer: {llm_message}\")\n    return chatbot, chat_state, img_list\n\n\nclass STLLM(gr.themes.base.Base):\n    def __init__(\n        self,\n        *,\n        primary_hue=colors.blue,\n        secondary_hue=colors.sky,\n        neutral_hue=colors.gray,\n        spacing_size=sizes.spacing_md,\n        radius_size=sizes.radius_sm,\n        text_size=sizes.text_md,\n        font=(\n            fonts.GoogleFont(\"Noto Sans\"),\n            \"ui-sans-serif\",\n            \"sans-serif\",\n        ),\n        font_mono=(\n            fonts.GoogleFont(\"IBM Plex Mono\"),\n            \"ui-monospace\",\n            \"monospace\",\n        ),\n    ):\n        super().__init__(\n            primary_hue=primary_hue,\n            secondary_hue=secondary_hue,\n            neutral_hue=neutral_hue,\n            spacing_size=spacing_size,\n            radius_size=radius_size,\n            text_size=text_size,\n            font=font,\n            font_mono=font_mono,\n        )\n        super().set(\n            body_background_fill=\"*neutral_50\",\n        )\n\n\ngvlabtheme = STLLM(primary_hue=colors.blue,\n        secondary_hue=colors.sky,\n        neutral_hue=colors.gray,\n        spacing_size=sizes.spacing_md,\n        radius_size=sizes.radius_sm,\n        text_size=sizes.text_md,\n        )\n\ntitle = \"\"\"<h1 align=\"center\"><a href=\"https://github.com/farewellthree/ST-LLM\"><img src=\"https://s21.ax1x.com/2024/03/25/pF4Wzq0.png\" border=\"0\" style=\"margin: 0 auto; height: 150px;\" /></a> </h1>\"\"\"\ndescription =\"\"\"\n        CLICK FOR SOURCE CODE!<br><p><a href='https://github.com/farewellthree/ST-LLM'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>\n        \"\"\"\n\n\nwith gr.Blocks(title=\"ST-LLM Chatbot!\",theme=gvlabtheme,css=\"#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}\") as demo:\n    gr.Markdown(title)\n    gr.Markdown(description)\n\n    with gr.Row():\n        with gr.Column(scale=0.5, visible=True) as video_upload:\n            with gr.Column(elem_id=\"image\", scale=0.5) as img_part:\n                with gr.Tab(\"Video\", elem_id='video_tab'):\n                    up_video = gr.Video(interactive=True, include_audio=True, elem_id=\"video_upload\").style(height=360)\n            # text_prompt_input = gr.Textbox(value=\"Watch the video and answer the question.\",show_label=False, placeholder='Input your text prompt, example: \"Watch the video and answer the question.\"', interactive=True).style(container=False)           \n            upload_button = gr.Button(value=\"Upload & Start Chat\", interactive=True, variant=\"primary\")\n            clear = gr.Button(\"Restart\")\n            \n            num_beams = gr.Slider(\n                minimum=1,\n                maximum=10,\n                value=5,\n                step=1,\n                interactive=True,\n                label=\"beam search numbers\",\n            )\n            \n            temperature = gr.Slider(\n                minimum=0.1,\n                maximum=2.0,\n                value=1.0,\n                step=0.1,\n                interactive=True,\n                label=\"Temperature\",\n            )\n            \n            num_segments = gr.Slider(\n                minimum=16,\n                maximum=96,\n                value=64,\n                step=1,\n                interactive=True,\n                label=\"Video Segments\",\n            )\n        \n        with gr.Column(visible=True)  as input_raws:\n            chat_state = gr.State()\n            img_list = gr.State()\n            chatbot = gr.Chatbot(elem_id=\"chatbot\",label='ST-LLM')\n            with gr.Row():\n                with gr.Column(scale=0.7):\n                    text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False).style(container=False)\n                with gr.Column(scale=0.15, min_width=0):\n                    run = gr.Button(\"💭Send\")\n                with gr.Column(scale=0.15, min_width=0):\n                    clear = gr.Button(\"🔄Clear️\")     \n    \n    upload_button.click(upload_video, [up_video, chat_state, num_segments], [up_video, text_input, upload_button, chat_state, img_list])\n    \n    text_input.submit(gradio_ask, [text_input, chatbot, chat_state, up_video, num_segments], [text_input, chatbot, chat_state, img_list]).then(\n        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]\n    )\n    run.click(gradio_ask, [text_input, chatbot, chat_state, up_video, num_segments], [text_input, chatbot, chat_state, img_list]).then(\n        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]\n    )\n    run.click(lambda: \"\", None, text_input)  \n    clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_video, text_input, upload_button, chat_state, img_list], queue=False)\n\ndemo.launch(share=True, enable_queue=True)\n"
  },
  {
    "path": "prompts/alignment.txt",
    "content": "<Img><ImageHere></Img> Describe this image in detail.\n<Img><ImageHere></Img> Take a look at this image and describe what you notice.\n<Img><ImageHere></Img> Please provide a detailed description of the picture.\n<Img><ImageHere></Img> Could you describe the contents of this image for me?"
  },
  {
    "path": "requirement.txt",
    "content": "torch==2.0.0\ntorchaudio==2.0.1\ntorchvision==0.15.1+cu118\naccelerate\naiohttp==3.8.4\naiosignal==1.3.1\nasync-timeout==4.0.2\nattrs==22.2.0\nbitsandbytes==0.37.0\ncchardet==2.1.7\nchardet==5.1.0\ncontourpy==1.0.7\ncycler==0.11.0\nfilelock==3.9.0\nfonttools==4.38.0\nfrozenlist==1.3.3\nhuggingface-hub==0.13.4\nimportlib-resources==5.12.0\nkiwisolver==1.4.4\nmatplotlib==3.7.0\nmultidict==6.0.4\nopenai==0.27.0\npackaging==23.0\npsutil==5.9.4\npycocotools==2.0.6\npyparsing==3.0.9\npython-dateutil==2.8.2\npyyaml==6.0\nregex==2022.10.31\ntokenizers==0.13.2\ntqdm==4.64.1\ntransformers==4.28.0\ntimm==0.6.13\nspacy==3.5.1\nwebdataset==0.2.48\nscikit-learn==1.2.2\nscipy==1.10.1\nyarl==1.8.2\nzipp==3.14.0\nomegaconf==2.3.0\nopencv-python==4.7.0.72\niopath==0.1.10\ndecord==0.6.0\ntenacity==8.2.2\npycocoevalcap\nsentence-transformers\numap-learn\nnotebook\ngradio==3.24.1\ngradio-client==0.0.8\nwandb\npeft==0.8.1\neinops==0.7.0\nimageio==2.33.1\nav==11.0.0\ntransformers[deepspeed]\nmmengine\n"
  },
  {
    "path": "script/inference/mvbench/test_mvbench.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/mvbench/mv_bench_infer.py \\\n    --cfg-path config/instructblipbase_stllm_qa.yaml \\\n    --ckpt-path Path/to/instructblipbase_stllm_qa \\\n    --anno-path Path/to/MVBench/json \\\n    --output_dir test_output/mvbench/ \\\n    --output_name instructblipbase_stllm_qa_mvbench_fps1 \\\n    --num-frames 0 \\\n    --ask_simple \\\n    \n    \n\n"
  },
  {
    "path": "script/inference/qabench/anet_qa.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython script/inference/qabench/anet_qa.sh \\\n    --cfg-path config/instructblipbase_stllm_qa.yaml \\\n    --ckpt-path /Path/to/STLLM_QA_weight \\\n    --video_dir /Path/to/Anet/videos \\\n    --gt_file_question /Path/to/Anet/test_q.json \\\n    --gt_file_answers /Path/to/Anet/test_a.json \\\n    --output_dir test_output/qabench/ \\\n    --output_name stllm_instructblipbase_anetqa \\\n    --num-frames 16 \\\n    \n    \n    "
  },
  {
    "path": "script/inference/qabench/msrvtt_qa.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/qabench/msrvtt_qa.py \\\n    --cfg-path config/instructblipbase_stllm_qa.yaml \\\n    --ckpt-path /Path/to/STLLM_QA_weight \\\n    --video_dir /Path/to/MSRVTT-QA/video/ \\\n    --gt_file /Path/to/MSRVTT-QA/test_qa.json \\\n    --output_dir test_output/qabench/ \\\n    --output_name stllm_instructblipbase_msrvttqa \\\n    --num-frames 64 \\\n    \n    \n    "
  },
  {
    "path": "script/inference/qabench/msvd_qa.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/qabench/msvd_qa.py \\\n    --cfg-path config/instructblipbase_stllm_qa.yaml \\\n    --ckpt-path /Path/to/STLLM_QA_weight \\\n    --video_dir /Path/to/MSVD/YouTubeClips \\\n    --gt_file /Path/to/MSVD-QA/test_qa.json \\\n    --output_dir test_output/qabench/ \\\n    --output_name stllm_instructblipbase_msvdqa \\\n    --num-frames 64 \\\n\n    \n    "
  },
  {
    "path": "script/inference/qabench/score_anet.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_output/qabench/stllm_instructblipbase_anetqa.json \\\n    --output_dir test_output/qabench/activityQA/stllm_instructblipbase \\\n    --output_json test_output/qabench/activityQA/stllm_instructblipbase/activityQA.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/qabench/score_msrvtt.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_output/qabench/stllm_instructblipbase_msrvttqa.json \\\n    --output_dir test_output/qabench/msrvttQA/stllm_instructblipbase \\\n    --output_json test_output/qabench/msrvttQA/stllm_instructblipbase/msrvttQA.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/qabench/score_msvd.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/gpt_evaluation/evaluate_activitynet_qa.py \\\n    --pred_path test_output/qabench/stllm_instructblipbase_msvdqa.json \\\n    --output_dir test_output/qabench/msvdQA/stllm_instructblipbase \\\n    --output_json test_output/qabench/msvdQA/stllm_instructblipbase/msvdQA.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/score_consist.sh",
    "content": "python stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py \\\n    --pred_path test_output/vcgbench/stllm_instructblipbase_consist.json \\\n    --output_dir test_output/vcgbench/consist/stllm_instructblipbase \\\n    --output_json test_output/vcgbench/consist/stllm_instructblipbase/consist.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/score_context.sh",
    "content": "python stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py \\\n    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \\\n    --output_dir test_output/vcgbench/context/stllm_instructblipbase \\\n    --output_json test_output/vcgbench/context/stllm_instructblipbase/context.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/score_correct.sh",
    "content": "python stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py \\\n    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \\\n    --output_dir test_output/vcgbench/correctness/stllm_instructblipbase \\\n    --output_json test_output/vcgbench/correctness/stllm_instructblipbase/correctness.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/score_detail.sh",
    "content": "python stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py \\\n    --pred_path test_output/vcgbench/stllm_instructblipbase_general.json \\\n    --output_dir test_output/vcgbench/detail/stllm_instructblipbase \\\n    --output_json test_output/vcgbench/detail/stllm_instructblipbase/detail.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/score_temporal.sh",
    "content": "python stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py \\\n    --pred_path test_output/vcgbench/stllm_instructblipbase_temporal.json \\\n    --output_dir test_output/vcgbench/temporal/stllm_instructblipbase \\\n    --output_json test_output/vcgbench/temporal/stllm_instructblipbase/temporal.json \\\n    --api_key openai_api_key \\\n    --num_tasks 3"
  },
  {
    "path": "script/inference/vcgbench/test_consist.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_consist.py \\\n    --cfg-path config/instructblipbase_stllm_conversation.yaml \\\n    --ckpt-path /Path/to/STLLM_conversation_weight \\\n    --video_dir /Path/to/video_chatgpt/Test_Videos \\\n    --gt_file /Path/to/video_chatgpt/Benchmarking_QA/consistency_qa.json \\\n    --output_dir test_output/vcgbench/ \\\n    --output_name stllm_instructblipbase_consist \\\n    --num-frames 64 \\\n    \n    "
  },
  {
    "path": "script/inference/vcgbench/test_general.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_general.py \\\n    --cfg-path config/instructblipbase_stllm_conversation.yaml \\\n    --ckpt-path /Path/to/STLLM_conversation_weight \\\n    --video_dir /Path/to/video_chatgpt/Test_Videos \\\n    --gt_file /Path/to/video_chatgpt/Benchmarking_QA/generic_qa.json \\\n    --output_dir test_output/vcgbench/ \\\n    --output_name stllm_instructblipbase_general \\\n    --num-frames 64 \\\n    \n    "
  },
  {
    "path": "script/inference/vcgbench/test_temporal.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\npython stllm/test/vcgbench/videochatgpt_benchmark_general.py \\\n    --cfg-path config/instructblipbase_stllm_conversation.yaml \\\n    --ckpt-path /Path/to/STLLM_conversation_weight \\\n    --video_dir /Path/to/video_chatgpt/Test_Videos \\\n    --gt_file /Path/to/Benchmarking_QA/temporal_qa.json \\\n    --output_dir test_output/vcgbench/ \\\n    --output_name stllm_instructblipbase_temporal \\\n    --num-frames 64 \\\n    \n    "
  },
  {
    "path": "script/train/train.sh",
    "content": "export PYTHONPATH=\"./:$PYTHONPATH\"\ndeepspeed --master_port=20000 --include=localhost:0,1,2,3,4,5,6,7 stllm/train/train_hf.py --cfg-path /Path/to/desired/config"
  },
  {
    "path": "stllm/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nimport sys\n\nfrom omegaconf import OmegaConf\n\nfrom stllm.common.registry import registry\n\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.tasks import *\n\n\nroot_dir = os.path.dirname(os.path.abspath(__file__))\ndefault_cfg = OmegaConf.load(os.path.join(root_dir, \"configs/default.yaml\"))\n\nregistry.register_path(\"library_root\", root_dir)\nrepo_root = os.path.join(root_dir, \"..\")\nregistry.register_path(\"repo_root\", repo_root)\ncache_root = os.path.join(repo_root, default_cfg.env.cache_root)\nregistry.register_path(\"cache_root\", cache_root)\n\nregistry.register(\"MAX_INT\", sys.maxsize)\nregistry.register(\"SPLIT_NAMES\", [\"train\", \"val\", \"test\"])\n"
  },
  {
    "path": "stllm/common/__init__.py",
    "content": ""
  },
  {
    "path": "stllm/common/config.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport json\nfrom typing import Dict\n\nfrom omegaconf import OmegaConf\nfrom stllm.common.registry import registry\n\n\nclass Config:\n    def __init__(self, args):\n        self.config = {}\n\n        self.args = args\n\n        # Register the config and configuration for setup\n        registry.register(\"configuration\", self)\n\n        user_config = self._build_opt_list(self.args.options)\n\n        config = OmegaConf.load(self.args.cfg_path)\n\n        runner_config = self.build_runner_config(config)\n        model_config = self.build_model_config(config, **user_config)\n        dataset_config = self.build_dataset_config(config)\n\n        # Validate the user-provided runner configuration\n        # model and dataset configuration are supposed to be validated by the respective classes\n        # [TODO] validate the model/dataset configuration\n        # self._validate_runner_config(runner_config)\n\n        # Override the default configuration with user options.\n        self.config = OmegaConf.merge(\n            runner_config, model_config, dataset_config, user_config\n        )\n\n    def _validate_runner_config(self, runner_config):\n        \"\"\"\n        This method validates the configuration, such that\n            1) all the user specified options are valid;\n            2) no type mismatches between the user specified options and the config.\n        \"\"\"\n        runner_config_validator = create_runner_config_validator()\n        runner_config_validator.validate(runner_config)\n\n    def _build_opt_list(self, opts):\n        opts_dot_list = self._convert_to_dot_list(opts)\n        return OmegaConf.from_dotlist(opts_dot_list)\n\n    @staticmethod\n    def build_model_config(config, **kwargs):\n        model = config.get(\"model\", None)\n        assert model is not None, \"Missing model configuration file.\"\n\n        model_cls = registry.get_model_class(model.arch)\n        assert model_cls is not None, f\"Model '{model.arch}' has not been registered.\"\n\n        model_type = kwargs.get(\"model.model_type\", None)\n        if not model_type:\n            model_type = model.get(\"model_type\", None)\n        # else use the model type selected by user.\n\n        assert model_type is not None, \"Missing model_type.\"\n\n        model_config_path = model_cls.default_config_path(model_type=model_type)\n\n        model_config = OmegaConf.create()\n        # hierarchy override, customized config > default config\n        model_config = OmegaConf.merge(\n            model_config,\n            OmegaConf.load(model_config_path),\n            {\"model\": config[\"model\"]},\n        )\n\n        return model_config\n\n    @staticmethod\n    def build_runner_config(config):\n        return {\"run\": config.run}\n\n    @staticmethod\n    def build_dataset_config(config):\n        datasets = config.get(\"datasets\", None)\n        if datasets is None:\n            raise KeyError(\n                \"Expecting 'datasets' as the root key for dataset configuration.\"\n            )\n\n        dataset_config = OmegaConf.create()\n\n        for dataset_name in datasets:\n            builder_cls = registry.get_builder_class(dataset_name)\n\n            dataset_config_type = datasets[dataset_name].get(\"type\", \"default\")\n            if builder_cls is not None:\n                dataset_config_path = builder_cls.default_config_path(\n                    type=dataset_config_type\n                )\n                default_config = OmegaConf.load(dataset_config_path)\n            else:\n                default_config = {}\n            # hierarchy override, customized config > default config\n            dataset_config = OmegaConf.merge(\n                dataset_config,\n                default_config,\n                {\"datasets\": {dataset_name: config[\"datasets\"][dataset_name]}},\n            )\n\n        return dataset_config\n\n    def _convert_to_dot_list(self, opts):\n        if opts is None:\n            opts = []\n\n        if len(opts) == 0:\n            return opts\n\n        has_equal = opts[0].find(\"=\") != -1\n\n        if has_equal:\n            return opts\n\n        return [(opt + \"=\" + value) for opt, value in zip(opts[0::2], opts[1::2])]\n\n    def get_config(self):\n        return self.config\n\n    @property\n    def run_cfg(self):\n        return self.config.run\n\n    @property\n    def datasets_cfg(self):\n        return self.config.datasets\n\n    @property\n    def model_cfg(self):\n        return self.config.model\n\n    def pretty_print(self):\n        logging.info(\"\\n=====  Running Parameters    =====\")\n        logging.info(self._convert_node_to_json(self.config.run))\n\n        logging.info(\"\\n======  Dataset Attributes  ======\")\n        datasets = self.config.datasets\n\n        for dataset in datasets:\n            if dataset in self.config.datasets:\n                logging.info(f\"\\n======== {dataset} =======\")\n                dataset_config = self.config.datasets[dataset]\n                logging.info(self._convert_node_to_json(dataset_config))\n            else:\n                logging.warning(f\"No dataset named '{dataset}' in config. Skipping\")\n\n        logging.info(f\"\\n======  Model Attributes  ======\")\n        logging.info(self._convert_node_to_json(self.config.model))\n\n    def _convert_node_to_json(self, node):\n        container = OmegaConf.to_container(node, resolve=True)\n        return json.dumps(container, indent=4, sort_keys=True)\n\n    def to_dict(self):\n        return OmegaConf.to_container(self.config)\n\n\ndef node_to_dict(node):\n    return OmegaConf.to_container(node)\n\n\nclass ConfigValidator:\n    \"\"\"\n    This is a preliminary implementation to centralize and validate the configuration.\n    May be altered in the future.\n\n    A helper class to validate configurations from yaml file.\n\n    This serves the following purposes:\n        1. Ensure all the options in the yaml are defined, raise error if not.\n        2. when type mismatches are found, the validator will raise an error.\n        3. a central place to store and display helpful messages for supported configurations.\n\n    \"\"\"\n\n    class _Argument:\n        def __init__(self, name, choices=None, type=None, help=None):\n            self.name = name\n            self.val = None\n            self.choices = choices\n            self.type = type\n            self.help = help\n\n        def __str__(self):\n            s = f\"{self.name}={self.val}\"\n            if self.type is not None:\n                s += f\", ({self.type})\"\n            if self.choices is not None:\n                s += f\", choices: {self.choices}\"\n            if self.help is not None:\n                s += f\", ({self.help})\"\n            return s\n\n    def __init__(self, description):\n        self.description = description\n\n        self.arguments = dict()\n\n        self.parsed_args = None\n\n    def __getitem__(self, key):\n        assert self.parsed_args is not None, \"No arguments parsed yet.\"\n\n        return self.parsed_args[key]\n\n    def __str__(self) -> str:\n        return self.format_help()\n\n    def add_argument(self, *args, **kwargs):\n        \"\"\"\n        Assume the first argument is the name of the argument.\n        \"\"\"\n        self.arguments[args[0]] = self._Argument(*args, **kwargs)\n\n    def validate(self, config=None):\n        \"\"\"\n        Convert yaml config (dict-like) to list, required by argparse.\n        \"\"\"\n        for k, v in config.items():\n            assert (\n                k in self.arguments\n            ), f\"\"\"{k} is not a valid argument. Support arguments are {self.format_arguments()}.\"\"\"\n\n            if self.arguments[k].type is not None:\n                try:\n                    self.arguments[k].val = self.arguments[k].type(v)\n                except ValueError:\n                    raise ValueError(f\"{k} is not a valid {self.arguments[k].type}.\")\n\n            if self.arguments[k].choices is not None:\n                assert (\n                    v in self.arguments[k].choices\n                ), f\"\"\"{k} must be one of {self.arguments[k].choices}.\"\"\"\n\n        return config\n\n    def format_arguments(self):\n        return str([f\"{k}\" for k in sorted(self.arguments.keys())])\n\n    def format_help(self):\n        # description + key-value pair string for each argument\n        help_msg = str(self.description)\n        return help_msg + \", available arguments: \" + self.format_arguments()\n\n    def print_help(self):\n        # display help message\n        print(self.format_help())\n\n\ndef create_runner_config_validator():\n    validator = ConfigValidator(description=\"Runner configurations\")\n\n    validator.add_argument(\n        \"runner\",\n        type=str,\n        choices=[\"runner_base\", \"runner_iter\"],\n        help=\"\"\"Runner to use. The \"runner_base\" uses epoch-based training while iter-based\n            runner runs based on iters. Default: runner_base\"\"\",\n    )\n    # add argumetns for training dataset ratios\n    validator.add_argument(\n        \"train_dataset_ratios\",\n        type=Dict[str, float],\n        help=\"\"\"Ratios of training dataset. This is used in iteration-based runner.\n        Do not support for epoch-based runner because how to define an epoch becomes tricky.\n        Default: None\"\"\",\n    )\n    validator.add_argument(\n        \"max_iters\",\n        type=float,\n        help=\"Maximum number of iterations to run.\",\n    )\n    validator.add_argument(\n        \"max_epoch\",\n        type=int,\n        help=\"Maximum number of epochs to run.\",\n    )\n    # add arguments for iters_per_inner_epoch\n    validator.add_argument(\n        \"iters_per_inner_epoch\",\n        type=float,\n        help=\"Number of iterations per inner epoch. This is required when runner is runner_iter.\",\n    )\n    lr_scheds_choices = registry.list_lr_schedulers()\n    validator.add_argument(\n        \"lr_sched\",\n        type=str,\n        choices=lr_scheds_choices,\n        help=\"Learning rate scheduler to use, from {}\".format(lr_scheds_choices),\n    )\n    task_choices = registry.list_tasks()\n    validator.add_argument(\n        \"task\",\n        type=str,\n        choices=task_choices,\n        help=\"Task to use, from {}\".format(task_choices),\n    )\n    # add arguments for init_lr\n    validator.add_argument(\n        \"init_lr\",\n        type=float,\n        help=\"Initial learning rate. This will be the learning rate after warmup and before decay.\",\n    )\n    # add arguments for min_lr\n    validator.add_argument(\n        \"min_lr\",\n        type=float,\n        help=\"Minimum learning rate (after decay).\",\n    )\n    # add arguments for warmup_lr\n    validator.add_argument(\n        \"warmup_lr\",\n        type=float,\n        help=\"Starting learning rate for warmup.\",\n    )\n    # add arguments for learning rate decay rate\n    validator.add_argument(\n        \"lr_decay_rate\",\n        type=float,\n        help=\"Learning rate decay rate. Required if using a decaying learning rate scheduler.\",\n    )\n    # add arguments for weight decay\n    validator.add_argument(\n        \"weight_decay\",\n        type=float,\n        help=\"Weight decay rate.\",\n    )\n    # add arguments for training batch size\n    validator.add_argument(\n        \"batch_size_train\",\n        type=int,\n        help=\"Training batch size.\",\n    )\n    # add arguments for evaluation batch size\n    validator.add_argument(\n        \"batch_size_eval\",\n        type=int,\n        help=\"Evaluation batch size, including validation and testing.\",\n    )\n    # add arguments for number of workers for data loading\n    validator.add_argument(\n        \"num_workers\",\n        help=\"Number of workers for data loading.\",\n    )\n    # add arguments for warm up steps\n    validator.add_argument(\n        \"warmup_steps\",\n        type=int,\n        help=\"Number of warmup steps. Required if a warmup schedule is used.\",\n    )\n    # add arguments for random seed\n    validator.add_argument(\n        \"seed\",\n        type=int,\n        help=\"Random seed.\",\n    )\n    # add arguments for output directory\n    validator.add_argument(\n        \"output_dir\",\n        type=str,\n        help=\"Output directory to save checkpoints and logs.\",\n    )\n    # add arguments for whether only use evaluation\n    validator.add_argument(\n        \"evaluate\",\n        help=\"Whether to only evaluate the model. If true, training will not be performed.\",\n    )\n    # add arguments for splits used for training, e.g. [\"train\", \"val\"]\n    validator.add_argument(\n        \"train_splits\",\n        type=list,\n        help=\"Splits to use for training.\",\n    )\n    # add arguments for splits used for validation, e.g. [\"val\"]\n    validator.add_argument(\n        \"valid_splits\",\n        type=list,\n        help=\"Splits to use for validation. If not provided, will skip the validation.\",\n    )\n    # add arguments for splits used for testing, e.g. [\"test\"]\n    validator.add_argument(\n        \"test_splits\",\n        type=list,\n        help=\"Splits to use for testing. If not provided, will skip the testing.\",\n    )\n    # add arguments for accumulating gradient for iterations\n    validator.add_argument(\n        \"accum_grad_iters\",\n        type=int,\n        help=\"Number of iterations to accumulate gradient for.\",\n    )\n\n    # ====== distributed training ======\n    validator.add_argument(\n        \"device\",\n        type=str,\n        choices=[\"cpu\", \"cuda\"],\n        help=\"Device to use. Support 'cuda' or 'cpu' as for now.\",\n    )\n    validator.add_argument(\n        \"world_size\",\n        type=int,\n        help=\"Number of processes participating in the job.\",\n    )\n    validator.add_argument(\"dist_url\", type=str)\n    validator.add_argument(\"distributed\", type=bool)\n    # add arguments to opt using distributed sampler during evaluation or not\n    validator.add_argument(\n        \"use_dist_eval_sampler\",\n        type=bool,\n        help=\"Whether to use distributed sampler during evaluation or not.\",\n    )\n\n    # ====== task specific ======\n    # generation task specific arguments\n    # add arguments for maximal length of text output\n    validator.add_argument(\n        \"max_len\",\n        type=int,\n        help=\"Maximal length of text output.\",\n    )\n    # add arguments for minimal length of text output\n    validator.add_argument(\n        \"min_len\",\n        type=int,\n        help=\"Minimal length of text output.\",\n    )\n    # add arguments number of beams\n    validator.add_argument(\n        \"num_beams\",\n        type=int,\n        help=\"Number of beams used for beam search.\",\n    )\n\n    # vqa task specific arguments\n    # add arguments for number of answer candidates\n    validator.add_argument(\n        \"num_ans_candidates\",\n        type=int,\n        help=\"\"\"For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.\"\"\",\n    )\n    # add arguments for inference method\n    validator.add_argument(\n        \"inference_method\",\n        type=str,\n        choices=[\"genearte\", \"rank\"],\n        help=\"\"\"Inference method to use for question answering. If rank, requires a answer list.\"\"\",\n    )\n\n    # ====== model specific ======\n    validator.add_argument(\n        \"k_test\",\n        type=int,\n        help=\"Number of top k most similar samples from ITC/VTC selection to be tested.\",\n    )\n\n    return validator\n"
  },
  {
    "path": "stllm/common/dist_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport functools\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport timm.models.hub as timm_hub\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop(\"force\", False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef init_distributed_mode(args):\n    if args.distributed is False:\n        print(\"Not using distributed mode\")\n        return\n    elif \"RANK\" in os.environ and \"WORLD_SIZE\" in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ[\"WORLD_SIZE\"])\n        args.gpu = int(os.environ[\"LOCAL_RANK\"])\n    elif \"SLURM_PROCID\" in os.environ:\n        args.rank = int(os.environ[\"SLURM_PROCID\"])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print(\"Not using distributed mode\")\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = \"nccl\"\n    print(\n        \"| distributed init (rank {}, world {}): {}\".format(\n            args.rank, args.world_size, args.dist_url\n        ),\n        flush=True,\n    )\n    torch.distributed.init_process_group(\n        backend=args.dist_backend,\n        init_method=args.dist_url,\n        world_size=args.world_size,\n        rank=args.rank,\n        timeout=datetime.timedelta(\n            days=365\n        ),  # allow auto-downloading and de-compressing\n    )\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\ndef get_dist_info():\n    if torch.__version__ < \"1.0\":\n        initialized = dist._initialized\n    else:\n        initialized = dist.is_initialized()\n    if initialized:\n        rank = dist.get_rank()\n        world_size = dist.get_world_size()\n    else:  # non-distributed training\n        rank = 0\n        world_size = 1\n    return rank, world_size\n\n\ndef main_process(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        rank, _ = get_dist_info()\n        if rank == 0:\n            return func(*args, **kwargs)\n\n    return wrapper\n\n\ndef download_cached_file(url, check_hash=True, progress=False):\n    \"\"\"\n    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.\n    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.\n    \"\"\"\n\n    def get_cached_file_path():\n        # a hack to sync the file path across processes\n        parts = torch.hub.urlparse(url)\n        filename = os.path.basename(parts.path)\n        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)\n\n        return cached_file\n\n    if is_main_process():\n        timm_hub.download_cached_file(url, check_hash, progress)\n\n    if is_dist_avail_and_initialized():\n        dist.barrier()\n\n    return get_cached_file_path()\n"
  },
  {
    "path": "stllm/common/gradcam.py",
    "content": "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom scipy.ndimage import filters\nfrom skimage import transform as skimage_transform\n\n\ndef getAttMap(img, attMap, blur=True, overlap=True):\n    attMap -= attMap.min()\n    if attMap.max() > 0:\n        attMap /= attMap.max()\n    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode=\"constant\")\n    if blur:\n        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))\n        attMap -= attMap.min()\n        attMap /= attMap.max()\n    cmap = plt.get_cmap(\"jet\")\n    attMapV = cmap(attMap)\n    attMapV = np.delete(attMapV, 3, 2)\n    if overlap:\n        attMap = (\n            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img\n            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV\n        )\n    return attMap\n"
  },
  {
    "path": "stllm/common/logger.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport logging\nimport time\nfrom collections import defaultdict, deque\n\nimport torch\nimport torch.distributed as dist\n\nfrom stllm.common import dist_utils\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not dist_utils.is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device=\"cuda\")\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value,\n        )\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\n            \"'{}' object has no attribute '{}'\".format(type(self).__name__, attr)\n        )\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\"{}: {}\".format(name, str(meter)))\n        return self.delimiter.join(loss_str)\n\n    def global_avg(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\"{}: {:.4f}\".format(name, meter.global_avg))\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = \"\"\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt=\"{avg:.4f}\")\n        data_time = SmoothedValue(fmt=\"{avg:.4f}\")\n        space_fmt = \":\" + str(len(str(len(iterable)))) + \"d\"\n        log_msg = [\n            header,\n            \"[{0\" + space_fmt + \"}/{1}]\",\n            \"eta: {eta}\",\n            \"{meters}\",\n            \"time: {time}\",\n            \"data: {data}\",\n        ]\n        if torch.cuda.is_available():\n            log_msg.append(\"max mem: {memory:.0f}\")\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(\n                        log_msg.format(\n                            i,\n                            len(iterable),\n                            eta=eta_string,\n                            meters=str(self),\n                            time=str(iter_time),\n                            data=str(data_time),\n                            memory=torch.cuda.max_memory_allocated() / MB,\n                        )\n                    )\n                else:\n                    print(\n                        log_msg.format(\n                            i,\n                            len(iterable),\n                            eta=eta_string,\n                            meters=str(self),\n                            time=str(iter_time),\n                            data=str(data_time),\n                        )\n                    )\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print(\n            \"{} Total time: {} ({:.4f} s / it)\".format(\n                header, total_time_str, total_time / len(iterable)\n            )\n        )\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\ndef setup_logger():\n    logging.basicConfig(\n        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,\n        format=\"%(asctime)s [%(levelname)s] %(message)s\",\n        handlers=[logging.StreamHandler()],\n    )\n"
  },
  {
    "path": "stllm/common/optims.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport math\n\nfrom stllm.common.registry import registry\n\n\n@registry.register_lr_scheduler(\"linear_warmup_step_lr\")\nclass LinearWarmupStepLRScheduler:\n    def __init__(\n        self,\n        optimizer,\n        max_epoch,\n        min_lr,\n        init_lr,\n        decay_rate=1,\n        warmup_start_lr=-1,\n        warmup_steps=0,\n        **kwargs\n    ):\n        self.optimizer = optimizer\n\n        self.max_epoch = max_epoch\n        self.min_lr = min_lr\n\n        self.decay_rate = decay_rate\n\n        self.init_lr = init_lr\n        self.warmup_steps = warmup_steps\n        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr\n\n    def step(self, cur_epoch, cur_step):\n        if cur_epoch == 0:\n            warmup_lr_schedule(\n                step=cur_step,\n                optimizer=self.optimizer,\n                max_step=self.warmup_steps,\n                init_lr=self.warmup_start_lr,\n                max_lr=self.init_lr,\n            )\n        else:\n            step_lr_schedule(\n                epoch=cur_epoch,\n                optimizer=self.optimizer,\n                init_lr=self.init_lr,\n                min_lr=self.min_lr,\n                decay_rate=self.decay_rate,\n            )\n\n@registry.register_lr_scheduler(\"linear_warmup_cosine_lr\")\nclass LinearWarmupCosineLRScheduler:\n    def __init__(\n        self,\n        optimizer,\n        max_epoch,\n        iters_per_epoch,\n        min_lr,\n        init_lr,\n        warmup_steps=0,\n        warmup_start_lr=-1,\n        **kwargs\n    ):\n        self.optimizer = optimizer\n\n        self.max_epoch = max_epoch\n        self.iters_per_epoch = iters_per_epoch\n        self.min_lr = min_lr\n\n        self.init_lr = init_lr\n        self.warmup_steps = warmup_steps\n        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr\n\n    def step(self, cur_epoch, cur_step):\n        total_cur_step = cur_epoch * self.iters_per_epoch + cur_step\n        if total_cur_step < self.warmup_steps:\n            warmup_lr_schedule(\n                step=cur_step,\n                optimizer=self.optimizer,\n                max_step=self.warmup_steps,\n                init_lr=self.warmup_start_lr,\n                max_lr=self.init_lr,\n            )\n        else:\n            cosine_lr_schedule(\n                epoch=total_cur_step,\n                optimizer=self.optimizer,\n                max_epoch=self.max_epoch * self.iters_per_epoch,\n                init_lr=self.init_lr,\n                min_lr=self.min_lr,\n            )\n\ndef cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):\n    \"\"\"Decay the learning rate\"\"\"\n    lr = (init_lr - min_lr) * 0.5 * (\n        1.0 + math.cos(math.pi * epoch / max_epoch)\n    ) + min_lr\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\ndef warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):\n    \"\"\"Warmup the learning rate\"\"\"\n    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\ndef step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):\n    \"\"\"Decay the learning rate\"\"\"\n    lr = max(min_lr, init_lr * (decay_rate**epoch))\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n"
  },
  {
    "path": "stllm/common/registry.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\n\nclass Registry:\n    mapping = {\n        \"builder_name_mapping\": {},\n        \"task_name_mapping\": {},\n        \"processor_name_mapping\": {},\n        \"model_name_mapping\": {},\n        \"lr_scheduler_name_mapping\": {},\n        \"runner_name_mapping\": {},\n        \"state\": {},\n        \"paths\": {},\n    }\n\n    @classmethod\n    def register_builder(cls, name):\n        r\"\"\"Register a dataset builder to registry with key 'name'\n\n        Args:\n            name: Key with which the builder will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n            from stllm.datasets.base_dataset_builder import BaseDatasetBuilder\n        \"\"\"\n\n        def wrap(builder_cls):\n            from stllm.datasets.builders.base_dataset_builder import BaseDatasetBuilder\n\n            assert issubclass(\n                builder_cls, BaseDatasetBuilder\n            ), \"All builders must inherit BaseDatasetBuilder class, found {}\".format(\n                builder_cls\n            )\n            if name in cls.mapping[\"builder_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"builder_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"builder_name_mapping\"][name] = builder_cls\n            return builder_cls\n\n        return wrap\n\n    @classmethod\n    def register_task(cls, name):\n        r\"\"\"Register a task to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n\n        def wrap(task_cls):\n            from stllm.tasks.base_task import BaseTask\n\n            assert issubclass(\n                task_cls, BaseTask\n            ), \"All tasks must inherit BaseTask class\"\n            if name in cls.mapping[\"task_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"task_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"task_name_mapping\"][name] = task_cls\n            return task_cls\n\n        return wrap\n\n    @classmethod\n    def register_model(cls, name):\n        r\"\"\"Register a task to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n\n        def wrap(model_cls):\n            from stllm.models import BaseModel\n\n            assert issubclass(\n                model_cls, BaseModel\n            ), \"All models must inherit BaseModel class\"\n            if name in cls.mapping[\"model_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"model_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"model_name_mapping\"][name] = model_cls\n            return model_cls\n\n        return wrap\n\n    @classmethod\n    def register_processor(cls, name):\n        r\"\"\"Register a processor to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n\n        def wrap(processor_cls):\n            from stllm.processors import BaseProcessor\n\n            assert issubclass(\n                processor_cls, BaseProcessor\n            ), \"All processors must inherit BaseProcessor class\"\n            if name in cls.mapping[\"processor_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"processor_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"processor_name_mapping\"][name] = processor_cls\n            return processor_cls\n\n        return wrap\n\n    @classmethod\n    def register_lr_scheduler(cls, name):\n        r\"\"\"Register a model to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n\n        def wrap(lr_sched_cls):\n            if name in cls.mapping[\"lr_scheduler_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"lr_scheduler_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"lr_scheduler_name_mapping\"][name] = lr_sched_cls\n            return lr_sched_cls\n\n        return wrap\n\n    @classmethod\n    def register_runner(cls, name):\n        r\"\"\"Register a model to registry with key 'name'\n\n        Args:\n            name: Key with which the task will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n\n        def wrap(runner_cls):\n            if name in cls.mapping[\"runner_name_mapping\"]:\n                raise KeyError(\n                    \"Name '{}' already registered for {}.\".format(\n                        name, cls.mapping[\"runner_name_mapping\"][name]\n                    )\n                )\n            cls.mapping[\"runner_name_mapping\"][name] = runner_cls\n            return runner_cls\n\n        return wrap\n\n    @classmethod\n    def register_path(cls, name, path):\n        r\"\"\"Register a path to registry with key 'name'\n\n        Args:\n            name: Key with which the path will be registered.\n\n        Usage:\n\n            from stllm.common.registry import registry\n        \"\"\"\n        assert isinstance(path, str), \"All path must be str.\"\n        if name in cls.mapping[\"paths\"]:\n            raise KeyError(\"Name '{}' already registered.\".format(name))\n        cls.mapping[\"paths\"][name] = path\n\n    @classmethod\n    def register(cls, name, obj):\n        r\"\"\"Register an item to registry with key 'name'\n\n        Args:\n            name: Key with which the item will be registered.\n\n        Usage::\n\n            from stllm.common.registry import registry\n\n            registry.register(\"config\", {})\n        \"\"\"\n        path = name.split(\".\")\n        current = cls.mapping[\"state\"]\n\n        for part in path[:-1]:\n            if part not in current:\n                current[part] = {}\n            current = current[part]\n\n        current[path[-1]] = obj\n\n    # @classmethod\n    # def get_trainer_class(cls, name):\n    #     return cls.mapping[\"trainer_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_builder_class(cls, name):\n        return cls.mapping[\"builder_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_model_class(cls, name):\n        return cls.mapping[\"model_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_task_class(cls, name):\n        return cls.mapping[\"task_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_processor_class(cls, name):\n        return cls.mapping[\"processor_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_lr_scheduler_class(cls, name):\n        return cls.mapping[\"lr_scheduler_name_mapping\"].get(name, None)\n\n    @classmethod\n    def get_runner_class(cls, name):\n        return cls.mapping[\"runner_name_mapping\"].get(name, None)\n\n    @classmethod\n    def list_runners(cls):\n        return sorted(cls.mapping[\"runner_name_mapping\"].keys())\n\n    @classmethod\n    def list_models(cls):\n        return sorted(cls.mapping[\"model_name_mapping\"].keys())\n\n    @classmethod\n    def list_tasks(cls):\n        return sorted(cls.mapping[\"task_name_mapping\"].keys())\n\n    @classmethod\n    def list_processors(cls):\n        return sorted(cls.mapping[\"processor_name_mapping\"].keys())\n\n    @classmethod\n    def list_lr_schedulers(cls):\n        return sorted(cls.mapping[\"lr_scheduler_name_mapping\"].keys())\n\n    @classmethod\n    def list_datasets(cls):\n        return sorted(cls.mapping[\"builder_name_mapping\"].keys())\n\n    @classmethod\n    def get_path(cls, name):\n        return cls.mapping[\"paths\"].get(name, None)\n\n    @classmethod\n    def get(cls, name, default=None, no_warning=False):\n        r\"\"\"Get an item from registry with key 'name'\n\n        Args:\n            name (string): Key whose value needs to be retrieved.\n            default: If passed and key is not in registry, default value will\n                     be returned with a warning. Default: None\n            no_warning (bool): If passed as True, warning when key doesn't exist\n                               will not be generated. Useful for MMF's\n                               internal operations. Default: False\n        \"\"\"\n        original_name = name\n        name = name.split(\".\")\n        value = cls.mapping[\"state\"]\n        for subname in name:\n            value = value.get(subname, default)\n            if value is default:\n                break\n\n        if (\n            \"writer\" in cls.mapping[\"state\"]\n            and value == default\n            and no_warning is False\n        ):\n            cls.mapping[\"state\"][\"writer\"].warning(\n                \"Key {} is not present in registry, returning default value \"\n                \"of {}\".format(original_name, default)\n            )\n        return value\n\n    @classmethod\n    def unregister(cls, name):\n        r\"\"\"Remove an item from registry with key 'name'\n\n        Args:\n            name: Key which needs to be removed.\n        Usage::\n\n            from mmf.common.registry import registry\n\n            config = registry.unregister(\"config\")\n        \"\"\"\n        return cls.mapping[\"state\"].pop(name, None)\n\n\nregistry = Registry()\n"
  },
  {
    "path": "stllm/common/utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport io\nimport json\nimport logging\nimport os\nimport pickle\nimport re\nimport shutil\nimport urllib\nimport urllib.error\nimport urllib.request\nfrom typing import Optional\nfrom urllib.parse import urlparse\n\nimport numpy as np\nimport pandas as pd\nimport yaml\nfrom iopath.common.download import download\nfrom iopath.common.file_io import file_lock, g_pathmgr\nfrom stllm.common.registry import registry\nfrom torch.utils.model_zoo import tqdm\nfrom torchvision.datasets.utils import (\n    check_integrity,\n    download_file_from_google_drive,\n    extract_archive,\n)\n\n\ndef now():\n    from datetime import datetime\n\n    return datetime.now().strftime(\"%Y%m%d%H%M\")[:-1]\n\n\ndef is_url(url_or_filename):\n    parsed = urlparse(url_or_filename)\n    return parsed.scheme in (\"http\", \"https\")\n\n\ndef get_cache_path(rel_path):\n    return os.path.expanduser(os.path.join(registry.get_path(\"cache_root\"), rel_path))\n\n\ndef get_abs_path(rel_path):\n    return os.path.join(registry.get_path(\"library_root\"), rel_path)\n\n\ndef load_json(filename):\n    with open(filename, \"r\") as f:\n        return json.load(f)\n\n\n# The following are adapted from torchvision and vissl\n# torchvision: https://github.com/pytorch/vision\n# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py\n\n\ndef makedir(dir_path):\n    \"\"\"\n    Create the directory if it does not exist.\n    \"\"\"\n    is_success = False\n    try:\n        if not g_pathmgr.exists(dir_path):\n            g_pathmgr.mkdirs(dir_path)\n        is_success = True\n    except BaseException:\n        print(f\"Error creating directory: {dir_path}\")\n    return is_success\n\n\ndef get_redirected_url(url: str):\n    \"\"\"\n    Given a URL, returns the URL it redirects to or the\n    original URL in case of no indirection\n    \"\"\"\n    import requests\n\n    with requests.Session() as session:\n        with session.get(url, stream=True, allow_redirects=True) as response:\n            if response.history:\n                return response.url\n            else:\n                return url\n\n\ndef to_google_drive_download_url(view_url: str) -> str:\n    \"\"\"\n    Utility function to transform a view URL of google drive\n    to a download URL for google drive\n    Example input:\n        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view\n    Example output:\n        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp\n    \"\"\"\n    splits = view_url.split(\"/\")\n    assert splits[-1] == \"view\"\n    file_id = splits[-2]\n    return f\"https://drive.google.com/uc?export=download&id={file_id}\"\n\n\ndef download_google_drive_url(url: str, output_path: str, output_file_name: str):\n    \"\"\"\n    Download a file from google drive\n    Downloading an URL from google drive requires confirmation when\n    the file of the size is too big (google drive notifies that\n    anti-viral checks cannot be performed on such files)\n    \"\"\"\n    import requests\n\n    with requests.Session() as session:\n\n        # First get the confirmation token and append it to the URL\n        with session.get(url, stream=True, allow_redirects=True) as response:\n            for k, v in response.cookies.items():\n                if k.startswith(\"download_warning\"):\n                    url = url + \"&confirm=\" + v\n\n        # Then download the content of the file\n        with session.get(url, stream=True, verify=True) as response:\n            makedir(output_path)\n            path = os.path.join(output_path, output_file_name)\n            total_size = int(response.headers.get(\"Content-length\", 0))\n            with open(path, \"wb\") as file:\n                from tqdm import tqdm\n\n                with tqdm(total=total_size) as progress_bar:\n                    for block in response.iter_content(\n                        chunk_size=io.DEFAULT_BUFFER_SIZE\n                    ):\n                        file.write(block)\n                        progress_bar.update(len(block))\n\n\ndef _get_google_drive_file_id(url: str) -> Optional[str]:\n    parts = urlparse(url)\n\n    if re.match(r\"(drive|docs)[.]google[.]com\", parts.netloc) is None:\n        return None\n\n    match = re.match(r\"/file/d/(?P<id>[^/]*)\", parts.path)\n    if match is None:\n        return None\n\n    return match.group(\"id\")\n\n\ndef _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:\n    with open(filename, \"wb\") as fh:\n        with urllib.request.urlopen(\n            urllib.request.Request(url, headers={\"User-Agent\": \"vissl\"})\n        ) as response:\n            with tqdm(total=response.length) as pbar:\n                for chunk in iter(lambda: response.read(chunk_size), \"\"):\n                    if not chunk:\n                        break\n                    pbar.update(chunk_size)\n                    fh.write(chunk)\n\n\ndef download_url(\n    url: str,\n    root: str,\n    filename: Optional[str] = None,\n    md5: Optional[str] = None,\n) -> None:\n    \"\"\"Download a file from a url and place it in root.\n    Args:\n        url (str): URL to download file from\n        root (str): Directory to place downloaded file in\n        filename (str, optional): Name to save the file under.\n                                  If None, use the basename of the URL.\n        md5 (str, optional): MD5 checksum of the download. If None, do not check\n    \"\"\"\n    root = os.path.expanduser(root)\n    if not filename:\n        filename = os.path.basename(url)\n    fpath = os.path.join(root, filename)\n\n    makedir(root)\n\n    # check if file is already present locally\n    if check_integrity(fpath, md5):\n        print(\"Using downloaded and verified file: \" + fpath)\n        return\n\n    # expand redirect chain if needed\n    url = get_redirected_url(url)\n\n    # check if file is located on Google Drive\n    file_id = _get_google_drive_file_id(url)\n    if file_id is not None:\n        return download_file_from_google_drive(file_id, root, filename, md5)\n\n    # download the file\n    try:\n        print(\"Downloading \" + url + \" to \" + fpath)\n        _urlretrieve(url, fpath)\n    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]\n        if url[:5] == \"https\":\n            url = url.replace(\"https:\", \"http:\")\n            print(\n                \"Failed download. Trying https -> http instead.\"\n                \" Downloading \" + url + \" to \" + fpath\n            )\n            _urlretrieve(url, fpath)\n        else:\n            raise e\n\n    # check integrity of downloaded file\n    if not check_integrity(fpath, md5):\n        raise RuntimeError(\"File not found or corrupted.\")\n\n\ndef download_and_extract_archive(\n    url: str,\n    download_root: str,\n    extract_root: Optional[str] = None,\n    filename: Optional[str] = None,\n    md5: Optional[str] = None,\n    remove_finished: bool = False,\n) -> None:\n    download_root = os.path.expanduser(download_root)\n    if extract_root is None:\n        extract_root = download_root\n    if not filename:\n        filename = os.path.basename(url)\n\n    download_url(url, download_root, filename, md5)\n\n    archive = os.path.join(download_root, filename)\n    print(\"Extracting {} to {}\".format(archive, extract_root))\n    extract_archive(archive, extract_root, remove_finished)\n\n\ndef cache_url(url: str, cache_dir: str) -> str:\n    \"\"\"\n    This implementation downloads the remote resource and caches it locally.\n    The resource will only be downloaded if not previously requested.\n    \"\"\"\n    parsed_url = urlparse(url)\n    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip(\"/\")))\n    makedir(dirname)\n    filename = url.split(\"/\")[-1]\n    cached = os.path.join(dirname, filename)\n    with file_lock(cached):\n        if not os.path.isfile(cached):\n            logging.info(f\"Downloading {url} to {cached} ...\")\n            cached = download(url, dirname, filename=filename)\n    logging.info(f\"URL {url} cached in {cached}\")\n    return cached\n\n\n# TODO (prigoyal): convert this into RAII-style API\ndef create_file_symlink(file1, file2):\n    \"\"\"\n    Simply create the symlinks for a given file1 to file2.\n    Useful during model checkpointing to symlinks to the\n    latest successful checkpoint.\n    \"\"\"\n    try:\n        if g_pathmgr.exists(file2):\n            g_pathmgr.rm(file2)\n        g_pathmgr.symlink(file1, file2)\n    except Exception as e:\n        logging.info(f\"Could NOT create symlink. Error: {e}\")\n\n\ndef save_file(data, filename, append_to_json=True, verbose=True):\n    \"\"\"\n    Common i/o utility to handle saving data to various file formats.\n    Supported:\n        .pkl, .pickle, .npy, .json\n    Specifically for .json, users have the option to either append (default)\n    or rewrite by passing in Boolean value to append_to_json.\n    \"\"\"\n    if verbose:\n        logging.info(f\"Saving data to file: {filename}\")\n    file_ext = os.path.splitext(filename)[1]\n    if file_ext in [\".pkl\", \".pickle\"]:\n        with g_pathmgr.open(filename, \"wb\") as fopen:\n            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)\n    elif file_ext == \".npy\":\n        with g_pathmgr.open(filename, \"wb\") as fopen:\n            np.save(fopen, data)\n    elif file_ext == \".json\":\n        if append_to_json:\n            with g_pathmgr.open(filename, \"a\") as fopen:\n                fopen.write(json.dumps(data, sort_keys=True) + \"\\n\")\n                fopen.flush()\n        else:\n            with g_pathmgr.open(filename, \"w\") as fopen:\n                fopen.write(json.dumps(data, sort_keys=True) + \"\\n\")\n                fopen.flush()\n    elif file_ext == \".yaml\":\n        with g_pathmgr.open(filename, \"w\") as fopen:\n            dump = yaml.dump(data)\n            fopen.write(dump)\n            fopen.flush()\n    else:\n        raise Exception(f\"Saving {file_ext} is not supported yet\")\n\n    if verbose:\n        logging.info(f\"Saved data to file: {filename}\")\n\n\ndef load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):\n    \"\"\"\n    Common i/o utility to handle loading data from various file formats.\n    Supported:\n        .pkl, .pickle, .npy, .json\n    For the npy files, we support reading the files in mmap_mode.\n    If the mmap_mode of reading is not successful, we load data without the\n    mmap_mode.\n    \"\"\"\n    if verbose:\n        logging.info(f\"Loading data from file: {filename}\")\n\n    file_ext = os.path.splitext(filename)[1]\n    if file_ext == \".txt\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = fopen.readlines()\n    elif file_ext in [\".pkl\", \".pickle\"]:\n        with g_pathmgr.open(filename, \"rb\") as fopen:\n            data = pickle.load(fopen, encoding=\"latin1\")\n    elif file_ext == \".npy\":\n        if mmap_mode:\n            try:\n                with g_pathmgr.open(filename, \"rb\") as fopen:\n                    data = np.load(\n                        fopen,\n                        allow_pickle=allow_pickle,\n                        encoding=\"latin1\",\n                        mmap_mode=mmap_mode,\n                    )\n            except ValueError as e:\n                logging.info(\n                    f\"Could not mmap {filename}: {e}. Trying without g_pathmgr\"\n                )\n                data = np.load(\n                    filename,\n                    allow_pickle=allow_pickle,\n                    encoding=\"latin1\",\n                    mmap_mode=mmap_mode,\n                )\n                logging.info(\"Successfully loaded without g_pathmgr\")\n            except Exception:\n                logging.info(\"Could not mmap without g_pathmgr. Trying without mmap\")\n                with g_pathmgr.open(filename, \"rb\") as fopen:\n                    data = np.load(fopen, allow_pickle=allow_pickle, encoding=\"latin1\")\n        else:\n            with g_pathmgr.open(filename, \"rb\") as fopen:\n                data = np.load(fopen, allow_pickle=allow_pickle, encoding=\"latin1\")\n    elif file_ext == \".json\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = json.load(fopen)\n    elif file_ext == \".yaml\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = yaml.load(fopen, Loader=yaml.FullLoader)\n    elif file_ext == \".csv\":\n        with g_pathmgr.open(filename, \"r\") as fopen:\n            data = pd.read_csv(fopen)\n    else:\n        raise Exception(f\"Reading from {file_ext} is not supported yet\")\n    return data\n\n\ndef abspath(resource_path: str):\n    \"\"\"\n    Make a path absolute, but take into account prefixes like\n    \"http://\" or \"manifold://\"\n    \"\"\"\n    regex = re.compile(r\"^\\w+://\")\n    if regex.match(resource_path) is None:\n        return os.path.abspath(resource_path)\n    else:\n        return resource_path\n\n\ndef makedir(dir_path):\n    \"\"\"\n    Create the directory if it does not exist.\n    \"\"\"\n    is_success = False\n    try:\n        if not g_pathmgr.exists(dir_path):\n            g_pathmgr.mkdirs(dir_path)\n        is_success = True\n    except BaseException:\n        logging.info(f\"Error creating directory: {dir_path}\")\n    return is_success\n\n\ndef is_url(input_url):\n    \"\"\"\n    Check if an input string is a url. look for http(s):// and ignoring the case\n    \"\"\"\n    is_url = re.match(r\"^(?:http)s?://\", input_url, re.IGNORECASE) is not None\n    return is_url\n\n\ndef cleanup_dir(dir):\n    \"\"\"\n    Utility for deleting a directory. Useful for cleaning the storage space\n    that contains various training artifacts like checkpoints, data etc.\n    \"\"\"\n    if os.path.exists(dir):\n        logging.info(f\"Deleting directory: {dir}\")\n        shutil.rmtree(dir)\n    logging.info(f\"Deleted contents of directory: {dir}\")\n\n\ndef get_file_size(filename):\n    \"\"\"\n    Given a file, get the size of file in MB\n    \"\"\"\n    size_in_mb = os.path.getsize(filename) / float(1024**2)\n    return size_in_mb\n"
  },
  {
    "path": "stllm/configs/datasets/cc_sbu/align.yaml",
    "content": "datasets:\n  cc_sbu_align:\n    data_type: images\n    build_info:\n      storage: cc_sbu_align\n"
  },
  {
    "path": "stllm/configs/datasets/cc_sbu/defaults.yaml",
    "content": "datasets:\n  cc_sbu:\n    data_type: images\n    build_info:\n      storage: /path/to/cc_sbu_dataset/{00000..01255}.tar\n"
  },
  {
    "path": "stllm/configs/datasets/laion/defaults.yaml",
    "content": "datasets:\n  laion:\n    data_type: images\n    build_info:\n      storage: /path/to/laion_dataset/{00000..10488}.tar\n"
  },
  {
    "path": "stllm/configs/default.yaml",
    "content": "env:\n  # For default users\n  # cache_root: \"cache\"\n  # For internal use with persistent storage\n  cache_root: \"/export/home/.cache/minigpt4\"\n"
  },
  {
    "path": "stllm/configs/models/instructblip_vicuna0.yaml",
    "content": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n  freeze_qformer: True\n  \n  # Q-Former\n  #q_former_model: '/path/to/instruct_blip_vicuna7b_trimmed.pth'\n  q_former_model: 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth'\n  num_query_token: 32\n\n  # generation configs\n  prompt: \"\"\n\n  llama_model: '/path/to/vicuna-7b-v1.1'\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip2_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "stllm/configs/models/instructblip_vicuna0_btadapter.yaml",
    "content": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  vit_model: \"eva_btadapter_g\"\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n  freeze_qformer: True\n\n  # Q-Former\n  #q_former_model: '/path/to/instruct_blip_vicuna7b_trimmed.pth'\n  q_former_model: 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth'\n  num_query_token: 32\n\n  # generation configs\n  prompt: \"\"\n\n  llama_model: \"/path/to/vicuna-7b-v1.1\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip2_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "stllm/configs/models/minigpt4_vicuna0.yaml",
    "content": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n  freeze_qformer: True\n\n  # Q-Former\n  #q_former_model: \"/path/to/blip2_pretrained_flant5xxl.pth\"\n  q_former_model: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\"\n  num_query_token: 32\n\n  # generation configs\n  prompt: \"\"\n\n  llama_model: \"/path/to/vicuna-7b\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip2_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "stllm/configs/models/minigpt4_vicuna0_btadapter.yaml",
    "content": "model:\n  arch: st_llm_hf\n\n  # vit encoder\n  vit_model: \"eva_btadapter_g\"\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_precision: \"fp16\"\n  freeze_vit: True\n  freeze_qformer: True\n\n  # Q-Former\n  #q_former_model: \"/path/to/blip2_pretrained_flant5xxl.pth\"\n  q_former_model: \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\"\n  num_query_token: 32\n\n  # generation configs\n  prompt: \"\"\n\n  llama_model: \"/path/to/vicuna-7b\"\n\npreprocess:\n    vis_processor:\n        train:\n          name: \"blip2_image_train\"\n          image_size: 224\n        eval:\n          name: \"blip2_image_eval\"\n          image_size: 224\n    text_processor:\n        train:\n          name: \"blip_caption\"\n        eval:\n          name: \"blip_caption\"\n"
  },
  {
    "path": "stllm/conversation/__init__.py",
    "content": ""
  },
  {
    "path": "stllm/conversation/conversation.py",
    "content": "import argparse\nimport time\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer\nfrom transformers import StoppingCriteria, StoppingCriteriaList\n\nimport dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple, Any\n\nfrom stllm.common.registry import registry\nfrom stllm.test.video_utils import load_video\nimport torchvision.transforms as T\nfrom stllm.test.video_transforms import (\n    GroupNormalize, GroupScale, GroupCenterCrop, \n    Stack, ToTorchFormatTensor\n)\nfrom torchvision.transforms.functional import InterpolationMode\n\n\nclass SeparatorStyle(Enum):\n    \"\"\"Different separator style.\"\"\"\n    SINGLE = auto()\n    TWO = auto()\n\n\n@dataclasses.dataclass\nclass Conversation:\n    \"\"\"A class that keeps all conversation history.\"\"\"\n    system: str\n    roles: List[str]\n    messages: List[List[str]]\n    offset: int\n    # system_img: List[Image.Image] = []\n    instruction: bool\n    sep_style: SeparatorStyle = SeparatorStyle.SINGLE\n    sep: str = \"###\"\n    sep2: str = None\n\n    skip_next: bool = False\n    conv_id: Any = None\n\n    def get_prompt(self):\n        if self.sep_style == SeparatorStyle.SINGLE:\n            ret = self.system + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.TWO:\n            seps = [self.sep, self.sep2]\n            ret = self.system + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + message + seps[i % 2]\n                else:\n                    ret += role\n            return ret\n        else:\n            raise ValueError(f\"Invalid style: {self.sep_style}\")\n\n    def append_message(self, role, message):\n        self.messages.append([role, message])\n\n    def to_gradio_chatbot(self):\n        ret = []\n        for i, (role, msg) in enumerate(self.messages[self.offset:]):\n            if i % 2 == 0:\n                ret.append([msg, None])\n            else:\n                ret[-1][-1] = msg\n        return ret\n\n    def copy(self):\n        return Conversation(\n            system=self.system,\n            # system_img=self.system_img,\n            roles=self.roles,\n            messages=[[x, y] for x, y in self.messages],\n            offset=self.offset,\n            instruction=self.instruction,\n            sep_style=self.sep_style,\n            sep=self.sep,\n            sep2=self.sep2,\n            conv_id=self.conv_id)\n\n    def dict(self):\n        return {\n            \"system\": self.system,\n            # \"system_img\": self.system_img,\n            \"roles\": self.roles,\n            \"messages\": self.messages,\n            \"offset\": self.offset,\n            \"sep\": self.sep,\n            \"sep2\": self.sep2,\n            \"conv_id\": self.conv_id,\n        }\n\n\nclass StoppingCriteriaSub(StoppingCriteria):\n\n    def __init__(self, stops=[], encounters=1):\n        super().__init__()\n        self.stops = stops\n\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n        for stop in self.stops:\n            if torch.all((stop == input_ids[0][-len(stop):])).item():\n                return True\n\n        return False\n\ndef get_residual_index(sample_segments, total_segments, devices):\n    seg_size = float(total_segments) / sample_segments\n    frame_indices = np.array([\n    int((seg_size / 2) + np.round(seg_size * idx))\n    for idx in range(sample_segments)\n    ])\n    frame_indices = torch.from_numpy(frame_indices).to(devices)\n    return frame_indices\n\nCONV_VISION_Vicuna0 = Conversation(\n    system=\"Give the following image: <Img>ImageContent</Img>. \"\n           \"You will be able to see the image once I provide it to you. Please answer my questions.\",\n    roles=(\"Human: \", \"Assistant: \"),\n    messages=[],\n    offset=2,\n    instruction=True,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nCONV_VIDEO_Vicuna0 = Conversation(\n    system=\"Give the following video: <Video>VideoContent</Video>. \"\n           \"You will be able to see the video once I provide it to you. Please answer my questions.\",\n    roles=(\"Human: \", \"Assistant: \"),\n    messages=[],\n    offset=2,\n    instruction=True,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nCONV_instructblip_Vicuna0 = Conversation(\n    system=\"Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, give your answer that best addresses the question.\\n\",\n    roles=(\"Human: \", \"Assistant: \"),\n    messages=[],\n    instruction=False,\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nCONV_VISION_LLama2 = Conversation(\n    system=\"Give the following image: <Img>ImageContent</Img>. \"\n           \"You will be able to see the image once I provide it to you. Please answer my questions.\",\n    roles=(\"<s>[INST] \", \" [/INST] \"),\n    messages=[],\n    offset=2,\n    instruction=True,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"\",\n)\n\nCONV_VIDEO_LLama2 = Conversation(\n    system=\"Give the following video: <Img>VideoContent</Img>. \"\n           \"You will be able to see the video once I provide it to you. Please answer my questions.\",\n    roles=(\"<s>[INST] \", \" [/INST] \"),\n    messages=[],\n    offset=2,\n    instruction=True,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"\",\n)\n\nclass Chat:\n    def __init__(self, model, device='cuda:0'):\n        self.device = device\n        self.model = model\n        if not hasattr(model,'llama_model'):\n            if hasattr(model.model,'stllm_model'):\n                self.model = model.model.stllm_model\n            else:\n                self.model = model.model.model.stllm_model\n            self.LLM = model\n\n        input_mean = [0.48145466, 0.4578275, 0.40821073]\n        input_std = [0.26862954, 0.26130258, 0.27577711]\n        self.transform = T.Compose([\n            GroupScale(int(224), interpolation=InterpolationMode.BICUBIC),\n            GroupCenterCrop(224),\n            Stack(),\n            ToTorchFormatTensor(),\n            GroupNormalize(input_mean, input_std) \n        ])\n        stop_words_ids = [torch.tensor([835]).to(self.device),\n                          torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.\n        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])\n\n    def ask(self, text, conv):\n        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \\\n                and (conv.messages[-1][1][-6:] == '</Img>' or conv.messages[-1][1][-8:] == '</Video>' \n                    or conv.messages[-1][1][-8:] == '</Frame>'):  # last message is image.\n            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])\n        else:\n            conv.append_message(conv.roles[0], text)\n\n    def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, system=True,\n               repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, do_sample=True):\n        conv.append_message(conv.roles[1], None)\n        if conv.instruction:\n            embs, attention_mask = self.get_context_emb(conv, img_list)\n        else:\n            embs, attention_mask = self.get_context_emb_sim(conv, img_list, system=system)\n            repetition_penalty = 1.5\n\n        current_max_len = embs.shape[1] + max_new_tokens\n        if current_max_len - max_length > 0:\n            print('Warning: The number of tokens in current conversation exceeds the max length. '\n                  'The model will not see the contexts outside the range.')\n        begin_idx = max(0, current_max_len - max_length)\n\n        embs = embs[:, begin_idx:]\n\n        llama_model = self.LLM if hasattr(self,'LLM') else self.model.llama_model\n        outputs = llama_model.generate(\n            inputs_embeds=embs,\n            max_new_tokens=max_new_tokens,\n            #attention_mask=attention_mask,\n            stopping_criteria=self.stopping_criteria,\n            num_beams=num_beams,\n            do_sample=do_sample,\n            min_length=min_length,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n            length_penalty=length_penalty,\n            temperature=temperature,\n        )\n        output_token = outputs[0]\n        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it\n            output_token = output_token[1:]\n        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it\n            output_token = output_token[1:]\n        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)\n        output_text = output_text.split('###')[0]  # remove the stop sign '###'\n        output_text = output_text.split('Assistant:')[-1].strip()\n        conv.messages[-1][1] = output_text\n        return output_text, output_token.cpu().numpy()\n\n    def upload_img(self, image, conv, img_list):\n        if isinstance(image, str):  # is a image path\n            raw_image = Image.open(image).convert('RGB')\n            image = self.transform([raw_image]).to(self.device)\n        elif isinstance(image, Image.Image):\n            raw_image = image\n            image = self.transform([raw_image]).to(self.device)\n        elif isinstance(image, torch.Tensor):\n            if len(image.shape) == 3:\n                image = image.unsqueeze(0)\n            image = image.to(self.device)\n\n        image_emb, _ = self.model.encode_img(image)\n        img_list.append(image_emb)\n        conv.append_message(conv.roles[0], \"<Img><ImageHere></Img>\")\n        msg = \"Received.\"\n        # self.conv.append_message(self.conv.roles[1], msg)\n        return msg\n\n    def upload_video(self, video, conv, img_list, num_frame=64, text=None):\n        raw_frames = load_video(video, num_frm=num_frame) if isinstance(video,str) else video\n        video_frames = self.transform(raw_frames).to(self.device) \n        bt, w, h = video_frames.size()\n        video_frames = video_frames.view(bt//3,3,w,h)\n\n        video_emb, _, _ = self.model.encode_img(video_frames, text=text)\n        if self.model.video_input == 'mean':\n            video_emb = video_emb.mean(dim=0, keepdim=True)\n        elif self.model.video_input == 'all':\n            video_emb = video_emb.view(1, -1, video_emb.size(-1))\n        elif self.model.video_input == 'residual':\n            T = video_emb.size(0)\n            residual_size = self.model.residual_size\n            residual_index = get_residual_index(residual_size, T, video_emb.device)\n            global_embeds = video_emb.mean(dim=0, keepdim=True)\n            local_embeds = video_emb[residual_index]\n            global_embeds = global_embeds.expand((residual_size,-1,-1)).to(self.model.up_proj.weight.dtype)\n            global_embeds = self.model.up_proj(self.model.non_linear_func(self.model.down_proj(global_embeds)))\n            video_emb = (local_embeds + global_embeds).view(1,-1,video_emb.size(-1)).contiguous()\n        \n        img_list.append(video_emb)\n        sign='<Video><ImageHere></Video>'\n        conv.append_message(conv.roles[0], sign)\n        msg = \"Received.\"\n        return msg\n    \n    def get_context_emb(self, conv, img_list):\n        prompt = conv.get_prompt()\n        prompt_segs = prompt.split('<ImageHere>')\n        assert len(prompt_segs) == len(img_list) + 1, \"Unmatched numbers of image placeholders and images.\"\n        seg_tokens = [\n            self.model.llama_tokenizer(\n                seg, return_tensors=\"pt\", add_special_tokens=i == 0).to(self.device).input_ids\n            # only add bos to the first seg\n            for i, seg in enumerate(prompt_segs)\n        ]\n        if hasattr(self.model, \"embed_tokens\"):\n            embed_tokens = self.model.embed_tokens\n        elif hasattr(self.model.llama_model.model, \"embed_tokens\"):\n            embed_tokens = self.model.llama_model.model.embed_tokens\n        else:\n            embed_tokens = self.model.llama_model.model.model.embed_tokens\n        seg_embs = [embed_tokens(seg_t) for seg_t in seg_tokens]\n        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]\n        mixed_embs = torch.cat(mixed_embs, dim=1)\n        return mixed_embs, None\n    \n    def get_context_emb_sim(self, conv, img_list, system=True):\n        question = conv.messages[0][1]\n        question = question.split('</Video> ')[1]\n        system = conv.system if system else \"\"\n        question = system + \"###Human: \" + question + \" ###Assistant: \"\n        seg_tokens = self.model.llama_tokenizer(\n                [question], return_tensors=\"pt\", add_special_tokens=0 == 0).to(self.device)\n        \n        if hasattr(self.model, \"embed_tokens\"):\n            embed_tokens = self.model.embed_tokens\n        elif hasattr(self.model.llama_model.model, \"embed_tokens\"):\n            embed_tokens = self.model.llama_model.model.embed_tokens\n        else:\n            embed_tokens = self.model.llama_model.model.model.embed_tokens\n        seg_embs = embed_tokens(seg_tokens.input_ids) \n        mixed_embs = torch.cat((img_list[0],seg_embs), dim=1)\n        atts_img = torch.ones(img_list[0].size()[:-1], dtype=torch.long).to(mixed_embs.device)\n        attention_mask = torch.cat([atts_img, seg_tokens.attention_mask], dim=1)\n        return mixed_embs, attention_mask\n        \n\n\n"
  },
  {
    "path": "stllm/conversation/mvbench_conversation.py",
    "content": "import torch\nimport numpy as np\nfrom transformers import StoppingCriteria, StoppingCriteriaList\n\ndef get_prompt(conv):\n    ret = conv.system + conv.sep\n    for role, message in conv.messages:\n        if message:\n            ret += role + \": \" + message + conv.sep\n        else:\n            ret += role + \":\"\n    return ret\n\ndef get_prompt2(conv):\n    ret = conv.system + conv.sep\n    count = 0\n    for role, message in conv.messages:\n        count += 1\n        if count == len(conv.messages):\n            ret += role + \": \" + message\n        else:\n            if message:\n                ret += role + \": \" + message + conv.sep\n            else:\n                ret += role + \":\"\n    return ret\n\ndef get_context_emb(conv, model, img_list, answer_prompt=None):\n    if answer_prompt:\n        prompt = get_prompt2(conv)\n    else:\n        prompt = get_prompt(conv)\n    if '<VideoHere>' in prompt:\n        prompt_segs = prompt.split('<VideoHere>')\n    else:\n        prompt_segs = prompt.split('<ImageHere>')\n    assert len(prompt_segs) == len(img_list) + 1, \"Unmatched numbers of image placeholders and images.\"\n\n    if hasattr(model.model,'stllm_model'):\n        model = model.model.stllm_model\n    else:\n        model = model.model.model.stllm_model\n    if hasattr(model, \"embed_tokens\"):\n        embed_tokens = model.embed_tokens\n    elif hasattr(model.llama_model.model, \"embed_tokens\"):\n        embed_tokens = model.llama_model.model.embed_tokens\n    else:\n        embed_tokens = model.llama_model.model.model.embed_tokens\n        \n    with torch.no_grad():\n        seg_tokens = [\n            model.llama_tokenizer(\n                seg, return_tensors=\"pt\", add_special_tokens=i == 0).to(\"cuda:0\").input_ids\n            # only add bos to the first seg\n            for i, seg in enumerate(prompt_segs)\n        ]\n        seg_embs = [embed_tokens(seg_t) for seg_t in seg_tokens]\n    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]\n    mixed_embs = torch.cat(mixed_embs, dim=1)\n    return mixed_embs\n\ndef get_context_emb_sim(conv, model, img_list, answer_prompt=None):\n    if answer_prompt:\n        prompt = get_prompt2(conv)\n    else:\n        prompt = get_prompt(conv)\n    question = prompt.split('</Video>\\n')[1]\n    if hasattr(model.model,'stllm_model'):\n        model = model.model.stllm_model\n    else:\n        model = model.model.model.stllm_model\n\n    if hasattr(model, \"embed_tokens\"):\n        embed_tokens = model.embed_tokens\n    elif hasattr(model.llama_model.model, \"embed_tokens\"):\n        embed_tokens = model.llama_model.model.embed_tokens\n    else:\n        embed_tokens = model.llama_model.model.model.embed_tokens\n\n    with torch.no_grad():\n        seg_tokens = model.llama_tokenizer(\n                [question], return_tensors=\"pt\", add_special_tokens=0 == 0).to(\"cuda:0\")\n        seg_embs = embed_tokens(seg_tokens.input_ids) \n    mixed_embs = torch.cat((img_list[0],seg_embs), dim=1)\n    return mixed_embs\n\ndef ask(text, conv):\n    conv.messages.append([conv.roles[0], text + '\\n'])       \n\nclass StoppingCriteriaSub(StoppingCriteria):\n    def __init__(self, stops=[], encounters=1):\n        super().__init__()\n        self.stops = stops\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n        for stop in self.stops:\n            if torch.all((stop == input_ids[0][-len(stop):])).item():\n                return True\n        return False\n       \ndef answer(conv, model, img_list, ask_simple=False, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,\n               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None):\n    stop_words_ids = [\n        torch.tensor([835]).to(\"cuda:0\"),\n        torch.tensor([2277, 29937]).to(\"cuda:0\")]  # '###' can be encoded in two different ways.\n    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])\n    \n    conv.messages.append([conv.roles[1], answer_prompt])\n    if ask_simple:\n        embs = get_context_emb_sim(conv, model, img_list, answer_prompt=answer_prompt)\n    else:\n        embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt)\n    with torch.no_grad():\n        generate_model = model if not hasattr(model,'llama_model') else model.llama_model\n        outputs = generate_model.generate(\n            inputs_embeds=embs,\n            max_new_tokens=max_new_tokens,\n            stopping_criteria=stopping_criteria,\n            num_beams=num_beams,\n            do_sample=do_sample,\n            min_length=min_length,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n            length_penalty=length_penalty,\n            temperature=temperature,\n        )\n    output_token = outputs[0]\n    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it\n            output_token = output_token[1:]\n    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it\n            output_token = output_token[1:]\n            \n    if hasattr(model,'llama_model'):\n        model = model\n    elif hasattr(model.model,'stllm_model'):\n        model = model.model.stllm_model\n    else:\n        model = model.model.model.stllm_model\n    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)\n    output_text = output_text.split('###')[0]  # remove the stop sign '###'\n    output_text = output_text.split('Assistant:')[-1].strip()\n    conv.messages[-1][1] = output_text\n    return output_text, output_token.cpu().numpy()\n\nclass EasyDict(dict):\n    \"\"\"\n    Get attributes\n\n    >>> d = EasyDict({'foo':3})\n    >>> d['foo']\n    3\n    >>> d.foo\n    3\n    >>> d.bar\n    Traceback (most recent call last):\n    ...\n    AttributeError: 'EasyDict' object has no attribute 'bar'\n\n    Works recursively\n\n    >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})\n    >>> isinstance(d.bar, dict)\n    True\n    >>> d.bar.x\n    1\n\n    Bullet-proof\n\n    >>> EasyDict({})\n    {}\n    >>> EasyDict(d={})\n    {}\n    >>> EasyDict(None)\n    {}\n    >>> d = {'a': 1}\n    >>> EasyDict(**d)\n    {'a': 1}\n\n    Set attributes\n\n    >>> d = EasyDict()\n    >>> d.foo = 3\n    >>> d.foo\n    3\n    >>> d.bar = {'prop': 'value'}\n    >>> d.bar.prop\n    'value'\n    >>> d\n    {'foo': 3, 'bar': {'prop': 'value'}}\n    >>> d.bar.prop = 'newer'\n    >>> d.bar.prop\n    'newer'\n\n\n    Values extraction\n\n    >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})\n    >>> isinstance(d.bar, list)\n    True\n    >>> from operator import attrgetter\n    >>> map(attrgetter('x'), d.bar)\n    [1, 3]\n    >>> map(attrgetter('y'), d.bar)\n    [2, 4]\n    >>> d = EasyDict()\n    >>> d.keys()\n    []\n    >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))\n    >>> d.foo\n    3\n    >>> d.bar.x\n    1\n\n    Still like a dict though\n\n    >>> o = EasyDict({'clean':True})\n    >>> o.items()\n    [('clean', True)]\n\n    And like a class\n\n    >>> class Flower(EasyDict):\n    ...     power = 1\n    ...\n    >>> f = Flower()\n    >>> f.power\n    1\n    >>> f = Flower({'height': 12})\n    >>> f.height\n    12\n    >>> f['power']\n    1\n    >>> sorted(f.keys())\n    ['height', 'power']\n\n    update and pop items\n    >>> d = EasyDict(a=1, b='2')\n    >>> e = EasyDict(c=3.0, a=9.0)\n    >>> d.update(e)\n    >>> d.c\n    3.0\n    >>> d['c']\n    3.0\n    >>> d.get('c')\n    3.0\n    >>> d.update(a=4, b=4)\n    >>> d.b\n    4\n    >>> d.pop('a')\n    4\n    >>> d.a\n    Traceback (most recent call last):\n    ...\n    AttributeError: 'EasyDict' object has no attribute 'a'\n    \"\"\"\n\n    def __init__(self, d=None, **kwargs):\n        if d is None:\n            d = {}\n        if kwargs:\n            d.update(**kwargs)\n        for k, v in d.items():\n            setattr(self, k, v)\n        # Class attributes\n        for k in self.__class__.__dict__.keys():\n            if not (k.startswith(\"__\") and k.endswith(\"__\")) and not k in (\"update\", \"pop\"):\n                setattr(self, k, getattr(self, k))\n\n    def __setattr__(self, name, value):\n        if isinstance(value, (list, tuple)):\n            value = [self.__class__(x) if isinstance(x, dict) else x for x in value]\n        elif isinstance(value, dict) and not isinstance(value, self.__class__):\n            value = self.__class__(value)\n        super(EasyDict, self).__setattr__(name, value)\n        super(EasyDict, self).__setitem__(name, value)\n\n    __setitem__ = __setattr__\n\n    def update(self, e=None, **f):\n        d = e or dict()\n        d.update(f)\n        for k in d:\n            setattr(self, k, d[k])\n\n    def pop(self, k, d=None):\n        if hasattr(self, k):\n            delattr(self, k)\n        return super(EasyDict, self).pop(k, d)\n    \n    "
  },
  {
    "path": "stllm/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "stllm/datasets/builders/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom stllm.datasets.builders.base_dataset_builder import load_dataset_config\nfrom stllm.datasets.builders.image_text_pair_builder import (\n    CCSBUBuilder,\n    LaionBuilder,\n    CCSBUAlignBuilder\n)\nfrom stllm.common.registry import registry\n\n__all__ = [\n    \"CCSBUBuilder\",\n    \"LaionBuilder\",\n    \"CCSBUAlignBuilder\"\n]\n\n\ndef load_dataset(name, cfg_path=None, vis_path=None, data_type=None):\n    \"\"\"\n    Example\n\n    >>> dataset = load_dataset(\"coco_caption\", cfg=None)\n    >>> splits = dataset.keys()\n    >>> print([len(dataset[split]) for split in splits])\n\n    \"\"\"\n    if cfg_path is None:\n        cfg = None\n    else:\n        cfg = load_dataset_config(cfg_path)\n\n    try:\n        builder = registry.get_builder_class(name)(cfg)\n    except TypeError:\n        print(\n            f\"Dataset {name} not found. Available datasets:\\n\"\n            + \", \".join([str(k) for k in dataset_zoo.get_names()])\n        )\n        exit(1)\n\n    if vis_path is not None:\n        if data_type is None:\n            # use default data type in the config\n            data_type = builder.config.data_type\n\n        assert (\n            data_type in builder.config.build_info\n        ), f\"Invalid data_type {data_type} for {name}.\"\n\n        builder.config.build_info.get(data_type).storage = vis_path\n\n    dataset = builder.build_datasets()\n    return dataset\n\n\nclass DatasetZoo:\n    def __init__(self) -> None:\n        self.dataset_zoo = {\n            k: list(v.DATASET_CONFIG_DICT.keys())\n            for k, v in sorted(registry.mapping[\"builder_name_mapping\"].items())\n        }\n\n    def get_names(self):\n        return list(self.dataset_zoo.keys())\n\n\ndataset_zoo = DatasetZoo()\n"
  },
  {
    "path": "stllm/datasets/builders/base_dataset_builder.py",
    "content": "\"\"\"\n This file is from\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\nimport shutil\nimport warnings\n\nfrom omegaconf import OmegaConf\nimport torch.distributed as dist\nfrom torchvision.datasets.utils import download_url\n\nimport stllm.common.utils as utils\nfrom stllm.common.dist_utils import is_dist_avail_and_initialized, is_main_process\nfrom stllm.common.registry import registry\nfrom stllm.processors.base_processor import BaseProcessor\n\n\n\nclass BaseDatasetBuilder:\n    train_dataset_cls, eval_dataset_cls = None, None\n\n    def __init__(self, cfg=None):\n        super().__init__()\n\n        if cfg is None:\n            # help to create datasets from default config.\n            self.config = load_dataset_config(self.default_config_path())\n        elif isinstance(cfg, str):\n            self.config = load_dataset_config(cfg)\n        else:\n            # when called from task.build_dataset()\n            self.config = cfg\n\n        self.data_type = self.config.data_type\n\n        self.vis_processors = {\"train\": BaseProcessor(), \"eval\": BaseProcessor()}\n        self.text_processors = {\"train\": BaseProcessor(), \"eval\": BaseProcessor()}\n\n    def build_datasets(self):\n        # download, split, etc...\n        # only called on 1 GPU/TPU in distributed\n\n        if is_main_process():\n            self._download_data()\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.\n        logging.info(\"Building datasets...\")\n        datasets = self.build()  # dataset['train'/'val'/'test']\n\n        return datasets\n\n    def build_processors(self):\n        vis_proc_cfg = self.config.get(\"vis_processor\")\n        txt_proc_cfg = self.config.get(\"text_processor\")\n\n        if vis_proc_cfg is not None:\n            vis_train_cfg = vis_proc_cfg.get(\"train\")\n            vis_eval_cfg = vis_proc_cfg.get(\"eval\")\n\n            self.vis_processors[\"train\"] = self._build_proc_from_cfg(vis_train_cfg)\n            self.vis_processors[\"eval\"] = self._build_proc_from_cfg(vis_eval_cfg)\n\n        if txt_proc_cfg is not None:\n            txt_train_cfg = txt_proc_cfg.get(\"train\")\n            txt_eval_cfg = txt_proc_cfg.get(\"eval\")\n\n            self.text_processors[\"train\"] = self._build_proc_from_cfg(txt_train_cfg)\n            self.text_processors[\"eval\"] = self._build_proc_from_cfg(txt_eval_cfg)\n\n    @staticmethod\n    def _build_proc_from_cfg(cfg):\n        return (\n            registry.get_processor_class(cfg.name).from_config(cfg)\n            if cfg is not None\n            else None\n        )\n\n    @classmethod\n    def default_config_path(cls, type=\"default\"):\n        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])\n\n    def _download_data(self):\n        self._download_ann()\n        self._download_vis()\n\n    def _download_ann(self):\n        \"\"\"\n        Download annotation files if necessary.\n        All the vision-language datasets should have annotations of unified format.\n\n        storage_path can be:\n          (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.\n          (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.\n\n        Local annotation paths should be relative.\n        \"\"\"\n        anns = self.config.build_info.annotations\n\n        splits = anns.keys()\n\n        cache_root = registry.get_path(\"cache_root\")\n\n        for split in splits:\n            info = anns[split]\n\n            urls, storage_paths = info.get(\"url\", None), info.storage\n\n            if isinstance(urls, str):\n                urls = [urls]\n            if isinstance(storage_paths, str):\n                storage_paths = [storage_paths]\n\n            assert len(urls) == len(storage_paths)\n\n            for url_or_filename, storage_path in zip(urls, storage_paths):\n                # if storage_path is relative, make it full by prefixing with cache_root.\n                if not os.path.isabs(storage_path):\n                    storage_path = os.path.join(cache_root, storage_path)\n\n                dirname = os.path.dirname(storage_path)\n                if not os.path.exists(dirname):\n                    os.makedirs(dirname)\n\n                if os.path.isfile(url_or_filename):\n                    src, dst = url_or_filename, storage_path\n                    if not os.path.exists(dst):\n                        shutil.copyfile(src=src, dst=dst)\n                    else:\n                        logging.info(\"Using existing file {}.\".format(dst))\n                else:\n                    if os.path.isdir(storage_path):\n                        # if only dirname is provided, suffix with basename of URL.\n                        raise ValueError(\n                            \"Expecting storage_path to be a file path, got directory {}\".format(\n                                storage_path\n                            )\n                        )\n                    else:\n                        filename = os.path.basename(storage_path)\n\n                    download_url(url=url_or_filename, root=dirname, filename=filename)\n\n    def _download_vis(self):\n\n        storage_path = self.config.build_info.get(self.data_type).storage\n        storage_path = utils.get_cache_path(storage_path)\n\n        if not os.path.exists(storage_path):\n            warnings.warn(\n                f\"\"\"\n                The specified path {storage_path} for visual inputs does not exist.\n                Please provide a correct path to the visual inputs or\n                refer to datasets/download_scripts/README.md for downloading instructions.\n                \"\"\"\n            )\n\n    def build(self):\n        \"\"\"\n        Create by split datasets inheriting torch.utils.data.Datasets.\n\n        # build() can be dataset-specific. Overwrite to customize.\n        \"\"\"\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        ann_info = build_info.annotations\n        vis_info = build_info.get(self.data_type)\n\n        datasets = dict()\n        for split in ann_info.keys():\n            if split not in [\"train\", \"val\", \"test\"]:\n                continue\n\n            is_train = split == \"train\"\n\n            # processors\n            vis_processor = (\n                self.vis_processors[\"train\"]\n                if is_train\n                else self.vis_processors[\"eval\"]\n            )\n            text_processor = (\n                self.text_processors[\"train\"]\n                if is_train\n                else self.text_processors[\"eval\"]\n            )\n\n            # annotation path\n            ann_paths = ann_info.get(split).storage\n            if isinstance(ann_paths, str):\n                ann_paths = [ann_paths]\n\n            abs_ann_paths = []\n            for ann_path in ann_paths:\n                if not os.path.isabs(ann_path):\n                    ann_path = utils.get_cache_path(ann_path)\n                abs_ann_paths.append(ann_path)\n            ann_paths = abs_ann_paths\n\n            # visual data storage path\n            vis_path = os.path.join(vis_info.storage, split)\n\n            if not os.path.isabs(vis_path):\n                # vis_path = os.path.join(utils.get_cache_path(), vis_path)\n                vis_path = utils.get_cache_path(vis_path)\n\n            if not os.path.exists(vis_path):\n                warnings.warn(\"storage path {} does not exist.\".format(vis_path))\n\n            # create datasets\n            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls\n            datasets[split] = dataset_cls(\n                vis_processor=vis_processor,\n                text_processor=text_processor,\n                ann_paths=ann_paths,\n                vis_root=vis_path,\n            )\n\n        return datasets\n\n\ndef load_dataset_config(cfg_path):\n    cfg = OmegaConf.load(cfg_path).datasets\n    cfg = cfg[list(cfg.keys())[0]]\n\n    return cfg\n"
  },
  {
    "path": "stllm/datasets/builders/image_text_pair_builder.py",
    "content": "import os\nimport logging\nimport warnings\n\nfrom stllm.common.registry import registry\nfrom stllm.datasets.builders.base_dataset_builder import BaseDatasetBuilder\nfrom stllm.datasets.datasets.laion_dataset import LaionDataset\nfrom stllm.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset\n\n\n@registry.register_builder(\"cc_sbu\")\nclass CCSBUBuilder(BaseDatasetBuilder):\n    train_dataset_cls = CCSBUDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/cc_sbu/defaults.yaml\"}\n\n    def _download_ann(self):\n        pass\n\n    def _download_vis(self):\n        pass\n\n    def build(self):\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        datasets = dict()\n        split = \"train\"\n\n        # create datasets\n        # [NOTE] return inner_datasets (wds.DataPipeline)\n        dataset_cls = self.train_dataset_cls\n        datasets[split] = dataset_cls(\n            vis_processor=self.vis_processors[split],\n            text_processor=self.text_processors[split],\n            location=build_info.storage,\n        ).inner_dataset\n\n        return datasets\n\n\n@registry.register_builder(\"laion\")\nclass LaionBuilder(BaseDatasetBuilder):\n    train_dataset_cls = LaionDataset\n\n    DATASET_CONFIG_DICT = {\"default\": \"configs/datasets/laion/defaults.yaml\"}\n\n    def _download_ann(self):\n        pass\n\n    def _download_vis(self):\n        pass\n\n    def build(self):\n        self.build_processors()\n\n        build_info = self.config.build_info\n\n        datasets = dict()\n        split = \"train\"\n\n        # create datasets\n        # [NOTE] return inner_datasets (wds.DataPipeline)\n        dataset_cls = self.train_dataset_cls\n        datasets[split] = dataset_cls(\n            vis_processor=self.vis_processors[split],\n            text_processor=self.text_processors[split],\n            location=build_info.storage,\n        ).inner_dataset\n\n        return datasets\n\n\n@registry.register_builder(\"cc_sbu_align\")\nclass CCSBUAlignBuilder(BaseDatasetBuilder):\n    train_dataset_cls = CCSBUAlignDataset\n\n    DATASET_CONFIG_DICT = {\n        \"default\": \"configs/datasets/cc_sbu/align.yaml\",\n    }\n\n    def build_datasets(self):\n        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.\n        logging.info(\"Building datasets...\")\n        self.build_processors()\n\n        build_info = self.config.build_info\n        storage_path = build_info.storage\n\n        datasets = dict()\n\n        if not os.path.exists(storage_path):\n            warnings.warn(\"storage path {} does not exist.\".format(storage_path))\n\n        # create datasets\n        dataset_cls = self.train_dataset_cls\n        datasets['train'] = dataset_cls(\n            vis_processor=self.vis_processors[\"train\"],\n            text_processor=self.text_processors[\"train\"],\n            ann_paths=[os.path.join(storage_path, 'filter_cap.json')],\n            vis_root=os.path.join(storage_path, 'image'),\n        )\n\n        return datasets\n\n"
  },
  {
    "path": "stllm/datasets/data_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport gzip\nimport logging\nimport os\nimport random as rnd\nimport tarfile\nimport zipfile\nimport random\nfrom typing import List\nfrom tqdm import tqdm\n\nimport decord\nfrom decord import VideoReader\nimport webdataset as wds\nimport numpy as np\nimport torch\nfrom torch.utils.data.dataset import IterableDataset\n\nfrom stllm.common.registry import registry\nfrom stllm.datasets.datasets.base_dataset import ConcatDataset\n\n\ndecord.bridge.set_bridge(\"torch\")\nMAX_INT = registry.get(\"MAX_INT\")\n\n\nclass ChainDataset(wds.DataPipeline):\n    r\"\"\"Dataset for chaining multiple :class:`DataPipeline` s.\n\n    This class is useful to assemble different existing dataset streams. The\n    chaining operation is done on-the-fly, so concatenating large-scale\n    datasets with this class will be efficient.\n\n    Args:\n        datasets (iterable of IterableDataset): datasets to be chained together\n    \"\"\"\n    def __init__(self, datasets: List[wds.DataPipeline]) -> None:\n        super().__init__()\n        self.datasets = datasets\n        self.prob = []\n        self.names = []\n        for dataset in self.datasets:\n            if hasattr(dataset, 'name'):\n                self.names.append(dataset.name)\n            else:\n                self.names.append('Unknown')\n            if hasattr(dataset, 'sample_ratio'):\n                self.prob.append(dataset.sample_ratio)\n            else:\n                self.prob.append(1)\n                logging.info(\"One of the datapipeline doesn't define ratio and set to 1 automatically.\")\n\n    def __iter__(self):\n        datastreams = [iter(dataset) for dataset in self.datasets]\n        while True:\n            select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]\n            yield next(select_datastream)\n\n\ndef apply_to_sample(f, sample):\n    if len(sample) == 0:\n        return {}\n\n    def _apply(x):\n        if torch.is_tensor(x):\n            return f(x)\n        elif isinstance(x, dict):\n            return {key: _apply(value) for key, value in x.items()}\n        elif isinstance(x, list):\n            return [_apply(x) for x in x]\n        else:\n            return x\n\n    return _apply(sample)\n\n\ndef move_to_cuda(sample):\n    def _move_to_cuda(tensor):\n        return tensor.cuda()\n\n    return apply_to_sample(_move_to_cuda, sample)\n\n\ndef prepare_sample(samples, cuda_enabled=True):\n    if cuda_enabled:\n        samples = move_to_cuda(samples)\n\n    # TODO fp16 support\n\n    return samples\n\n\ndef reorg_datasets_by_split(datasets):\n    \"\"\"\n    Organizes datasets by split.\n\n    Args:\n        datasets: dict of torch.utils.data.Dataset objects by name.\n\n    Returns:\n        Dict of datasets by split {split_name: List[Datasets]}.\n    \"\"\"\n    # if len(datasets) == 1:\n    #     return datasets[list(datasets.keys())[0]]\n    # else:\n    reorg_datasets = dict()\n\n    # reorganize by split\n    for _, dataset in datasets.items():\n        for split_name, dataset_split in dataset.items():\n            if split_name not in reorg_datasets:\n                reorg_datasets[split_name] = [dataset_split]\n            else:\n                reorg_datasets[split_name].append(dataset_split)\n\n    return reorg_datasets\n\n\ndef concat_datasets(datasets):\n    \"\"\"\n    Concatenates multiple datasets into a single dataset.\n\n    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support\n    generic IterableDataset because it requires creating separate samplers.\n\n    Now only supports conctenating training datasets and assuming validation and testing\n    have only a single dataset. This is because metrics should not be computed on the concatenated\n    datasets.\n\n    Args:\n        datasets: dict of torch.utils.data.Dataset objects by split.\n\n    Returns:\n        Dict of concatenated datasets by split, \"train\" is the concatenation of multiple datasets,\n        \"val\" and \"test\" remain the same.\n\n        If the input training datasets contain both map-style and DataPipeline datasets, returns\n        a tuple, where the first element is a concatenated map-style dataset and the second\n        element is a chained DataPipeline dataset.\n\n    \"\"\"\n    # concatenate datasets in the same split\n    for split_name in datasets:\n        if split_name != \"train\":\n            assert (\n                len(datasets[split_name]) == 1\n            ), \"Do not support multiple {} datasets.\".format(split_name)\n            datasets[split_name] = datasets[split_name][0]\n        else:\n            iterable_datasets, map_datasets = [], []\n            for dataset in datasets[split_name]:\n                if isinstance(dataset, wds.DataPipeline):\n                    logging.info(\n                        \"Dataset {} is IterableDataset, can't be concatenated.\".format(\n                            dataset\n                        )\n                    )\n                    iterable_datasets.append(dataset)\n                elif isinstance(dataset, IterableDataset):\n                    raise NotImplementedError(\n                        \"Do not support concatenation of generic IterableDataset.\"\n                    )\n                else:\n                    map_datasets.append(dataset)\n\n            # if len(iterable_datasets) > 0:\n            # concatenate map-style datasets and iterable-style datasets separately\n            if len(iterable_datasets) > 1:\n                chained_datasets = (\n                    ChainDataset(iterable_datasets)\n                )\n            elif len(iterable_datasets) == 1:\n                chained_datasets = iterable_datasets[0]\n            else:\n                chained_datasets = None\n\n            concat_datasets = (\n                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None\n            )\n\n            train_datasets = concat_datasets, chained_datasets\n            train_datasets = tuple([x for x in train_datasets if x is not None])\n            train_datasets = (\n                train_datasets[0] if len(train_datasets) == 1 else train_datasets\n            )\n\n            datasets[split_name] = train_datasets\n\n    return datasets\n\n"
  },
  {
    "path": "stllm/datasets/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "stllm/datasets/datasets/base_dataset.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport json\nfrom typing import Iterable\n\nfrom torch.utils.data import Dataset, ConcatDataset\nfrom torch.utils.data.dataloader import default_collate\n\n\nclass BaseDataset(Dataset):\n    def __init__(\n        self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]\n    ):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        self.vis_root = vis_root\n\n        self.annotation = []\n        for ann_path in ann_paths:\n            jfile = json.load(open(ann_path, \"r\"))\n            if 'annotations' in jfile:\n                self.annotation.extend(jfile['annotations'])\n            else:\n                self.annotation.extend(jfile)\n\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n        self._add_instance_ids()\n\n    def __len__(self):\n        return len(self.annotation)\n\n    def collater(self, samples):\n        return default_collate(samples)\n\n    def set_processors(self, vis_processor, text_processor):\n        self.vis_processor = vis_processor\n        self.text_processor = text_processor\n\n    def _add_instance_ids(self, key=\"instance_id\"):\n        for idx, ann in enumerate(self.annotation):\n            ann[key] = str(idx)\n\n\nclass ConcatDataset(ConcatDataset):\n    def __init__(self, datasets: Iterable[Dataset]) -> None:\n        super().__init__(datasets)\n\n    def collater(self, samples):\n        # TODO For now only supports datasets with same underlying collater implementations\n\n        all_keys = set()\n        for s in samples:\n            all_keys.update(s)\n\n        shared_keys = all_keys\n        for s in samples:\n            shared_keys = shared_keys & set(s.keys())\n\n        samples_shared_keys = []\n        for s in samples:\n            samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})\n\n        return self.datasets[0].collater(samples_shared_keys)\n"
  },
  {
    "path": "stllm/datasets/datasets/caption_datasets.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport os\nfrom collections import OrderedDict\n\nfrom stllm.datasets.datasets.base_dataset import BaseDataset\nfrom PIL import Image\n\n\nclass __DisplMixin:\n    def displ_item(self, index):\n        sample, ann = self.__getitem__(index), self.annotation[index]\n\n        return OrderedDict(\n            {\n                \"file\": ann[\"image\"],\n                \"caption\": ann[\"caption\"],\n                \"image\": sample[\"image\"],\n            }\n        )\n\n\nclass CaptionDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n        self.img_ids = {}\n        n = 0\n        for ann in self.annotation:\n            img_id = ann[\"image_id\"]\n            if img_id not in self.img_ids.keys():\n                self.img_ids[img_id] = n\n                n += 1\n\n    def __getitem__(self, index):\n\n        # TODO this assumes image input, not general enough\n        ann = self.annotation[index]\n\n        img_file = '{:0>12}.jpg'.format(ann[\"image_id\"])\n        image_path = os.path.join(self.vis_root, img_file)\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = self.text_processor(ann[\"caption\"])\n\n        return {\n            \"image\": image,\n            \"text_input\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n        }\n\n\nclass CaptionEvalDataset(BaseDataset, __DisplMixin):\n    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):\n        \"\"\"\n        vis_root (string): Root directory of images (e.g. coco/images/)\n        ann_root (string): directory to store the annotation file\n        split (string): val or test\n        \"\"\"\n        super().__init__(vis_processor, text_processor, vis_root, ann_paths)\n\n    def __getitem__(self, index):\n\n        ann = self.annotation[index]\n\n        image_path = os.path.join(self.vis_root, ann[\"image\"])\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n\n        return {\n            \"image\": image,\n            \"image_id\": ann[\"image_id\"],\n            \"instance_id\": ann[\"instance_id\"],\n        }\n"
  },
  {
    "path": "stllm/datasets/datasets/cc_sbu_dataset.py",
    "content": "import os\nimport pickle\nfrom PIL import Image\nimport webdataset as wds\nfrom stllm.datasets.datasets.base_dataset import BaseDataset\nfrom stllm.datasets.datasets.caption_datasets import CaptionDataset\n\n\nclass CCSBUDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, location):\n        super().__init__(vis_processor=vis_processor, text_processor=text_processor)\n\n        self.inner_dataset = wds.DataPipeline(\n            wds.ResampledShards(location),\n            wds.tarfile_to_samples(handler=wds.warn_and_continue),\n            wds.shuffle(1000, handler=wds.warn_and_continue),\n            wds.decode(\"pilrgb\", handler=wds.warn_and_continue),\n            wds.to_tuple(\"jpg\", \"json\", handler=wds.warn_and_continue),\n            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),\n            wds.map(self.to_dict, handler=wds.warn_and_continue),\n        )\n\n    def to_dict(self, sample):\n        return {\n            \"image\": sample[0],\n            \"answer\": self.text_processor(sample[1][\"caption\"]),\n        }\n\n\nclass CCSBUAlignDataset(CaptionDataset):\n\n    def __getitem__(self, index):\n\n        # TODO this assumes image input, not general enough\n        ann = self.annotation[index]\n\n        img_file = '{}.jpg'.format(ann[\"image_id\"])\n        image_path = os.path.join(self.vis_root, img_file)\n        image = Image.open(image_path).convert(\"RGB\")\n\n        image = self.vis_processor(image)\n        caption = ann[\"caption\"]\n\n        return {\n            \"image\": image,\n            \"answer\": caption,\n            \"image_id\": self.img_ids[ann[\"image_id\"]],\n        }\n"
  },
  {
    "path": "stllm/datasets/datasets/dataloader_utils.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport time\nimport random\nimport torch\nfrom stllm.datasets.data_utils import move_to_cuda\nfrom torch.utils.data import DataLoader\nimport torch.distributed as dist\n\nclass MultiIterLoader:\n    \"\"\"\n    A simple wrapper for iterating over multiple iterators.\n\n    Args:\n        loaders (List[Loader]): List of Iterator loaders.\n        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.\n    \"\"\"\n\n    def __init__(self, loaders, ratios=None):\n        # assert all loaders has __next__ method\n        for loader in loaders:\n            assert hasattr(\n                loader, \"__next__\"\n            ), \"Loader {} has no __next__ method.\".format(loader)\n\n        if ratios is None:\n            ratios = [1.0] * len(loaders)\n        else:\n            assert len(ratios) == len(loaders)\n            ratios = [float(ratio) / sum(ratios) for ratio in ratios]\n\n        self.loaders = loaders\n        self.ratios = ratios\n\n    def __next__(self):\n        # random sample from each loader by ratio\n        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]\n        return next(self.loaders[loader_idx])\n\nclass MetaLoader(object):\n    \"\"\" wraps multiple data loader \"\"\"\n    def __init__(self, loaders, ratios=None):\n        \"\"\"Iterates over multiple dataloaders, it ensures all processes\n        work on data from the same dataloader. This loader will end when\n        the shorter dataloader raises StopIteration exception.\n\n        loaders: List, [dataloader]\n        \"\"\"\n        self.loaders = loaders\n        self.iter_order = self.build_iter()\n\n    def build_iter(self):\n        iter_order = []\n\n        for n, l in enumerate(self.loaders):\n            iter_order.extend([n]*len(l))\n\n        random.shuffle(iter_order)\n        iter_order = torch.Tensor(iter_order).to(torch.device(\"cuda\")).to(torch.uint8)\n\n        # sync\n        if dist.is_available():\n            # make sure all processes have the same order so that\n            # each step they will have data from the same loader\n            dist.broadcast(iter_order, src=0)\n        return iter_order\n\n    def __len__(self):\n        return len(self.iter_order)\n\n    def __iter__(self):\n        \"\"\" this iterator will run indefinitely \"\"\"\n        for i, loader_idx in enumerate(self.iter_order):\n            batch = next(self.loaders[loader_idx])\n            if i==len(self)-1:\n                self.iter_order = self.build_iter()\n            yield batch\n\nclass PrefetchLoader(object):\n    \"\"\"\n    Modified from https://github.com/ChenRocks/UNITER.\n\n    overlap compute and cuda data transfer\n    (copied and then modified from nvidia apex)\n    \"\"\"\n\n    def __init__(self, loader):\n        self.loader = loader\n        self.stream = torch.cuda.Stream()\n\n    def __iter__(self):\n        loader_it = iter(self.loader)\n        self.preload(loader_it)\n        batch = self.next(loader_it)\n        while batch is not None:\n            is_tuple = isinstance(batch, tuple)\n            if is_tuple:\n                task, batch = batch\n\n            if is_tuple:\n                yield task, batch\n            else:\n                yield batch\n            batch = self.next(loader_it)\n\n    def __len__(self):\n        return len(self.loader)\n\n    def preload(self, it):\n        try:\n            self.batch = next(it)\n        except StopIteration:\n            self.batch = None\n            return\n        # if record_stream() doesn't work, another option is to make sure\n        # device inputs are created on the main stream.\n        # self.next_input_gpu = torch.empty_like(self.next_input,\n        #                                        device='cuda')\n        # self.next_target_gpu = torch.empty_like(self.next_target,\n        #                                         device='cuda')\n        # Need to make sure the memory allocated for next_* is not still in use\n        # by the main stream at the time we start copying to next_*:\n        # self.stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(self.stream):\n            self.batch = move_to_cuda(self.batch)\n            # more code for the alternative if record_stream() doesn't work:\n            # copy_ will record the use of the pinned source tensor in this\n            # side stream.\n            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)\n            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)\n            # self.next_input = self.next_input_gpu\n            # self.next_target = self.next_target_gpu\n\n    def next(self, it):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        batch = self.batch\n        if batch is not None:\n            record_cuda_stream(batch)\n        self.preload(it)\n        return batch\n\n    def __getattr__(self, name):\n        method = self.loader.__getattribute__(name)\n        return method\n\n\ndef record_cuda_stream(batch):\n    if isinstance(batch, torch.Tensor):\n        batch.record_stream(torch.cuda.current_stream())\n    elif isinstance(batch, list) or isinstance(batch, tuple):\n        for t in batch:\n            record_cuda_stream(t)\n    elif isinstance(batch, dict):\n        for t in batch.values():\n            record_cuda_stream(t)\n    else:\n        pass\n\n\nclass IterLoader:\n    \"\"\"\n    A wrapper to convert DataLoader as an infinite iterator.\n\n    Modified from:\n        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py\n    \"\"\"\n\n    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):\n        self._dataloader = dataloader\n        self.iter_loader = iter(self._dataloader)\n        self._use_distributed = use_distributed\n        self._epoch = 0\n\n    @property\n    def epoch(self) -> int:\n        return self._epoch\n\n    def __next__(self):\n        try:\n            data = next(self.iter_loader)\n        except StopIteration:\n            self._epoch += 1\n            if hasattr(self._dataloader.sampler, \"set_epoch\") and self._use_distributed:\n                self._dataloader.sampler.set_epoch(self._epoch)\n            time.sleep(2)  # Prevent possible deadlock during epoch transition\n            self.iter_loader = iter(self._dataloader)\n            data = next(self.iter_loader)\n\n        return data\n\n    def __iter__(self):\n        return self\n\n    def __len__(self):\n        return len(self._dataloader)\n"
  },
  {
    "path": "stllm/datasets/datasets/image_video_itdatasets.py",
    "content": "import logging\nimport os\nimport random\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms import InterpolationMode\nfrom stllm.datasets.datasets.instruction_data import available_corpus, train_transform\n\nimport json\nfrom os.path import basename\nimport numpy as np\n\nfrom .utils import load_anno, pre_text, VIDEO_READER_FUNCS, load_image_from_path\n\ntry:\n    from mmengine import fileio \n    has_client = True\nexcept ImportError:\n    has_client = False\n\nlogger = logging.getLogger(__name__)\n\n\nclass ImageVideoBaseDataset(Dataset):\n    \"\"\"Base class that implements the image and video loading methods\"\"\"\n\n    media_type = \"video\"\n\n    def __init__(self):\n        assert self.media_type in [\"image\", \"video\", \"only_video\"]\n        self.data_root = None\n        self.anno_list = (\n            None  # list(dict), each dict contains {\"image\": str, # image or video path}\n        )\n        self.transform = None\n        self.video_reader = None\n        self.num_tries = None\n\n        self.client = None\n        if has_client:\n            self.client = fileio\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def __len__(self):\n        raise NotImplementedError\n\n    def get_anno(self, index):\n        \"\"\"obtain the annotation for one media (video or image)\n\n        Args:\n            index (int): The media index.\n\n        Returns: dict.\n            - \"image\": the filename, video also use \"image\".\n            - \"caption\": The caption for this file.\n\n        \"\"\"\n        anno = self.anno_list[index]\n        if self.data_root is not None:\n            anno[\"image\"] = os.path.join(self.data_root, anno[\"image\"])\n        return anno\n\n    def load_and_transform_media_data(self, index, data_path):\n        if self.media_type == \"image\":\n            return self.load_and_transform_media_data_image(index, data_path)\n        else:\n            return self.load_and_transform_media_data_video(index, data_path)\n\n    def load_and_transform_media_data_image(self, index, data_path):\n        image = load_image_from_path(data_path, client=self.client)\n        image = self.transform(image)\n        return image, index\n\n    def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None):\n        for _ in range(self.num_tries):\n            try:\n                max_num_frames = self.max_num_frames if hasattr(self, \"max_num_frames\") else -1\n                frames, frame_indices, sec = self.video_reader(\n                    data_path, self.num_frames, self.sample_type, \n                    max_num_frames=max_num_frames, client=self.client, clip=clip\n                )\n            except Exception as e:\n                logger.warning(\n                    f\"Caught exception {e} when loading video {data_path}, \"\n                    f\"randomly sample a new video as replacement\"\n                )\n                index = random.randint(0, len(self) - 1)\n                ann = self.get_anno(index)\n                data_path = ann[\"image\"]\n                continue\n            # shared aug for video frames\n            frames = self.transform(frames)\n            if return_fps:\n                #sec = [str(round(f / fps, 1)) for f in frame_indices]\n                return frames, index, sec\n            else:\n                return frames, index\n        else:\n            raise RuntimeError(\n                f\"Failed to fetch video after {self.num_tries} tries. \"\n                f\"This might indicate that you have many corrupted videos.\"\n            )\n\nclass PTImgTrainDataset(ImageVideoBaseDataset):\n    media_type = \"image\"\n\n    def __init__(self, ann_file, transform, pre_text=True):\n        super().__init__()\n\n        if len(ann_file) == 3 and ann_file[2] == \"video\":\n            self.media_type = \"video\"  \n        else:\n            self.media_type = \"image\"\n        self.label_file, self.data_root = ann_file[:2]\n\n        logger.info('Load json file')\n        with open(self.label_file, 'r') as f:\n            self.anno = json.load(f)\n        self.num_examples = len(self.anno)\n\n        self.transform = transform\n        self.pre_text = pre_text\n        logger.info(f\"Pre-process text: {pre_text}\")\n\n    def get_anno(self, index):\n        filename = self.anno[index][self.media_type]\n        caption = self.anno[index][\"caption\"]\n        anno = {\"image\": os.path.join(self.data_root, filename), \"caption\": caption}\n        return anno\n\n    def __len__(self):\n        return self.num_examples\n\n    def __getitem__(self, index):\n        try:\n            ann = self.get_anno(index)\n            image, index = self.load_and_transform_media_data(index, ann[\"image\"])\n            caption = pre_text(ann[\"caption\"], pre_text=self.pre_text)\n            return image, caption, index\n        except Exception as e:\n            logger.warning(f\"Caught exception {e} when loading image {ann['image']}\")\n            index = np.random.randint(0, len(self))\n            return self.__getitem__(index)\n\nclass PTVidTrainDataset(PTImgTrainDataset):\n    media_type = \"video\"\n\n    def __init__(\n        self,\n        ann_file,\n        transform,\n        num_frames=4,\n        video_reader_type=\"decord\",\n        sample_type=\"rand\",\n        num_tries=3,\n        pre_text=True\n    ):\n        super().__init__(ann_file, transform, pre_text=pre_text)\n        self.num_frames = num_frames\n        self.video_reader_type = video_reader_type\n        self.video_reader = VIDEO_READER_FUNCS[video_reader_type]\n        self.sample_type = sample_type\n        self.num_tries = num_tries\n\nclass ITImgTrainDataset(ImageVideoBaseDataset):\n    media_type = \"image\"\n\n    def __init__(\n        self, ann_file, transform, simple=False,\n        system=\"\", role=(\"Human\", \"Assistant\"),\n        start_token=\"<Image>\", end_token=\"</Image>\",\n        random_shuffle=True, # if True, shuffle the QA list\n    ):\n        super().__init__()\n\n        if len(ann_file) == 3 and ann_file[2] == \"video\":\n            self.media_type = \"video\"  \n        else:\n            self.media_type = \"image\"\n        self.label_file, self.data_root = ann_file[:2]\n\n        logger.info('Load json file')\n        with open(self.label_file, 'r') as f:\n            self.anno = json.load(f)\n        self.num_examples = len(self.anno)\n        self.transform = transform\n\n        # prompt parameters\n        if system:\n            assert system[-1] == \" \", \"' ' should be add in the end of system, thus '###' will be tokenized into one token.\"\n        # currently not support add start_token and end_token in the system, since the msg should be added properly\n        self.begin_signal = \"###\"\n        self.end_signal = \" \"\n        self.start_token = start_token\n        self.end_token = end_token\n        self.system = system\n        self.role = role\n        self.random_shuffle = random_shuffle\n        self.simple = simple\n        # instruction location and number\n        logger.info(f\"Random shuffle: {self.random_shuffle}\")\n\n    def get_anno(self, index):\n        filename = self.anno[index][self.media_type]\n        qa = self.anno[index][\"QA\"]\n        if \"num_frames\" in self.anno[index]:\n            self.max_num_frames = self.anno[index][\"num_frames\"]\n        if \"start\" in self.anno[index] and \"end\" in self.anno[index]:\n            anno = {\n                \"image\": os.path.join(self.data_root, filename), \"qa\": qa,\n                \"start\": self.anno[index][\"start\"], \"end\": self.anno[index][\"end\"],\n            }\n        else:\n            anno = {\"image\": os.path.join(self.data_root, filename), \"qa\": qa}\n        return anno\n\n    def __len__(self):\n        return self.num_examples\n    \n    def process_qa(self, qa, msg=\"\"):\n        cur_instruction = \"\"\n        # randomly shuffle qa for conversation\n        if self.random_shuffle and len(qa) > 1:\n            random.shuffle(qa)\n        if \"i\" in qa[0].keys() and qa[0][\"i\"] != \"\":\n            cur_instruction = qa[0][\"i\"] + self.end_signal\n\n        conversation = self.system\n        # add instruction as system message\n\n        # rstrip() for the extra \" \" in msg\n        if not self.simple:\n            if cur_instruction:\n                conversation += cur_instruction\n            conversation += (\n                self.begin_signal + self.role[0] + \": \" + \n                self.start_token + '<ImageHere>' + self.end_token + msg.rstrip() + ' ' + \n                qa[0][\"q\"] + self.end_signal + self.begin_signal + self.role[1] + \": \"\n            )\n        else:\n            conversation += '<ImageHere>'\n            conversation += (\n                self.begin_signal + self.role[0] + \": \" + cur_instruction + msg.rstrip() + \n                qa[0][\"q\"] + self.end_signal + self.begin_signal + self.role[1] + \": \"\n            )\n        \n        return conversation, qa[0][\"a\"]\n\n    def __getitem__(self, index):\n        try:\n            ann = self.get_anno(index)\n            image, index = self.load_and_transform_media_data_image(index, ann[\"image\"])\n            instruction, answer = self.process_qa(ann[\"qa\"])\n            return {\n                \"image\": image,\n                \"answer\": answer,\n                \"image_id\": index,\n                \"instruction_input\": instruction\n            }\n        except Exception as e:\n            logger.warning(f\"Caught exception {e} when loading image {ann['image']}\")\n            index = np.random.randint(0, len(self))\n            return self.__getitem__(index)\n\nclass ITVidTrainDataset(ITImgTrainDataset):\n    media_type = \"video\"\n\n    def __init__(\n        self, ann_file, transform, simple=False,\n        num_frames=4, video_reader_type=\"decord\", sample_type=\"rand\", num_tries=3,\n        system=\"\", role=(\"Human\", \"Assistant\"),\n        start_token=\"<Video>\", end_token=\"</Video>\",\n        add_second_msg=False,\n        random_shuffle=True,\n    ):\n        super().__init__(\n            ann_file, transform, \n            system=system, role=role,\n            start_token=start_token, end_token=end_token,\n            random_shuffle=random_shuffle,\n            simple=simple,\n        )\n        self.num_frames = num_frames\n        self.video_reader_type = video_reader_type\n        self.video_reader = VIDEO_READER_FUNCS[video_reader_type]\n        self.sample_type = sample_type\n        self.num_tries = num_tries\n        self.add_second_msg = add_second_msg\n\n        logger.info(f\"Use {video_reader_type} for data in {ann_file}\")\n        if add_second_msg:\n            logger.info(f\"Add second message: The video contains X frames sampled at T seconds.\")\n\n    def __getitem__(self, index):\n        try:\n            ann = self.get_anno(index)\n            msg = \"\"\n            clip = None\n            if \"start\" in ann and \"end\" in ann:\n                clip = [ann[\"start\"], ann[\"end\"]]\n            video, index, sec = self.load_and_transform_media_data_video(index, ann[\"image\"], return_fps=True, clip=clip)\n            if self.add_second_msg:\n                # \" \" should be added in the start and end\n                msg = f\" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. \"\n            instruction, answer = self.process_qa(ann[\"qa\"], msg)\n            return {\n                \"image\": video,\n                \"answer\": answer,\n                \"image_id\": index,\n                \"instruction_input\": instruction,\n                \"video_len\": sec\n            }\n        except Exception as e:\n            logger.warning(f\"Caught exception {e} when loading video {ann['image']}\")\n            index = np.random.randint(0, len(self))\n            return self.__getitem__(index)\n \nif __name__ == \"__main__\":\n    pass\n\n"
  },
  {
    "path": "stllm/datasets/datasets/instruction_data.py",
    "content": "from torchvision import transforms\nfrom torchvision.transforms import InterpolationMode\n\nmean = (0.48145466, 0.4578275, 0.40821073)\nstd = (0.26862954, 0.26130258, 0.27577711)\nnormalize = transforms.Normalize(mean, std)\ntype_transform = transforms.Lambda(lambda x: x.float().div(255.0))\ntrain_transform = transforms.Compose(\n    [\n        transforms.RandomResizedCrop(\n            224,\n            scale=(0.5, 1.0),\n            interpolation=InterpolationMode.BICUBIC,\n        ),\n        #transforms.RandomHorizontalFlip(),\n        type_transform,\n        normalize,\n    ]\n)\n\nanno_root_it = '/Path/to/MVBench/VideoChat2-IT'\n\n# ============== pretraining datasets=================\navailable_corpus = dict(\n    # image\n    llava_full=[\n        f\"{anno_root_it}/image/llava/llava_full.json\", \n        \"your_data_path/coco_caption\",\n    ],\n    caption_coco=[\n        f\"{anno_root_it}/image/caption/coco/train.json\", \n        \"your_data_path/coco_caption\",\n    ],\n    caption_llava=[\n        f\"{anno_root_it}/image/caption/llava/train.json\", \n        \"your_data_path/coco_caption\",\n    ],\n    caption_minigpt4=[\n        f\"{anno_root_it}/image/caption/minigpt4/train.json\", \n        \"your_data_path/minigpt4/image\",\n    ],\n    caption_paragraph_captioning=[\n        f\"{anno_root_it}/image/caption/paragraph_captioning/train.json\", \n        \"your_data_path/m3it/image-paragraph-captioning\",\n    ],\n    caption_textcaps=[\n        f\"{anno_root_it}/image/caption/textcaps/train.json\", \n        \"your_data_path/m3it/textcap\",\n    ],\n    classification_imagenet=[\n        f\"{anno_root_it}/image/classification/imagenet/train.json\", \n        \"your_data_path/m3it/imagenet\",\n    ],\n    classification_coco_itm=[\n        f\"{anno_root_it}/image/classification/coco_itm/train.json\", \n        \"your_data_path/m3it/coco-itm\",\n    ],\n    conversation_llava=[\n        f\"{anno_root_it}/image/conversation/llava/train.json\", \n        \"your_data_path/coco_caption\",\n    ],\n    reasoning_clevr=[\n        f\"{anno_root_it}/image/reasoning/clevr/train.json\", \n        \"your_data_path/m3it/clevr\",\n    ],\n    reasoning_visual_mrc=[\n        f\"{anno_root_it}/image/reasoning/visual_mrc/train.json\", \n        \"your_data_path/m3it/visual-mrc\",\n    ],\n    reasoning_llava=[\n        f\"{anno_root_it}/image/reasoning/llava/train.json\", \n        \"your_data_path/coco_caption\",\n    ],\n    vqa_vqav2=[\n        f\"{anno_root_it}/image/vqa/vqav2/train.json\", \n        \"your_data_path/m3it/vqa-v2\",\n    ],\n    vqa_gqa=[\n        f\"{anno_root_it}/image/vqa/gqa/train.json\", \n        \"your_data_path/m3it/gqa\",\n    ],\n    vqa_okvqa=[\n        f\"{anno_root_it}/image/vqa/okvqa/train.json\", \n        \"your_data_path/m3it/okvqa\",\n    ],\n    vqa_a_okvqa=[\n        f\"{anno_root_it}/image/vqa/a_okvqa/train.json\", \n        \"your_data_path/m3it/a-okvqa\",\n    ],\n    vqa_viquae=[\n        f\"{anno_root_it}/image/vqa/viquae/train.json\", \n        \"your_data_path/m3it/viquae\",\n    ],\n    vqa_ocr_vqa=[\n        f\"{anno_root_it}/image/vqa/ocr_vqa/train.json\", \n        \"your_data_path/m3it/ocr-vqa\",\n    ],\n    vqa_text_vqa=[\n        f\"{anno_root_it}/image/vqa/text_vqa/train.json\", \n        \"your_data_path/m3it/text-vqa\",\n    ],\n    vqa_st_vqa=[\n        f\"{anno_root_it}/image/vqa/st_vqa/train.json\", \n        \"your_data_path/m3it/st-vqa\",\n    ],\n    vqa_docvqa=[\n        f\"{anno_root_it}/image/vqa/docvqa/train.json\", \n        \"your_data_path/m3it/docvqa\",\n    ],\n    # video\n    caption_textvr=[\n        f\"{anno_root_it}/video/caption/textvr/train.json\", \n        \"your_data_path/TextVR/Video\",\n        \"video\"\n    ],\n    caption_videochat=[\n        f\"{anno_root_it}/video/caption/videochat/train.json\", \n        \"your_data_path/WebVid10M\",\n        \"video\"\n    ],\n    caption_webvid=[\n        f\"{anno_root_it}/video/caption/webvid/train.json\", \n        \"your_data_path/WebVid2M\",\n        \"video\"\n    ],\n    caption_youcook2=[\n        f\"{anno_root_it}/video/caption/youcook2/train.json\", \n        \"your_data_path/youcook2/split_videos\",\n        \"video\"\n    ],\n    classification_k710=[\n        f\"{anno_root_it}/video/classification/k710/train.json\", \n        \"\",\n        \"video\"\n    ],\n    classification_ssv2=[\n        f\"{anno_root_it}/video/classification/ssv2/train.json\", \n        \"your_data_path/video_pub/ssv2_video\",\n        \"video\"\n    ],\n    conversation_videochat1=[\n        f\"{anno_root_it}/video/conversation/videochat1/train_flat.json\", \n        \"your_data_path/WebVid10M\",\n        \"video\"\n    ],\n    conversation_videochat2=[\n        f\"{anno_root_it}/video/conversation/videochat2/train.json\", \n        \"your_data_path/internvid\",\n        \"video\"\n    ],\n    caption_videochatgpt=[\n        f\"{anno_root_it}/video/conversation/videochatgpt/train_full_flat.json\", \n        \"your_data_path/ANet/ANet_320p_fps30\",\n        \"video\"\n    ],\n    reasoning_next_qa=[\n        f\"{anno_root_it}/video/reasoning/next_qa/train.json\", \n        \"your_data_path/nextqa\",\n        \"video\"\n    ],\n    reasoning_clevrer_qa=[\n        f\"{anno_root_it}/video/reasoning/clevrer_qa/train.json\", \n        \"your_data_path/clevrer/video_train\",\n        \"video\"\n    ],\n    reasoning_clevrer_mc=[\n        f\"{anno_root_it}/video/reasoning/clevrer_mc/train.json\",  \n        \"your_data_path/clevrer/video_train\",\n        \"video\"\n    ],\n    vqa_ego_qa=[\n        f\"{anno_root_it}/video/vqa/ego_qa/train.json\", \n        \"your_data_path/EgoQA/split_videos\",\n        \"video\"\n    ],\n    vqa_tgif_frame_qa=[\n        f\"{anno_root_it}/video/vqa/tgif_frame_qa/train.json\", \n        \"your_data_path/tgif\",\n        \"video\"\n    ],\n    vqa_tgif_transition_qa=[\n        f\"{anno_root_it}/video/vqa/tgif_transition_qa/train.json\", \n        \"your_data_path/tgif\",\n        \"video\"\n    ],\n    vqa_webvid_qa=[\n        f\"{anno_root_it}/video/vqa/webvid_qa/train.json\", \n        \"your_data_path/WebVid2M\",\n        \"video\"\n    ],\n)\n\n\n"
  },
  {
    "path": "stllm/datasets/datasets/laion_dataset.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport webdataset as wds\nfrom stllm.datasets.datasets.base_dataset import BaseDataset\n\n\nclass LaionDataset(BaseDataset):\n    def __init__(self, vis_processor, text_processor, location):\n        super().__init__(vis_processor=vis_processor, text_processor=text_processor)\n\n        self.inner_dataset = wds.DataPipeline(\n            wds.ResampledShards(location),\n            wds.tarfile_to_samples(handler=wds.warn_and_continue),\n            wds.shuffle(1000, handler=wds.warn_and_continue),\n            wds.decode(\"pilrgb\", handler=wds.warn_and_continue),\n            wds.to_tuple(\"jpg\", \"json\", handler=wds.warn_and_continue),\n            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),\n            wds.map(self.to_dict, handler=wds.warn_and_continue),\n        )\n\n    def to_dict(self, sample):\n        return {\n            \"image\": sample[0],\n            \"answer\": self.text_processor(sample[1][\"caption\"]),\n        }\n\n"
  },
  {
    "path": "stllm/datasets/datasets/utils.py",
    "content": "#from utils.distributed import is_main_process, get_rank, get_world_size\nimport logging\nimport torch.distributed as dist\nimport torch\nimport io\nimport os\nimport json\nimport re\nimport copy\nimport numpy as np\nfrom os.path import join\nfrom tqdm import trange\nfrom PIL import Image\nfrom PIL import ImageFile\nfrom torchvision.transforms import PILToTensor\nImageFile.LOAD_TRUNCATED_IMAGES = True\nImage.MAX_IMAGE_PIXELS = None\n\nimport random\nimport av\nimport cv2\nimport decord\nimport imageio\nfrom decord import VideoReader\nimport torch\nimport math\ndecord.bridge.set_bridge(\"torch\")\n\nimport logging\nlogger = logging.getLogger(__name__)\n\ndef load_image_from_path(image_path, client):\n    if image_path.startswith('s3') or image_path.startswith('p2'):\n        value = client.get(image_path)\n        img_bytes = np.frombuffer(value, dtype=np.uint8)\n        buff = io.BytesIO(img_bytes)\n        image = Image.open(buff).convert('RGB')\n    else:\n        image = Image.open(image_path).convert('RGB')  # PIL Image\n    image = PILToTensor()(image).unsqueeze(0)  # (1, C, H, W), torch.uint8\n    return image\n\ndef load_anno(ann_file_list):\n    \"\"\"[summary]\n\n    Args:\n        ann_file_list (List[List[str, str]] or List[str, str]):\n            the latter will be automatically converted to the former.\n            Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])\n            which specifies the data type, video or image\n\n    Returns:\n        List(dict): each dict is {\n            image: str or List[str],  # image_path,\n            caption: str or List[str]  # caption text string\n        }\n    \"\"\"\n    if isinstance(ann_file_list[0], str):\n        ann_file_list = [ann_file_list]\n\n    ann = []\n    for d in ann_file_list:\n        data_root = d[1]\n        fp = d[0]\n        is_video = len(d) == 3 and d[2] == \"video\"\n        cur_ann = json.load(open(fp, \"r\"))\n        iterator = trange(len(cur_ann), desc=f\"Loading {fp}\") \\\n            if is_main_process() else range(len(cur_ann))\n        for idx in iterator:\n            key = \"video\" if is_video else \"image\"\n            # unified to have the same key for data path\n            if isinstance(cur_ann[idx][key], str):\n                cur_ann[idx][\"image\"] = join(data_root, cur_ann[idx][key])\n            else:  # list\n                cur_ann[idx][\"image\"] = [join(data_root, e) for e in cur_ann[idx][key]]\n        ann += cur_ann\n    return ann\n\n\ndef pre_text(text, max_l=None, pre_text=True):\n    if pre_text:\n        text = re.sub(r\"([,.'!?\\\"()*#:;~])\", '', text.lower())\n        text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')\n\n        text = re.sub(r\"\\s{2,}\", ' ', text)\n        text = text.rstrip('\\n').strip(' ')\n\n        if max_l:  # truncate\n            words = text.split(' ')\n            if len(words) > max_l:\n                text = ' '.join(words[:max_l])\n    else:\n        pass\n    return text\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef collect_result(result, result_dir, filename, is_json=True, is_list=True):\n    if is_json:\n        result_file = os.path.join(\n            result_dir, '%s_rank%d.json' % (filename, get_rank()))\n        final_result_file = os.path.join(result_dir, '%s.json' % filename)\n        json.dump(result, open(result_file, 'w'))\n    else:\n        result_file = os.path.join(\n            result_dir, '%s_rank%d.pth' % (filename, get_rank()))\n        final_result_file = os.path.join(result_dir, '%s.pth' % filename)\n        torch.save(result, result_file)\n\n    dist.barrier()\n\n    result = None\n    if is_main_process():\n        # combine results from all processes\n        if is_list:\n            result = []\n        else:\n            result = {}\n        for rank in range(get_world_size()):\n            if is_json:\n                result_file = os.path.join(\n                    result_dir, '%s_rank%d.json' % (filename, rank))\n                res = json.load(open(result_file, 'r'))\n            else:\n                result_file = os.path.join(\n                    result_dir, '%s_rank%d.pth' % (filename, rank))\n                res = torch.load(result_file)\n            if is_list:\n                result += res\n            else:\n                result.update(res)\n\n    return result\n\n\ndef sync_save_result(result, result_dir, filename, is_json=True, is_list=True):\n    \"\"\"gather results from multiple GPUs\"\"\"\n    if is_json:\n        result_file = os.path.join(\n            result_dir, \"dist_res\", '%s_rank%d.json' % (filename, get_rank()))\n        final_result_file = os.path.join(result_dir, '%s.json' % filename)\n        os.makedirs(os.path.dirname(result_file), exist_ok=True)\n        json.dump(result, open(result_file, 'w'))\n    else:\n        result_file = os.path.join(\n            result_dir, \"dist_res\", '%s_rank%d.pth' % (filename, get_rank()))\n        os.makedirs(os.path.dirname(result_file), exist_ok=True)\n        final_result_file = os.path.join(result_dir, '%s.pth' % filename)\n        torch.save(result, result_file)\n\n    dist.barrier()\n\n    if is_main_process():\n        # combine results from all processes\n        if is_list:\n            result = []\n        else:\n            result = {}\n        for rank in range(get_world_size()):\n            if is_json:\n                result_file = os.path.join(\n                    result_dir, \"dist_res\", '%s_rank%d.json' % (filename, rank))\n                res = json.load(open(result_file, 'r'))\n            else:\n                result_file = os.path.join(\n                    result_dir, \"dist_res\", '%s_rank%d.pth' % (filename, rank))\n                res = torch.load(result_file)\n            if is_list:\n                result += res\n            else:\n                result.update(res)\n        if is_json:\n            json.dump(result, open(final_result_file, 'w'))\n        else:\n            torch.save(result, final_result_file)\n\n        logger.info('result file saved to %s' % final_result_file)\n    dist.barrier()\n    return final_result_file, result\n\n\ndef pad_sequences_1d(sequences, dtype=torch.long, device=torch.device(\"cpu\"), fixed_length=None):\n    \"\"\" Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)\n    into a (n+1)-d array, only allow the first dim has variable lengths.\n    Args:\n        sequences: list(n-d tensor or list)\n        dtype: np.dtype or torch.dtype\n        device:\n        fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.\n            return will be of shape [len(sequences), fixed_length, ...]\n    Returns:\n        padded_seqs: ((n+1)-d tensor) padded with zeros\n        mask: (2d tensor) of the same shape as the first two dims of padded_seqs,\n              1 indicate valid, 0 otherwise\n    Examples:\n        >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]\n        >>> pad_sequences_1d(test_data_list, dtype=torch.long)\n        >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]\n        >>> pad_sequences_1d(test_data_3d, dtype=torch.float)\n        >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]\n        >>> pad_sequences_1d(test_data_list, dtype=np.float32)\n        >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]\n        >>> pad_sequences_1d(test_data_3d, dtype=np.float32)\n    \"\"\"\n    if isinstance(sequences[0], list):\n        if \"torch\" in str(dtype):\n            sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]\n        else:\n            sequences = [np.asarray(s, dtype=dtype) for s in sequences]\n\n    extra_dims = sequences[0].shape[1:]  # the extra dims should be the same for all elements\n    lengths = [len(seq) for seq in sequences]\n    if fixed_length is not None:\n        max_length = fixed_length\n    else:\n        max_length = max(lengths)\n    if isinstance(sequences[0], torch.Tensor):\n        assert \"torch\" in str(dtype), \"dtype and input type does not match\"\n        padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)\n        mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)\n    else:  # np\n        assert \"numpy\" in str(dtype), \"dtype and input type does not match\"\n        padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)\n        mask = np.zeros((len(sequences), max_length), dtype=np.float32)\n\n    for idx, seq in enumerate(sequences):\n        end = lengths[idx]\n        padded_seqs[idx, :end] = seq\n        mask[idx, :end] = 1\n    return padded_seqs, mask  # , lengths\n\ndef pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:\n    \"\"\"\n    Converts a present time with the given time base and start_pts offset to seconds.\n\n    Returns:\n        time_in_seconds (float): The corresponding time in seconds.\n\n    https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64\n    \"\"\"\n    if pts == math.inf:\n        return math.inf\n\n    return int(pts - start_pts) * time_base\n\n\ndef get_pyav_video_duration(video_reader):\n    video_stream = video_reader.streams.video[0]\n    video_duration = pts_to_secs(\n        video_stream.duration,\n        video_stream.time_base,\n        video_stream.start_time\n    )\n    return float(video_duration)\n\n\ndef get_frame_indices_by_fps():\n    pass\n\ndef get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):\n    if sample in [\"rand\", \"middle\"]: # uniform sampling\n        acc_samples = min(num_frames, vlen)\n        # split the video into `acc_samples` intervals, and sample from each interval.\n        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)\n        ranges = []\n        for idx, interv in enumerate(intervals[:-1]):\n            ranges.append((interv, intervals[idx + 1] - 1))\n        if sample == 'rand':\n            try:\n                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]\n            except:\n                frame_indices = np.random.permutation(vlen)[:acc_samples]\n                frame_indices.sort()\n                frame_indices = list(frame_indices)\n        elif fix_start is not None:\n            frame_indices = [x[0] + fix_start for x in ranges]\n        elif sample == 'middle':\n            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]\n        else:\n            raise NotImplementedError\n\n        if len(frame_indices) < num_frames:  # padded with last frame\n            padded_frame_indices = [frame_indices[-1]] * num_frames\n            padded_frame_indices[:len(frame_indices)] = frame_indices\n            frame_indices = padded_frame_indices\n    elif \"fps\" in sample:  # fps0.5, sequentially sample frames at 0.5 fps\n        output_fps = float(sample[3:])\n        duration = float(vlen) / input_fps\n        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents\n        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)\n        frame_indices = np.around(frame_seconds * input_fps).astype(int)\n        frame_indices = [e for e in frame_indices if e < vlen]\n        if max_num_frames > 0 and len(frame_indices) > max_num_frames:\n            frame_indices = frame_indices[:max_num_frames]\n            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)\n    else:\n        raise ValueError\n    return frame_indices\n\ndef read_frames_av(\n        video_path, num_frames, sample='rand', fix_start=None, \n        max_num_frames=-1, client=None, clip=None,\n    ):\n    reader = av.open(video_path)\n    frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]\n    vlen = len(frames)\n    duration = get_pyav_video_duration(reader)\n    fps = vlen / float(duration)\n    frame_indices = get_frame_indices(\n        num_frames, vlen, sample=sample, fix_start=fix_start,\n        input_fps=fps, max_num_frames=max_num_frames\n    )\n    frames = torch.stack([frames[idx] for idx in frame_indices])  # (T, H, W, C), torch.uint8\n    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8\n    return frames, frame_indices, fps\n\ndef read_frames_gif(\n        video_path, num_frames, sample='rand', fix_start=None, \n        max_num_frames=-1, client=None, clip=None,\n    ):\n    if video_path.startswith('s3') or video_path.startswith('p2'):\n        video_bytes = client.get(video_path)\n        gif = imageio.get_reader(io.BytesIO(video_bytes))\n    else:\n        gif = imageio.get_reader(video_path)\n    vlen = len(gif)\n    frame_indices = get_frame_indices(\n        num_frames, vlen, sample=sample, fix_start=fix_start,\n        max_num_frames=max_num_frames\n    )\n    frames = []\n    for index, frame in enumerate(gif):\n        # for index in frame_idxs:\n        if index in frame_indices:\n            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)\n            frame = torch.from_numpy(frame).byte()\n            # # (H x W x C) to (C x H x W)\n            frame = frame.permute(2, 0, 1)\n            frames.append(frame)\n    frames = torch.stack(frames)  # .float() / 255\n    return frames, frame_indices, 25. # for tgif\n\ndef read_frames_decord(\n        video_path, num_frames, sample='rand', fix_start=None, \n        max_num_frames=-1, client=None, clip=None\n    ):\n    if video_path.startswith('s3') or video_path.startswith('p2'):\n        video_bytes = client.get(video_path)\n        video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)\n    else:\n        video_reader = VideoReader(video_path, num_threads=1)\n    vlen = len(video_reader)\n    fps = video_reader.get_avg_fps()\n    duration = vlen / float(fps)\n\n    if clip:\n        start, end = clip\n        duration = end - start\n        vlen = int(duration * fps)\n        start_index = int(start * fps)\n\n    frame_indices = get_frame_indices(\n        num_frames, vlen, sample=sample, fix_start=fix_start,\n        input_fps=fps, max_num_frames=max_num_frames\n    )\n    if clip:\n        frame_indices = [f + start_index for f in frame_indices]\n\n    frames = video_reader.get_batch(frame_indices)  # (T, H, W, C), torch.uint8\n    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8\n    return frames, frame_indices, float(duration)\n\ndef read_frames_rawframes(\n        video_path, num_frames, sample='rand', fix_start=None, \n        max_num_frames=-1, client=None, clip=None\n    ):\n    file_client = client.FileClient('disk')\n    fps = 5\n    filename_tmpl=\"{:0>6}.jpg\"\n    offset=1\n    frame_indices = get_frame_indices(\n        num_frames, max_num_frames, sample=sample, fix_start=fix_start,\n        input_fps=fps, max_num_frames=-1\n    )\n    imgs = list()\n    cache = {}\n    for i, frame_idx in enumerate(frame_indices):\n        # Avoid loading duplicated frames\n        if frame_idx in cache:\n            imgs.append(copy.deepcopy(imgs[cache[frame_idx]]))  \n            continue\n        else:\n            cache[frame_idx] = i\n        frame_idx += offset\n        filepath = os.path.join(video_path, filename_tmpl.format(frame_idx))\n        try:\n            img_bytes = file_client.get(filepath)\n        except:\n            filepath = os.path.join(video_path, filename_tmpl.format(frame_idx+1))\n            img_bytes = file_client.get(filepath)\n        # Get frame with channel order RGB directly.\n        import mmcv\n        cur_frame = mmcv.imfrombytes(img_bytes, channel_order='rgb')\n        imgs.append(cur_frame)\n    frames = np.concatenate([img[np.newaxis, ...] for img in imgs], axis=0)\n    frames = torch.from_numpy(frames)\n    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8\n    return frames, frame_indices, float(max_num_frames / fps)\n\nVIDEO_READER_FUNCS = {\n    'av': read_frames_av,\n    'decord': read_frames_decord,\n    'gif': read_frames_gif,\n    'rawframe': read_frames_rawframes,\n}\n"
  },
  {
    "path": "stllm/models/Qformer.py",
    "content": "\"\"\"\n * Copyright (c) 2023, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n * By Junnan Li\n * Based on huggingface code base\n * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert\n\"\"\"\n\nimport math\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Dict, Any\n\nimport torch\nfrom torch import Tensor, device, dtype, nn\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nimport torch.nn.functional as F\n\nfrom transformers.activations import ACT2FN\nfrom transformers.file_utils import (\n    ModelOutput,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    NextSentencePredictorOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.utils import logging\nfrom transformers.models.bert.configuration_bert import BertConfig\n\nlogger = logging.get_logger(__name__)\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word and position embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id\n        )\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size\n        )\n\n        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n        # any TensorFlow checkpoint file\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1))\n        )\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n\n        self.config = config\n\n    def forward(\n        self,\n        input_ids=None,\n        position_ids=None,\n        query_embeds=None,\n        past_key_values_length=0,\n    ):\n        if input_ids is not None:\n            seq_length = input_ids.size()[1]\n        else:\n            seq_length = 0\n\n        if position_ids is None:\n            position_ids = self.position_ids[\n                :, past_key_values_length : seq_length + past_key_values_length\n            ].clone()\n\n        if input_ids is not None:\n            embeddings = self.word_embeddings(input_ids)\n            if self.position_embedding_type == \"absolute\":\n                position_embeddings = self.position_embeddings(position_ids)\n                embeddings = embeddings + position_embeddings\n\n            if query_embeds is not None:\n                embeddings = torch.cat((query_embeds, embeddings), dim=1)\n        else:\n            embeddings = query_embeds\n\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, is_cross_attention):\n        super().__init__()\n        self.config = config\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(\n            config, \"embedding_size\"\n        ):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        if is_cross_attention:\n            self.key = nn.Linear(config.encoder_width, self.all_head_size)\n            self.value = nn.Linear(config.encoder_width, self.all_head_size)\n        else:\n            self.key = nn.Linear(config.hidden_size, self.all_head_size)\n            self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(\n                2 * config.max_position_embeddings - 1, self.attention_head_size\n            )\n        self.save_attention = False\n\n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n\n    def get_attn_gradients(self):\n        return self.attn_gradients\n\n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n\n    def get_attention_map(self):\n        return self.attention_map\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (\n            self.num_attention_heads,\n            self.attention_head_size,\n        )\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        mixed_query_layer = self.query(hidden_states)\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        if (\n            self.position_embedding_type == \"relative_key\"\n            or self.position_embedding_type == \"relative_key_query\"\n        ):\n            seq_length = hidden_states.size()[1]\n            position_ids_l = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(-1, 1)\n            position_ids_r = torch.arange(\n                seq_length, dtype=torch.long, device=hidden_states.device\n            ).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(\n                distance + self.max_position_embeddings - 1\n            )\n            positional_embedding = positional_embedding.to(\n                dtype=query_layer.dtype\n            )  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding\n                )\n                relative_position_scores_key = torch.einsum(\n                    \"bhrd,lrd->bhlr\", key_layer, positional_embedding\n                )\n                attention_scores = (\n                    attention_scores\n                    + relative_position_scores_query\n                    + relative_position_scores_key\n                )\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        if is_cross_attention and self.save_attention:\n            self.save_attention_map(attention_probs)\n            attention_probs.register_hook(self.save_attn_gradients)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs_dropped = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs_dropped = attention_probs_dropped * head_mask\n\n        context_layer = torch.matmul(attention_probs_dropped, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (\n            (context_layer, attention_probs) if output_attentions else (context_layer,)\n        )\n\n        outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, is_cross_attention=False):\n        super().__init__()\n        self.self = BertSelfAttention(config, is_cross_attention)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads,\n            self.self.num_attention_heads,\n            self.self.attention_head_size,\n            self.pruned_heads,\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = (\n            self.self.attention_head_size * self.self.num_attention_heads\n        )\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n\n        outputs = (attention_output,) + self_outputs[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_num):\n        super().__init__()\n        self.config = config\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n        self.layer_num = layer_num\n        if (\n            self.config.add_cross_attention\n            and layer_num % self.config.cross_attention_freq == 0\n        ):\n            self.crossattention = BertAttention(\n                config, is_cross_attention=self.config.add_cross_attention\n            )\n            self.has_cross_attention = True\n        else:\n            self.has_cross_attention = False\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n        self.intermediate_query = BertIntermediate(config)\n        self.output_query = BertOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n        query_length=0,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = (\n            past_key_value[:2] if past_key_value is not None else None\n        )\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n        outputs = self_attention_outputs[1:-1]\n\n        present_key_value = self_attention_outputs[-1]\n\n        if query_length > 0:\n            query_attention_output = attention_output[:, :query_length, :]\n\n            if self.has_cross_attention:\n                assert (\n                    encoder_hidden_states is not None\n                ), \"encoder_hidden_states must be given for cross-attention layers\"\n                cross_attention_outputs = self.crossattention(\n                    query_attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    output_attentions=output_attentions,\n                )\n                query_attention_output = cross_attention_outputs[0]\n                outputs = (\n                    outputs + cross_attention_outputs[1:-1]\n                )  # add cross attentions if we output attention weights\n\n            layer_output = apply_chunking_to_forward(\n                self.feed_forward_chunk_query,\n                self.chunk_size_feed_forward,\n                self.seq_len_dim,\n                query_attention_output,\n            )\n            if attention_output.shape[1] > query_length:\n                layer_output_text = apply_chunking_to_forward(\n                    self.feed_forward_chunk,\n                    self.chunk_size_feed_forward,\n                    self.seq_len_dim,\n                    attention_output[:, query_length:, :],\n                )\n                layer_output = torch.cat([layer_output, layer_output_text], dim=1)\n        else:\n            layer_output = apply_chunking_to_forward(\n                self.feed_forward_chunk,\n                self.chunk_size_feed_forward,\n                self.seq_len_dim,\n                attention_output,\n            )\n        outputs = (layer_output,) + outputs\n\n        outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n    def feed_forward_chunk_query(self, attention_output):\n        intermediate_output = self.intermediate_query(attention_output)\n        layer_output = self.output_query(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList(\n            [BertLayer(config, i) for i in range(config.num_hidden_layers)]\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        query_length=0,\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = (\n            () if output_attentions and self.config.add_cross_attention else None\n        )\n\n        next_decoder_cache = () if use_cache else None\n\n        for i in range(self.config.num_hidden_layers):\n            layer_module = self.layer[i]\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if getattr(self.config, \"gradient_checkpointing\", False) and self.training:\n\n                if use_cache:\n                    logger.warn(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(\n                            *inputs, past_key_value, output_attentions, query_length\n                        )\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                    query_length,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertOnlyMLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n\n    def forward(self, sequence_output):\n        prediction_scores = self.predictions(sequence_output)\n        return prediction_scores\n\n\nclass BertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = BertConfig\n    base_model_prefix = \"bert\"\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n\nclass BertModel(BertPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config, add_pooling_layer=False):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n\n        self.encoder = BertEncoder(config)\n\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    def get_extended_attention_mask(\n        self,\n        attention_mask: Tensor,\n        input_shape: Tuple[int],\n        device: device,\n        is_decoder: bool,\n        has_query: bool = False,\n    ) -> Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple[int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if attention_mask.dim() == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif attention_mask.dim() == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if is_decoder:\n                batch_size, seq_length = input_shape\n\n                seq_ids = torch.arange(seq_length, device=device)\n                causal_mask = (\n                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)\n                    <= seq_ids[None, :, None]\n                )\n\n                # add a prefix ones mask to the causal mask\n                # causal and attention masks must have same type with pytorch version < 1.3\n                causal_mask = causal_mask.to(attention_mask.dtype)\n\n                if causal_mask.shape[1] < attention_mask.shape[1]:\n                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                    if has_query:  # UniLM style attention mask\n                        causal_mask = torch.cat(\n                            [\n                                torch.zeros(\n                                    (batch_size, prefix_seq_len, seq_length),\n                                    device=device,\n                                    dtype=causal_mask.dtype,\n                                ),\n                                causal_mask,\n                            ],\n                            axis=1,\n                        )\n                    causal_mask = torch.cat(\n                        [\n                            torch.ones(\n                                (batch_size, causal_mask.shape[1], prefix_seq_len),\n                                device=device,\n                                dtype=causal_mask.dtype,\n                            ),\n                            causal_mask,\n                        ],\n                        axis=-1,\n                    )\n                extended_attention_mask = (\n                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n                )\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                \"Wrong shape for input_ids (shape {}) or attention_mask (shape {})\".format(\n                    input_shape, attention_mask.shape\n                )\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.to(\n            dtype=self.dtype\n        )  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if input_ids is None:\n            assert (\n                query_embeds is not None\n            ), \"You have to specify query_embeds when input_ids is None\"\n\n        # past_key_values_length\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] - self.config.query_length\n            if past_key_values is not None\n            else 0\n        )\n\n        query_length = query_embeds.shape[1] if query_embeds is not None else 0\n\n        embedding_output = self.embeddings(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            query_embeds=query_embeds,\n            past_key_values_length=past_key_values_length,\n        )\n\n        input_shape = embedding_output.size()[:-1]\n        batch_size, seq_length = input_shape\n        device = embedding_output.device\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                ((batch_size, seq_length + past_key_values_length)), device=device\n            )\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if is_decoder:\n            extended_attention_mask = self.get_extended_attention_mask(\n                attention_mask,\n                input_ids.shape,\n                device,\n                is_decoder,\n                has_query=(query_embeds is not None),\n            )\n        else:\n            extended_attention_mask = self.get_extended_attention_mask(\n                attention_mask, input_shape, device, is_decoder\n            )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if encoder_hidden_states is not None:\n            if type(encoder_hidden_states) == list:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[\n                    0\n                ].size()\n            else:\n                (\n                    encoder_batch_size,\n                    encoder_sequence_length,\n                    _,\n                ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n\n            if type(encoder_attention_mask) == list:\n                encoder_extended_attention_mask = [\n                    self.invert_attention_mask(mask) for mask in encoder_attention_mask\n                ]\n            elif encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n            else:\n                encoder_extended_attention_mask = self.invert_attention_mask(\n                    encoder_attention_mask\n                )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            query_length=query_length,\n        )\n        sequence_output = encoder_outputs[0]\n        pooled_output = (\n            self.pooler(sequence_output) if self.pooler is not None else None\n        )\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n\n\nclass BertLMHeadModel(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        past_key_values=None,\n        use_cache=True,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        return_logits=False,\n        is_decoder=True,\n        reduction=\"mean\",\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are\n            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n            >>> config = BertConfig.from_pretrained(\"bert-base-cased\")\n            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)\n            >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n            >>> outputs = model(**inputs)\n            >>> prediction_logits = outputs.logits\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n        if labels is not None:\n            use_cache = False\n        if past_key_values is not None:\n            query_embeds = None\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            query_embeds=query_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n        )\n\n        sequence_output = outputs[0]\n        if query_embeds is not None:\n            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]\n\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores[:, :-1, :].contiguous()\n\n        lm_loss = None\n        if labels is not None:\n            # we are doing next-token prediction; shift prediction scores and input ids by one\n            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()\n            labels = labels[:, 1:].contiguous()\n            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)\n            lm_loss = loss_fct(\n                shifted_prediction_scores.view(-1, self.config.vocab_size),\n                labels.view(-1),\n            )\n            if reduction == \"none\":\n                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return ((lm_loss,) + output) if lm_loss is not None else output\n\n        return CausalLMOutputWithCrossAttentions(\n            loss=lm_loss,\n            logits=prediction_scores,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            cross_attentions=outputs.cross_attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs\n    ):\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            attention_mask = input_ids.new_ones(input_ids.shape)\n        query_mask = input_ids.new_ones(query_embeds.shape[:-1])\n        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)\n\n        # cut decoder_input_ids if past is used\n        if past is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"input_ids\": input_ids,\n            \"query_embeds\": query_embeds,\n            \"attention_mask\": attention_mask,\n            \"past_key_values\": past,\n            \"encoder_hidden_states\": model_kwargs.get(\"encoder_hidden_states\", None),\n            \"encoder_attention_mask\": model_kwargs.get(\"encoder_attention_mask\", None),\n            \"is_decoder\": True,\n        }\n\n    def _reorder_cache(self, past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\nclass BertForMaskedLM(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        position_ids=None,\n        head_mask=None,\n        query_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        return_logits=False,\n        is_decoder=False,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            query_embeds=query_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n        )\n\n        if query_embeds is not None:\n            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]\n        prediction_scores = self.cls(sequence_output)\n\n        if return_logits:\n            return prediction_scores\n\n        masked_lm_loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()  # -100 index = padding token\n            masked_lm_loss = loss_fct(\n                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)\n            )\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return (\n                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n            )\n\n        return MaskedLMOutput(\n            loss=masked_lm_loss,\n            logits=prediction_scores,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "stllm/models/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom stllm.common.registry import registry\nfrom stllm.models.base_model import BaseModel\nfrom stllm.models.blip2 import Blip2Base\nfrom stllm.models.st_llm import STLLMForCausalLM\nfrom stllm.processors.base_processor import BaseProcessor\n\n\n__all__ = [\n    \"load_model\",\n    \"BaseModel\",\n    \"Blip2Base\",\n    \"STLLMForCausalLM\",\n]\n\n\ndef load_model(name, model_type, is_eval=False, device=\"cpu\", checkpoint=None):\n    \"\"\"\n    Load supported models.\n\n    To list all available models and types in registry:\n    >>> from stllm.models import model_zoo\n    >>> print(model_zoo)\n\n    Args:\n        name (str): name of the model.\n        model_type (str): type of the model.\n        is_eval (bool): whether the model is in eval mode. Default: False.\n        device (str): device to use. Default: \"cpu\".\n        checkpoint (str): path or to checkpoint. Default: None.\n            Note that expecting the checkpoint to have the same keys in state_dict as the model.\n\n    Returns:\n        model (torch.nn.Module): model.\n    \"\"\"\n\n    model = registry.get_model_class(name).from_pretrained(model_type=model_type)\n\n    if checkpoint is not None:\n        model.load_checkpoint(checkpoint)\n\n    if is_eval:\n        model.eval()\n\n    if device == \"cpu\":\n        model = model.float()\n\n    return model.to(device)\n\n\ndef load_preprocess(config):\n    \"\"\"\n    Load preprocessor configs and construct preprocessors.\n\n    If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.\n\n    Args:\n        config (dict): preprocessor configs.\n\n    Returns:\n        vis_processors (dict): preprocessors for visual inputs.\n        txt_processors (dict): preprocessors for text inputs.\n\n        Key is \"train\" or \"eval\" for processors used in training and evaluation respectively.\n    \"\"\"\n\n    def _build_proc_from_cfg(cfg):\n        return (\n            registry.get_processor_class(cfg.name).from_config(cfg)\n            if cfg is not None\n            else BaseProcessor()\n        )\n\n    vis_processors = dict()\n    txt_processors = dict()\n\n    vis_proc_cfg = config.get(\"vis_processor\")\n    txt_proc_cfg = config.get(\"text_processor\")\n\n    if vis_proc_cfg is not None:\n        vis_train_cfg = vis_proc_cfg.get(\"train\")\n        vis_eval_cfg = vis_proc_cfg.get(\"eval\")\n    else:\n        vis_train_cfg = None\n        vis_eval_cfg = None\n\n    vis_processors[\"train\"] = _build_proc_from_cfg(vis_train_cfg)\n    vis_processors[\"eval\"] = _build_proc_from_cfg(vis_eval_cfg)\n\n    if txt_proc_cfg is not None:\n        txt_train_cfg = txt_proc_cfg.get(\"train\")\n        txt_eval_cfg = txt_proc_cfg.get(\"eval\")\n    else:\n        txt_train_cfg = None\n        txt_eval_cfg = None\n\n    txt_processors[\"train\"] = _build_proc_from_cfg(txt_train_cfg)\n    txt_processors[\"eval\"] = _build_proc_from_cfg(txt_eval_cfg)\n\n    return vis_processors, txt_processors\n\n\ndef load_model_and_preprocess(name, model_type, is_eval=False, device=\"cpu\"):\n    \"\"\"\n    Load model and its related preprocessors.\n\n    List all available models and types in registry:\n    >>> from stllm.models import model_zoo\n    >>> print(model_zoo)\n\n    Args:\n        name (str): name of the model.\n        model_type (str): type of the model.\n        is_eval (bool): whether the model is in eval mode. Default: False.\n        device (str): device to use. Default: \"cpu\".\n\n    Returns:\n        model (torch.nn.Module): model.\n        vis_processors (dict): preprocessors for visual inputs.\n        txt_processors (dict): preprocessors for text inputs.\n    \"\"\"\n    model_cls = registry.get_model_class(name)\n\n    # load model\n    model = model_cls.from_pretrained(model_type=model_type)\n\n    if is_eval:\n        model.eval()\n\n    # load preprocess\n    cfg = OmegaConf.load(model_cls.default_config_path(model_type))\n    if cfg is not None:\n        preprocess_cfg = cfg.preprocess\n\n        vis_processors, txt_processors = load_preprocess(preprocess_cfg)\n    else:\n        vis_processors, txt_processors = None, None\n        logging.info(\n            f\"\"\"No default preprocess for model {name} ({model_type}).\n                This can happen if the model is not finetuned on downstream datasets,\n                or it is not intended for direct use without finetuning.\n            \"\"\"\n        )\n\n    if device == \"cpu\" or device == torch.device(\"cpu\"):\n        model = model.float()\n\n    return model.to(device), vis_processors, txt_processors\n\n\nclass ModelZoo:\n    \"\"\"\n    A utility class to create string representation of available model architectures and types.\n\n    >>> from stllm.models import model_zoo\n    >>> # list all available models\n    >>> print(model_zoo)\n    >>> # show total number of models\n    >>> print(len(model_zoo))\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.model_zoo = {\n            k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())\n            for k, v in registry.mapping[\"model_name_mapping\"].items()\n        }\n\n    def __str__(self) -> str:\n        return (\n            \"=\" * 50\n            + \"\\n\"\n            + f\"{'Architectures':<30} {'Types'}\\n\"\n            + \"=\" * 50\n            + \"\\n\"\n            + \"\\n\".join(\n                [\n                    f\"{name:<30} {', '.join(types)}\"\n                    for name, types in self.model_zoo.items()\n                ]\n            )\n        )\n\n    def __iter__(self):\n        return iter(self.model_zoo.items())\n\n    def __len__(self):\n        return sum([len(v) for v in self.model_zoo.values()])\n\n\nmodel_zoo = ModelZoo()\n"
  },
  {
    "path": "stllm/models/base_decoder.py",
    "content": "from functools import partial\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nimport torch.utils.checkpoint as checkpoint\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PretrainVisionTransformerDecoder(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, embed_dim=4096, depth=2, num_heads=32, mlp_ratio=2.6875,\n                 qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,\n                 norm_layer=nn.LayerNorm, init_values=0, use_checkpoint=True\n                 ):\n        super().__init__()\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.use_checkpoint = use_checkpoint\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values)\n            for i in range(depth)])\n        self.norm =  norm_layer(embed_dim)\n        self.head = nn.Linear(embed_dim, embed_dim) \n\n        self.apply(self._init_weights)\n\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            nn.init.xavier_uniform_(m.weight)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward(self, x, return_token_num = 0):\n        if self.use_checkpoint:\n            for blk in self.blocks:\n                x = checkpoint.checkpoint(blk, x)\n        else:   \n            for blk in self.blocks:\n                x = blk(x)\n\n        if return_token_num > 0:\n            x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels\n        else:\n            x = self.head(self.norm(x))\n\n        return x\n"
  },
  {
    "path": "stllm/models/base_model.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom stllm.common.dist_utils import download_cached_file, is_dist_avail_and_initialized\nfrom stllm.common.utils import get_abs_path, is_url\nfrom omegaconf import OmegaConf\n\n\nclass BaseModel(nn.Module):\n    \"\"\"Base class for models.\"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    @property\n    def device(self):\n        return list(self.parameters())[-1].device\n\n    def load_checkpoint(self, url_or_filename):\n        \"\"\"\n        Load from a finetuned checkpoint.\n\n        This should expect no mismatch in the model keys and the checkpoint keys.\n        \"\"\"\n\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        if \"model\" in checkpoint.keys():\n            state_dict = checkpoint[\"model\"]\n        else:\n            state_dict = checkpoint\n\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n\n    @classmethod\n    def from_pretrained(cls, model_type):\n        \"\"\"\n        Build a pretrained model from default configuration file, specified by model_type.\n\n        Args:\n            - model_type (str): model type, specifying architecture and checkpoints.\n\n        Returns:\n            - model (nn.Module): pretrained or finetuned model, depending on the configuration.\n        \"\"\"\n        model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model\n        model = cls.from_config(model_cfg)\n\n        return model\n\n    @classmethod\n    def default_config_path(cls, model_type):\n        assert (\n            model_type in cls.PRETRAINED_MODEL_CONFIG_DICT\n        ), \"Unknown model type {}\".format(model_type)\n        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])\n\n    def load_checkpoint_from_config(self, cfg, **kwargs):\n        \"\"\"\n        Load checkpoint as specified in the config file.\n\n        If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.\n        When loading the pretrained model, each task-specific architecture may define their\n        own load_from_pretrained() method.\n        \"\"\"\n        load_finetuned = cfg.get(\"load_finetuned\", True)\n        if load_finetuned:\n            finetune_path = cfg.get(\"finetuned\", None)\n            assert (\n                finetune_path is not None\n            ), \"Found load_finetuned is True, but finetune_path is None.\"\n            self.load_checkpoint(url_or_filename=finetune_path)\n        else:\n            # load pre-trained weights\n            pretrain_path = cfg.get(\"pretrained\", None)\n            assert \"Found load_finetuned is False, but pretrain_path is None.\"\n            self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)\n\n    def before_evaluation(self, **kwargs):\n        pass\n\n    def show_n_params(self, return_str=True):\n        tot = 0\n        for p in self.parameters():\n            w = 1\n            for x in p.shape:\n                w *= x\n            tot += w\n        if return_str:\n            if tot >= 1e6:\n                return \"{:.1f}M\".format(tot / 1e6)\n            else:\n                return \"{:.1f}K\".format(tot / 1e3)\n        else:\n            return tot\n\n\nclass BaseEncoder(nn.Module):\n    \"\"\"\n    Base class for primitive encoders, such as ViT, TimeSformer, etc.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward_features(self, samples, **kwargs):\n        raise NotImplementedError\n\n    @property\n    def device(self):\n        return list(self.parameters())[0].device\n\n\nclass SharedQueueMixin:\n    @torch.no_grad()\n    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):\n        # gather keys before updating queue\n        image_feats = concat_all_gather(image_feat)\n        text_feats = concat_all_gather(text_feat)\n\n        batch_size = image_feats.shape[0]\n\n        ptr = int(self.queue_ptr)\n        assert self.queue_size % batch_size == 0  # for simplicity\n\n        # replace the keys at ptr (dequeue and enqueue)\n        self.image_queue[:, ptr : ptr + batch_size] = image_feats.T\n        self.text_queue[:, ptr : ptr + batch_size] = text_feats.T\n\n        if idxs is not None:\n            idxs = concat_all_gather(idxs)\n            self.idx_queue[:, ptr : ptr + batch_size] = idxs.T\n\n        ptr = (ptr + batch_size) % self.queue_size  # move pointer\n        self.queue_ptr[0] = ptr\n\n\nclass MomentumDistilationMixin:\n    @torch.no_grad()\n    def copy_params(self):\n        for model_pair in self.model_pairs:\n            for param, param_m in zip(\n                model_pair[0].parameters(), model_pair[1].parameters()\n            ):\n                param_m.data.copy_(param.data)  # initialize\n                param_m.requires_grad = False  # not update by gradient\n\n    @torch.no_grad()\n    def _momentum_update(self):\n        for model_pair in self.model_pairs:\n            for param, param_m in zip(\n                model_pair[0].parameters(), model_pair[1].parameters()\n            ):\n                param_m.data = param_m.data * self.momentum + param.data * (\n                    1.0 - self.momentum\n                )\n\n\nclass GatherLayer(torch.autograd.Function):\n    \"\"\"\n    Gather tensors from all workers with support for backward propagation:\n    This implementation does not cut the gradients as torch.distributed.all_gather does.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x):\n        output = [\n            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())\n        ]\n        torch.distributed.all_gather(output, x)\n        return tuple(output)\n\n    @staticmethod\n    def backward(ctx, *grads):\n        all_gradients = torch.stack(grads)\n        torch.distributed.all_reduce(all_gradients)\n        return all_gradients[torch.distributed.get_rank()]\n\n\ndef all_gather_with_grad(tensors):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    Graph remains connected for backward grad computation.\n    \"\"\"\n    # Queue the gathered tensors\n    world_size = torch.distributed.get_world_size()\n    # There is no need for reduction in the single-proc case\n    if world_size == 1:\n        return tensors\n\n    # tensor_all = GatherLayer.apply(tensors)\n    tensor_all = GatherLayer.apply(tensors)\n\n    return torch.cat(tensor_all, dim=0)\n\n\n@torch.no_grad()\ndef concat_all_gather(tensor):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    *** Warning ***: torch.distributed.all_gather has no gradient.\n    \"\"\"\n    # if use distributed training\n    if not is_dist_avail_and_initialized():\n        return tensor\n\n    tensors_gather = [\n        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())\n    ]\n    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)\n\n    output = torch.cat(tensors_gather, dim=0)\n    return output\n\n\ndef tile(x, dim, n_tile):\n    init_dim = x.size(dim)\n    repeat_idx = [1] * x.dim()\n    repeat_idx[dim] = n_tile\n    x = x.repeat(*(repeat_idx))\n    order_index = torch.LongTensor(\n        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])\n    )\n    return torch.index_select(x, dim, order_index.to(x.device))\n"
  },
  {
    "path": "stllm/models/blip2.py",
    "content": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\nimport contextlib\nimport logging\nimport os\nimport time\nimport datetime\n\nimport torch\nimport torch.nn as nn\nimport torch.distributed as dist\nimport torch.nn.functional as F\n\nimport stllm.common.dist_utils as dist_utils\nfrom stllm.common.dist_utils import download_cached_file\nfrom stllm.common.utils import is_url\nfrom stllm.common.logger import MetricLogger\nfrom stllm.models.base_model import BaseModel\nfrom stllm.models.Qformer import BertConfig, BertLMHeadModel\nfrom stllm.models.eva_vit import create_eva_vit_g\nfrom stllm.models.eva_btadapter import create_eva_btadapter\nfrom transformers import BertTokenizer\n\n\nclass Blip2Base(BaseModel):\n    @classmethod\n    def init_tokenizer(cls, truncation_side=\"right\"):\n        tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\", truncation_side=truncation_side)\n        tokenizer.add_special_tokens({\"bos_token\": \"[DEC]\"})\n        return tokenizer\n\n    def maybe_autocast(self, dtype=torch.float16):\n        # if on cpu, don't use autocast\n        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16\n        enable_autocast = self.device != torch.device(\"cpu\")\n\n        if enable_autocast:\n            return torch.cuda.amp.autocast(dtype=dtype)\n        else:\n            return contextlib.nullcontext()\n\n    @classmethod\n    def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):\n        encoder_config = BertConfig.from_pretrained(\"bert-base-uncased\")\n        encoder_config.encoder_width = vision_width\n        # insert cross-attention layer every other block\n        encoder_config.add_cross_attention = True\n        encoder_config.cross_attention_freq = cross_attention_freq\n        encoder_config.query_length = num_query_token\n        Qformer = BertLMHeadModel(config=encoder_config)\n        query_tokens = nn.Parameter(\n            torch.zeros(1, num_query_token, encoder_config.hidden_size)\n        )\n        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)\n        return Qformer, query_tokens\n\n    @classmethod\n    def init_vision_encoder(\n        cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision\n    ):\n        assert model_name in [\"eva_clip_g\",\"eva_btadapter_g\"], \"vit model must be eva_clip_g for current version of MiniGPT-4\"\n        if model_name==\"eva_clip_g\":\n            visual_encoder = create_eva_vit_g(\n                img_size, drop_path_rate, use_grad_checkpoint, precision\n            )\n        else:\n            visual_encoder = create_eva_btadapter(precision)\n\n        ln_vision = LayerNorm(visual_encoder.num_features)\n        return visual_encoder, ln_vision\n\n    def load_from_pretrained(self, url_or_filename):\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=\"cpu\")\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=\"cpu\")\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n\n        msg = self.load_state_dict(state_dict, strict=False)\n\n        # logging.info(\"Missing keys {}\".format(msg.missing_keys))\n        logging.info(\"load checkpoint from %s\" % url_or_filename)\n\n        return msg\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\ndef compute_sim_matrix(model, data_loader, **kwargs):\n    k_test = kwargs.pop(\"k_test\")\n\n    metric_logger = MetricLogger(delimiter=\"  \")\n    header = \"Evaluation:\"\n\n    logging.info(\"Computing features for evaluation...\")\n    start_time = time.time()\n\n    texts = data_loader.dataset.text\n    num_text = len(texts)\n    text_bs = 256\n    text_ids = []\n    text_embeds = []\n    text_atts = []\n    for i in range(0, num_text, text_bs):\n        text = texts[i : min(num_text, i + text_bs)]\n        text_input = model.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=35,\n            return_tensors=\"pt\",\n        ).to(model.device)\n        text_feat = model.forward_text(text_input)\n        text_embed = F.normalize(model.text_proj(text_feat))\n        text_embeds.append(text_embed)\n        text_ids.append(text_input.input_ids)\n        text_atts.append(text_input.attention_mask)\n\n    text_embeds = torch.cat(text_embeds, dim=0)\n    text_ids = torch.cat(text_ids, dim=0)\n    text_atts = torch.cat(text_atts, dim=0)\n\n    vit_feats = []\n    image_embeds = []\n    for samples in data_loader:\n        image = samples[\"image\"]\n\n        image = image.to(model.device)\n        image_feat, vit_feat = model.forward_image(image)\n        image_embed = model.vision_proj(image_feat)\n        image_embed = F.normalize(image_embed, dim=-1)\n\n        vit_feats.append(vit_feat.cpu())\n        image_embeds.append(image_embed)\n\n    vit_feats = torch.cat(vit_feats, dim=0)\n    image_embeds = torch.cat(image_embeds, dim=0)\n\n    sims_matrix = []\n    for image_embed in image_embeds:\n        sim_q2t = image_embed @ text_embeds.t()\n        sim_i2t, _ = sim_q2t.max(0)\n        sims_matrix.append(sim_i2t)\n    sims_matrix = torch.stack(sims_matrix, dim=0)\n\n    score_matrix_i2t = torch.full(\n        (len(data_loader.dataset.image), len(texts)), -100.0\n    ).to(model.device)\n\n    num_tasks = dist_utils.get_world_size()\n    rank = dist_utils.get_rank()\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n        image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)\n        score = model.compute_itm(\n            image_inputs=image_inputs,\n            text_ids=text_ids[topk_idx],\n            text_atts=text_atts[topk_idx],\n        ).float()\n        score_matrix_i2t[start + i, topk_idx] = score + topk_sim\n\n    sims_matrix = sims_matrix.t()\n    score_matrix_t2i = torch.full(\n        (len(texts), len(data_loader.dataset.image)), -100.0\n    ).to(model.device)\n\n    step = sims_matrix.size(0) // num_tasks + 1\n    start = rank * step\n    end = min(sims_matrix.size(0), start + step)\n\n    for i, sims in enumerate(\n        metric_logger.log_every(sims_matrix[start:end], 50, header)\n    ):\n        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)\n        image_inputs = vit_feats[topk_idx.cpu()].to(model.device)\n        score = model.compute_itm(\n            image_inputs=image_inputs,\n            text_ids=text_ids[start + i].repeat(k_test, 1),\n            text_atts=text_atts[start + i].repeat(k_test, 1),\n        ).float()\n        score_matrix_t2i[start + i, topk_idx] = score + topk_sim\n\n    if dist_utils.is_dist_avail_and_initialized():\n        dist.barrier()\n        torch.distributed.all_reduce(\n            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM\n        )\n        torch.distributed.all_reduce(\n            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM\n        )\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logging.info(\"Evaluation time {}\".format(total_time_str))\n\n    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()\n"
  },
  {
    "path": "stllm/models/blip2_outputs.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_outputs import (\n    ModelOutput,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\n\n\n@dataclass\nclass BlipSimilarity(ModelOutput):\n    sim_i2t: torch.FloatTensor = None\n    sim_t2i: torch.FloatTensor = None\n\n    sim_i2t_m: Optional[torch.FloatTensor] = None\n    sim_t2i_m: Optional[torch.FloatTensor] = None\n\n    sim_i2t_targets: Optional[torch.FloatTensor] = None\n    sim_t2i_targets: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass BlipIntermediateOutput(ModelOutput):\n    \"\"\"\n    Data class for intermediate outputs of BLIP models.\n\n    image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).\n    text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).\n\n    image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).\n    text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).\n\n    encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.\n    encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.\n\n    decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.\n    decoder_labels (torch.LongTensor): labels for the captioning loss.\n\n    itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).\n    itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)\n\n    \"\"\"\n\n    # uni-modal features\n    image_embeds: torch.FloatTensor = None\n    text_embeds: Optional[torch.FloatTensor] = None\n\n    image_embeds_m: Optional[torch.FloatTensor] = None\n    text_embeds_m: Optional[torch.FloatTensor] = None\n\n    # intermediate outputs of multimodal encoder\n    encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n    encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None\n\n    itm_logits: Optional[torch.FloatTensor] = None\n    itm_labels: Optional[torch.LongTensor] = None\n\n    # intermediate outputs of multimodal decoder\n    decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None\n    decoder_labels: Optional[torch.LongTensor] = None\n\n\n@dataclass\nclass BlipOutput(ModelOutput):\n    # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.\n    sims: Optional[BlipSimilarity] = None\n\n    intermediate_output: BlipIntermediateOutput = None\n\n    loss: Optional[torch.FloatTensor] = None\n\n    loss_itc: Optional[torch.FloatTensor] = None\n\n    loss_itm: Optional[torch.FloatTensor] = None\n\n    loss_lm: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass BlipOutputFeatures(ModelOutput):\n    \"\"\"\n    Data class of features from BlipFeatureExtractor.\n\n    Args:\n        image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional\n        image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional\n        text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional\n        text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional\n\n        The first embedding or feature is for the [CLS] token.\n\n        Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    image_embeds_proj: Optional[torch.FloatTensor] = None\n\n    text_embeds: Optional[torch.FloatTensor] = None\n    text_embeds_proj: Optional[torch.FloatTensor] = None\n\n    multimodal_embeds: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "stllm/models/eva_btadapter.py",
    "content": "# --------------------------------------------------------\n# Adapted from  https://github.com/microsoft/unilm/tree/master/beit\n# --------------------------------------------------------\nimport math\nimport os\nimport copy\nfrom functools import partial\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom .eva_vit import create_eva_vit_g, DropPath, Attention, Block, VisionTransformer,\\\n    interpolate_pos_embed, convert_weights_to_fp16\n\nfrom einops import rearrange\ntry:\n    from timm.models.layers import drop_path, to_2tuple, trunc_normal_\nexcept:\n    from timm.layers import drop_path, to_2tuple, trunc_normal_\n\nfrom torch.utils.checkpoint import checkpoint\n\n\nif os.getenv('ENV_TYPE') == 'deepspeed':\n    try:\n        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint\n    except:\n        from torch.utils.checkpoint import checkpoint\nelse:\n    from torch.utils.checkpoint import checkpoint\n\ntry:\n    import xformers.ops as xops\nexcept ImportError:\n    xops = None\n    print(\"Please 'pip install xformers'\")\n\n\ndef constant_init(module, val, bias=0):\n    if hasattr(module, 'weight') and module.weight is not None:\n        nn.init.constant_(module.weight, val)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\nclass EVAVisionTransformer_BTAdapter(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, depth=4, mask_rate=0):\n        super().__init__()\n\n        clip = create_eva_vit_g(\n            224, 0, False, \"fp16\"\n        )\n        \n        self.image_size = clip.image_size\n        self.num_classes = clip.num_classes\n        self.num_features = self.embed_dim = clip.embed_dim  # num_features for consistency with other models\n        self.num_heads = clip.num_heads\n        self.patch_embed = clip.patch_embed\n        self.num_patches = self.patch_embed.num_patches\n\n        self.cls_token = clip.cls_token\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = clip.pos_embed\n        self.pos_drop = clip.pos_drop\n        self.rel_pos_bias = clip.rel_pos_bias\n\n        self.rope = None\n\n        self.use_rel_pos_bias = clip.use_rel_pos_bias\n        self.blocks = clip.blocks\n\n        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\n        self.grad_checkpointing = False\n\n        self.depth = depth\n        self.mask_rate = mask_rate\n        \n        dpr = np.linspace(0, 0.1, depth)\n        self.BTAdapter_cls =  nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        self.BTAdapter_S = nn.ModuleList([BTAdapter_Spatial(self.embed_dim, self.num_heads, drop_num=dpr[i]) for i in range(depth)])\n        self.BTAdapter_T = nn.ModuleList([BTAdapter_Temp(self.embed_dim, self.num_heads, drop_num=dpr[i]) for i in range(depth)])\n        self.BTAdapter_position = nn.Embedding(64, self.embed_dim)\n        self.init_weights()\n\n        del clip\n\n    def init_weights(self):\n        total_depth = len(self.blocks)\n        self.num_layers = total_depth\n        layer_para = self.blocks.state_dict()\n        spatial_para = {}\n        load_start = total_depth - self.depth\n        for k, v in layer_para.items():\n            num_layer = int(k.split(\".\")[0])\n            if num_layer >= load_start:\n                spatial_para[k.replace(str(num_layer),str(num_layer-load_start),1)] = v.clone()\n        self.BTAdapter_S.load_state_dict(spatial_para)\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            if self.naiveswiglu:\n                rescale(layer.mlp.w3.weight.data, layer_id + 1)\n            else:\n                rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def get_cast_dtype(self) -> torch.dtype:\n        return self.blocks[0].mlp.fc2.weight.dtype\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_num_layers(self):\n        return len(self.blocks)\n    \n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        assert unlocked_groups == 0, 'partial locking not currently supported for this model'\n        for param in self.parameters():\n            param.requires_grad = False\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x, mask=None):\n        \n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        \n        branch_x = None\n        \n        for idx,blk in enumerate(self.blocks):\n            if self.training and self.grad_checkpointing:\n                x = checkpoint(blk, x, rel_pos_bias, None)\n            else:\n                x = blk(x, rel_pos_bias=rel_pos_bias)\n            if idx >= self.num_layers-self.depth:\n                num_layer = idx + self.depth - self.num_layers\n                branch_x = self.forward_branch(x, branch_x, num_layer, mask)\n        \n        #x_patch, branch_patch = x[:,1:,], branch_x[:,1:,]\n        #x_cls = rearrange(x[:,:1], '(b t) p m -> b t p m', t=self.T).mean(dim=1)\n        #branch_cls = branch_x[:,:1,]\n        #x_patch = rearrange(x_patch, '(b t) p m -> b t p m', t=self.T).mean(dim=1)\n        #p = x_patch.size(1)\n        #branch_patch = rearrange(branch_patch, 'b (p t) m -> b t p m', p=p).mean(dim=1)\n        #x = torch.cat(((x_patch+branch_patch)/2,(x_cls+branch_cls)/2),dim=1)\n\n        p = x.size(1) - 1\n        branch_cls, branch_patch = branch_x[:,0], branch_x[:,1:,]\n        branch_patch = rearrange(branch_patch, 'b (p t) m -> (b t) p m', p=p)\n        branch_cls = branch_cls.repeat(1, self.T).view(branch_cls.size(0) * self.T, -1).unsqueeze(1)\n        x = (x + torch.cat((branch_cls,branch_patch),dim=1))/2\n        return x\n\n    def forward_branch(self, x, branch_x, num_layer, mask=None):\n        x = rearrange(x, '(b t) l d -> b t l d', t=self.T)\n        if branch_x is not None:\n            cls_x = x[:,:,0]\n            cls_branch = cls_x.mean(dim=1).unsqueeze(1)\n            x = rearrange(x[:,:,1:], 'b t l d -> b (l t) d')\n            if mask is not None:\n                B, _, D = x.size()\n                x = x[~mask].reshape(B,-1,D)\n            x = torch.cat((cls_branch,x), dim=1)\n            x = x + branch_x\n        \n        if num_layer==0:\n            x = self.init_input(x,mask)\n\n        if self.grad_checkpointing and self.training:\n            x = checkpoint(self.BTAdapter_T[num_layer],x,self.T)\n            x = checkpoint(self.BTAdapter_S[num_layer],x,self.T)\n        else: \n            x = self.BTAdapter_T[num_layer](x, self.T)\n            x = self.BTAdapter_S[num_layer](x, self.T)\n        return x\n    \n    def init_input(self, x, mask=None):\n        cls_x = x[:,:,0].mean(dim=1).unsqueeze(1)\n        x = x[:,:,1:,:]\n        b,t,l,d = x.size()\n        x = rearrange(x, 'b t l d -> (b t) l d')\n        #cls_branch = self.class_embedding.expand(1, x.size(1), -1)\n        cls_branch = self.BTAdapter_cls.expand(x.size(0), 1, -1)\n        x = torch.cat((cls_branch, x), dim=1)\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        cls_branch = x[:b, 0, :].unsqueeze(1)\n        x = rearrange(x[:, 1:, :], '(b t) l d -> (b l) t d', b=b)\n        position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device).unsqueeze(0).expand(x.size(0),-1)\n        time_embed = self.BTAdapter_position(position_ids)\n        x = x + time_embed\n\n        x = rearrange(x, '(b l) t d -> b (l t) d', b=b)\n        if mask is not None:\n            x = x[~mask].reshape(b,-1,d)\n        cls = (cls_x + cls_branch) / 2\n        x = torch.cat((cls, x), dim=1)\n        return x\n    \n    def forward(self, x, return_all_features=False):\n        if x.ndim == 5:\n            if x.shape[1]==3:\n                B, D, T, H, W = x.shape             \n                x = x.permute(0, 2, 1, 3, 4)\n            else:\n                B, T, D, H, W = x.shape   \n            x = x.reshape((-1,) + x.shape[2:])\n        elif x.ndim == 4:\n            T, D, H, W = x.shape \n            B = 1\n        else:\n            B, _, _, _ = x.shape\n            T = 1\n        self.T = T\n\n        if self.mask_rate>0 and self.training:\n            mask = TubeMaskingGenerator((T,self.num_patches),self.mask_rate,B,x.device)\n        else:\n            mask = None\n\n        x = self.forward_features(x, mask)\n        return x\n\nclass BTAdapter_Spatial(Block):\n    def __init__(self, d_model, n_head, drop_num=0.1):\n        super().__init__(dim=d_model, num_heads=n_head, drop_path=drop_num, qkv_bias=True, mlp_ratio=4.3637)\n        \n    def forward(self, x, T):\n        residual = x\n        init_cls_token = x[:,:1,:]\n        query_s = x[:, 1:, :]\n\n        b, pt, m = query_s.size()\n        p, t = pt//T, T\n\n        cls_token = init_cls_token.unsqueeze(1).repeat(1, t, 1, 1).reshape(b*t, 1, m)\n        query_s = rearrange(query_s, 'b (p t) m -> (b t) p m', p=p, t=t)\n\n        x = torch.cat((cls_token, query_s), 1)\n        x = self.attn(self.norm1(x))\n        res_spatial = self.drop_path(x.contiguous())\n        cls_token = res_spatial[:, :1, :].reshape(b, t, 1, m).mean(1)\n        res_spatial = rearrange(res_spatial[:, 1:, :], '(b t) p m -> b (p t) m', p=p, t=t)\n        \n        x = residual + torch.cat((cls_token, res_spatial), 1)\n        x = x + self.mlp(self.norm2(x))\n        x = self.drop_path(x.contiguous())\n        return x\n\nclass BTAdapter_Temp(nn.Module):\n    def __init__(self, d_model, n_head, drop_num=0.1, norm_layer=partial(nn.LayerNorm, eps=1e-6)):\n        super().__init__()\n        self.drop_path = DropPath(drop_num) if drop_num > 0. else nn.Identity()\n        self.attn = Attention(\n            d_model, num_heads=n_head, qkv_bias=True)\n        self.norm1 = norm_layer(d_model)\n\n        self.temporal_fc = nn.Linear(d_model, d_model)\n        constant_init(self.temporal_fc, val=0, bias=0)\n\n\n    def forward(self, x, T):\n        residual = x[:, 1:, :]\n\n        init_cls_token = x[:, :1, :]\n        query_t = x[:, 1:, :]\n        b, pt, m = query_t.size()\n        p, t = pt // T, T\n        x = query_t.reshape(b * p, t, m)\n\n        x = self.attn(self.norm1(x))\n        res_temporal = self.drop_path(x.contiguous())\n        res_temporal = self.temporal_fc(res_temporal)\n\n        x = res_temporal.reshape(b, p * t, m) + residual\n        x = torch.cat((init_cls_token, x), 1)\n        return x\n    \ndef create_eva_btadapter(precision=\"fp16\"):\n    model = EVAVisionTransformer_BTAdapter(depth=3)  \n    if precision == \"fp16\":\n#         model.to(\"cuda\") \n        convert_weights_to_fp16(model)\n    return model\n"
  },
  {
    "path": "stllm/models/eva_vit.py",
    "content": "# Based on EVA, BEIT, timm and DeiT code bases\n# https://github.com/baaivision/EVA\n# https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# https://github.com/microsoft/unilm/tree/master/beit\n# https://github.com/facebookresearch/deit/\n# https://github.com/facebookresearch/dino\n# --------------------------------------------------------'\nimport math\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import drop_path, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\nfrom stllm.common.dist_utils import download_cached_file\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),\n        **kwargs\n    }\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n    def extra_repr(self) -> str:\n        return 'p={}'.format(self.drop_prob)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        # x = self.drop(x)\n        # commit this for the orignal BERT implement \n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\n            proj_drop=0., window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        if attn_head_dim is not None:\n            head_dim = attn_head_dim\n        all_head_dim = head_dim * self.num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n\n        if window_size:\n            self.window_size = window_size\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n            # cls to token & token 2 cls & cls to cls\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(window_size[0])\n            coords_w = torch.arange(window_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n            relative_position_index = \\\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\n            relative_position_index[0, 0] = self.num_relative_distance - 1\n\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n        else:\n            self.window_size = None\n            self.relative_position_bias_table = None\n            self.relative_position_index = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(all_head_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, rel_pos_bias=None):\n        B, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.relative_position_bias_table is not None:\n            relative_position_bias = \\\n                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                    self.window_size[0] * self.window_size[1] + 1,\n                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if rel_pos_bias is not None:\n            attn = attn + rel_pos_bias\n        \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 window_size=None, attn_head_dim=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if init_values is not None and init_values > 0:\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        else:\n            self.gamma_1, self.gamma_2 = None, None\n\n    def forward(self, x, rel_pos_bias=None):\n        if self.gamma_1 is None:\n            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x, **kwargs):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass RelativePositionBias(nn.Module):\n\n    def __init__(self, window_size, num_heads):\n        super().__init__()\n        self.window_size = window_size\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        # cls to token & token 2 cls & cls to cls\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(window_size[0])\n        coords_w = torch.arange(window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\n        relative_position_index = \\\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\n        relative_position_index[0, 0] = self.num_relative_distance - 1\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        # trunc_normal_(self.relative_position_bias_table, std=.02)\n\n    def forward(self):\n        relative_position_bias = \\\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.window_size[0] * self.window_size[1] + 1,\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,\n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,\n                 use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):\n        super().__init__()\n        self.image_size = img_size\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_heads = num_heads\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        if use_abs_pos_emb:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        else:\n            self.pos_embed = None\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if use_shared_rel_pos_bias:\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\n        else:\n            self.rel_pos_bias = None\n        self.use_checkpoint = use_checkpoint\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.use_rel_pos_bias = use_rel_pos_bias\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)\n            for i in range(depth)])\n#         self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)\n#         self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None\n#         self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        if self.pos_embed is not None:\n            trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        # trunc_normal_(self.mask_token, std=.02)\n#         if isinstance(self.head, nn.Linear):\n#             trunc_normal_(self.head.weight, std=.02)\n        self.apply(self._init_weights)\n        self.fix_init_weight()\n#         if isinstance(self.head, nn.Linear):\n#             self.head.weight.data.mul_(init_scale)\n#             self.head.bias.data.mul_(init_scale)\n\n    def fix_init_weight(self):\n        def rescale(param, layer_id):\n            param.div_(math.sqrt(2.0 * layer_id))\n\n        for layer_id, layer in enumerate(self.blocks):\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\n            rescale(layer.mlp.fc2.weight.data, layer_id + 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x, rel_pos_bias)\n            else:\n                x = blk(x, rel_pos_bias)\n        return x\n#         x = self.norm(x)\n\n#         if self.fc_norm is not None:\n#             t = x[:, 1:, :]\n#             return self.fc_norm(t.mean(1))\n#         else:\n#             return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n#         x = self.head(x)\n        return x\n\n    def get_intermediate_layers(self, x):\n        x = self.patch_embed(x)\n        batch_size, seq_len, _ = x.size()\n\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        features = []\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\n        for blk in self.blocks:\n            x = blk(x, rel_pos_bias)\n            features.append(x)\n\n        return features\n    \n    \ndef interpolate_pos_embed(model, checkpoint_model):\n    if 'pos_embed' in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model['pos_embed'].float()\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model['pos_embed'] = new_pos_embed\n            \n            \ndef convert_weights_to_fp16(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n#         if isinstance(l, (nn.MultiheadAttention, Attention)):\n#             for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n#                 tensor = getattr(l, attr)\n#                 if tensor is not None:\n#                     tensor.data = tensor.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n    \n    \ndef create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision=\"fp16\"):\n    model = VisionTransformer(\n        img_size=img_size,\n        patch_size=14,\n        use_mean_pooling=False,\n        embed_dim=1408,\n        depth=39,\n        num_heads=1408//88,\n        mlp_ratio=4.3637,\n        qkv_bias=True,\n        drop_path_rate=drop_path_rate,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        use_checkpoint=use_checkpoint,\n    )  \n    url = \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth\"\n    cached_file = download_cached_file(\n        url, check_hash=False, progress=True\n    )\n    #cached_file = 'Path/to/eva_vit_g.pth'\n    state_dict = torch.load(cached_file, map_location=\"cpu\")    \n    interpolate_pos_embed(model,state_dict) \n    \n    incompatible_keys = model.load_state_dict(state_dict, strict=False)\n#     print(incompatible_keys)\n    \n    if precision == \"fp16\":\n#         model.to(\"cuda\") \n        convert_weights_to_fp16(model)\n    return model"
  },
  {
    "path": "stllm/models/modeling_llama_mem.py",
    "content": "# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py\n\n\"\"\" PyTorch LLaMA model.\"\"\"\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nimport warnings\nfrom flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func # flash_attnv2\nfrom flash_attn.bert_padding import unpad_input, pad_input\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings\nfrom transformers.models.llama.configuration_llama import LlamaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"LlamaConfig\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n        # convert into half-precision if necessary\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\nclass LlamaRotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.\n        if seq_len > self.max_seq_len_cached:\n            self.max_seq_len_cached = seq_len\n            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)\n            freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            # Different from paper, but it uses a different permutation in order to obtain the same calculation\n            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n            self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n            self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n        return (\n            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]\n    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])\n    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)\n    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.max_position_embeddings = config.max_position_embeddings\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            warnings.warn(\n                \"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.\"\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = (\n            self.q_proj(hidden_states)\n            .view(bsz, q_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n        )\n        key_states = (\n            self.k_proj(hidden_states)\n            .view(bsz, q_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n        )\n        value_states = (\n            self.v_proj(hidden_states)\n            .view(bsz, q_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n        )  # shape: (b, num_heads, s, head_dim)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin, position_ids\n        )\n\n        if past_key_value is not None:\n            # reuse k, v\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        # Transform the data into the format required by flash attention\n        qkv = torch.stack([query_states, key_states, value_states], dim=2)\n        qkv = qkv.transpose(1, 3)  # shape: [b, s, 3, num_heads, head_dim]\n        key_padding_mask = attention_mask\n\n        input_type = qkv.dtype\n        if input_type != torch.float16:\n            qkv = qkv.to(dtype=torch.float16)\n\n        if key_padding_mask is None:\n            qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)\n            cu_q_lens = torch.arange(\n                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device\n            )\n            max_s = q_len\n            output = flash_attn_varlen_qkvpacked_func(\n                qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True\n            )\n            output = output.view(bsz, q_len, -1)\n        else:\n            qkv = qkv.reshape(bsz, q_len, -1)\n            qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)\n            qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)\n            output_unpad = flash_attn_varlen_qkvpacked_func(\n                qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True\n            )\n            output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)\n            output = pad_input(output_unpad, indices, bsz, q_len)\n        output = output.to(dtype=input_type)\n        return self.o_proj(output).to(dtype=input_type), None, past_key_value\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(config=config)\n        self.mlp = LlamaMLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, LlamaModel):\n            module.gradient_checkpointing = value\n\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaModel(LlamaPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape,\n                                        inputs_embeds, past_key_values_length):\n        # [bsz, seq_len]\n        return attention_mask\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        query_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        if query_embeds is not None:\n            inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)\n            batch_size, seq_length, _ = inputs_embeds.shape\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n            )\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n        )\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass LlamaForCausalLM(LlamaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        query_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            query_embeds=query_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n                query_embeds = None\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"query_embeds\": query_embeds,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)\n        return reordered_past\n\n"
  },
  {
    "path": "stllm/models/peft_model.py",
    "content": "# coding=utf-8\n# Copyright 2023-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nimport peft\n\nfrom peft.utils import (\n    PeftType,\n)\n\n\ndef forward(\n    self,\n    samples=None,\n    input_ids=None,\n    attention_mask=None,\n    inputs_embeds=None,\n    labels=None,\n    output_attentions=None,\n    output_hidden_states=None,\n    return_dict=None,\n    task_ids=None,\n    **kwargs,\n):\n    peft_config = self.active_peft_config\n    if not peft_config.is_prompt_learning:\n        if self.base_model.config.model_type == \"mpt\":\n            if inputs_embeds is not None:\n                raise AssertionError(\"forward in MPTForCausalLM does not support inputs_embeds\")\n            return self.base_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                labels=labels,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                **kwargs,\n            )\n        return self.base_model(\n            samples=samples,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            inputs_embeds=inputs_embeds,\n            labels=labels,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            **kwargs,\n        )\n    batch_size = _get_batch_size(input_ids, inputs_embeds)\n    if attention_mask is not None:\n        # concat prompt attention mask\n        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)\n        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n    if kwargs.get(\"position_ids\", None) is not None:\n        warnings.warn(\"Position ids are not supported for parameter efficient tuning. Ignoring position ids.\")\n        kwargs[\"position_ids\"] = None\n    if kwargs.get(\"token_type_ids\", None) is not None:\n        warnings.warn(\"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids\")\n        kwargs[\"token_type_ids\"] = None\n    kwargs.update(\n        {\n            \"attention_mask\": attention_mask,\n            \"labels\": labels,\n            \"output_attentions\": output_attentions,\n            \"output_hidden_states\": output_hidden_states,\n            \"return_dict\": return_dict,\n        }\n    )\n    if peft_config.peft_type == PeftType.PREFIX_TUNING:\n        past_key_values = self.get_prompt(batch_size)\n        return self.base_model(\n            input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs\n        )\n    else:\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n        # concat prompt labels\n        if labels is not None:\n            prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)\n            kwargs[\"labels\"] = torch.cat((prefix_labels, labels), dim=1)\n        prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)\n        prompts = prompts.to(inputs_embeds.dtype)\n        inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)\n        return self.base_model(inputs_embeds=inputs_embeds, **kwargs)\n\ndef replace_peftmodel_with_sample_input():\n    peft.peft_model.PeftModelForCausalLM.forward = forward"
  },
  {
    "path": "stllm/models/st_llm.py",
    "content": "import logging\nimport random\nimport re\nimport os\nimport math\nimport einops\nimport ast\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\nimport numpy as np\nfrom torch.nn import CrossEntropyLoss\n\nfrom stllm.common.registry import registry\nfrom stllm.models.utils import RandomMaskingGenerator, get_sinusoid_encoding_table\nfrom stllm.models.blip2 import Blip2Base, disabled_train\nfrom stllm.models.peft_model import replace_peftmodel_with_sample_input\nfrom stllm.models.base_model import BaseModel\nfrom transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaConfig, LlamaModel\n#from stllm.models.modeling_llama_mem import LlamaForCausalLM, LlamaModel\n#from transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom transformers import LlamaTokenizer\n\nfrom peft import (\n    LoraConfig,\n    get_peft_model,\n)\n\nclass StllmConfig(LlamaConfig):\n    model_type = \"st_llm_hf\"\n\n\nclass Linear_Decoder(nn.Module):\n    def __init__(self, output_dim=4096, embed_dim=4096):\n        super().__init__()\n        self.head = nn.Linear(embed_dim, output_dim)\n        self.norm = nn.LayerNorm(output_dim)\n\n    def forward(self, x):\n        x = self.norm(self.head(x))\n        return x\n\nclass STLLMLlamaModel(LlamaModel):\n    config_class = StllmConfig\n    def __init__(self, config: LlamaConfig):  # TODO: Remove unused params\n        super(STLLMLlamaModel, self).__init__(config)\n    \n    def initialize_vision_modules(self, cfg):\n        self.stllm_model = STLLMModel.from_config(cfg)\n        if cfg.get(\"qformer_text_input\", False):\n            self.resize_token_embeddings(len(self.stllm_model.llama_tokenizer))\n        self.stllm_model.embed_tokens = self.embed_tokens\n    \n    def forward(self, samples=None, inputs_embeds=None, **kwargs):\n        if samples is None:\n            return super(STLLMLlamaModel, self).forward(inputs_embeds=inputs_embeds, **kwargs)\n        \n        inputs_embeds, attention_mask, unmask_inputs_embeds, unmask_attention_mask, labels = self.stllm_model(samples)\n        output_hidden_states = not (unmask_inputs_embeds is None)\n        outputs = super(STLLMLlamaModel, self).forward(\n            input_ids=None, attention_mask=attention_mask,\n            inputs_embeds=inputs_embeds, use_cache=False,\n            output_hidden_states=output_hidden_states,\n            return_dict=True\n        )\n        if unmask_inputs_embeds is None:\n            return outputs, None, labels\n        \n        img_start = 0 if self.stllm_model.qformer_text_input else 8\n        mask_output = outputs.hidden_states[-1]\n        B, _, D = mask_output.size()\n        mask_img_output = mask_output[:,img_start:img_start+self.stllm_model.mask_img_len]\n        if hasattr(self.stllm_model, \"mvm_decoder\"):\n            mask_img_output = self.stllm_model.mvm_decoder(mask_img_output)\n\n        with torch.no_grad():\n            unmask_outputs = super(STLLMLlamaModel, self).forward(\n                inputs_embeds=unmask_inputs_embeds,\n                attention_mask=unmask_attention_mask,\n                return_dict=True, use_cache=False,\n                output_hidden_states=True,\n            )\n        unmask_output = unmask_outputs.hidden_states[-1]\n        unmask_img_output = unmask_output[:,img_start:img_start+self.stllm_model.img_len]\n        unmask_img_output = unmask_img_output[~(self.stllm_model.mask.squeeze(1))].reshape(B, -1, D)\n\n        mask_img_output = mask_img_output / mask_img_output.norm(dim=-1, keepdim=True)\n        unmask_img_output = unmask_img_output / unmask_img_output.norm(dim=-1, keepdim=True)\n        loss_mvm = (2 - 2 * (mask_img_output * unmask_img_output).sum(dim=-1)).mean()\n        return outputs, loss_mvm, labels\n\n@registry.register_model(\"st_llm_hf\")\nclass STLLMForCausalLM(LlamaForCausalLM, BaseModel):\n    config_class = StllmConfig\n    PRETRAINED_MODEL_CONFIG_DICT = {\n        \"instructblip_vicuna0\": \"configs/models/instructblip_vicuna0.yaml\",\n        \"instructblip_vicuna0_btadapter\": \"configs/models/instructblip_vicuna0_btadapter.yaml\",\n        \"minigpt4_vicuna0\": \"configs/models/minigpt4_vicuna0.yaml\",\n        \"minigpt4_vicuna0_btadapter\": \"configs/models/minigpt4_vicuna0_btadapter.yaml\",\n    }\n\n    def __init__(self, config):\n        super(LlamaForCausalLM, self).__init__(config)\n        self.model = STLLMLlamaModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()  \n\n    def get_model(self):\n        return self.model\n    \n    def forward(self, samples=None, inputs_embeds=None, **kwargs):\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        if samples is None:\n            return super(STLLMForCausalLM, self).forward(inputs_embeds=inputs_embeds, **kwargs)\n        outputs, loss_pretrain, labels = self.model(samples)\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model/pipeline parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if loss_pretrain is not None:\n            loss += loss_pretrain\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    \n    @classmethod\n    def get_state_dict(self, path, prefix='pytorch_model'):\n        pattern = re.compile(f'{prefix}-(\\d+)-of-(\\d+).bin')\n        matching_files = [filename for filename in os.listdir(path) if pattern.match(filename)]\n\n        model_state_dict = {}\n        for model_path in matching_files:\n            partial_state_dict = torch.load(os.path.join(path,model_path), map_location=torch.device('cpu'))\n            model_state_dict.update(partial_state_dict)\n        return model_state_dict\n\n    @classmethod\n    def from_config(cls, cfg):\n        llama_model = cfg.get(\"llama_model\")\n        \n        model = cls.from_pretrained(llama_model, torch_dtype=torch.float16)        \n        lora_r = cfg.get(\"lora_r\", 0)\n        lora_alpha = cfg.get(\"lora_alpha\", 32)\n        if lora_r > 0:\n            replace_peftmodel_with_sample_input()\n            loraconfig = LoraConfig(\n                r=lora_r,\n                lora_alpha=lora_alpha,\n                target_modules=[\"q_proj\", \"v_proj\"],\n                lora_dropout=0.05,\n                bias=\"none\",\n                task_type=\"CAUSAL_LM\"\n            )\n            model = get_peft_model(model, loraconfig)\n\n        model.get_model().initialize_vision_modules(cfg)\n        if cfg.get(\"qformer_text_input\", False):\n            model.resize_token_embeddings(model.config.vocab_size)\n        if cfg.get(\"freeze_LLM\",True):\n            for name, param in model.named_parameters():\n                if 'stllm_model' not in name and 'lora' not in name:\n                    param.requires_grad = False\n        if cfg.get(\"use_grad_checkpoint\",True):\n            model.gradient_checkpointing_enable()\n\n        ckpt_path = cfg.get(\"ckpt\", \"\")  # load weights of MiniGPT-4\n        if ckpt_path:\n            print(\"Load BLIP2-LLM Checkpoint: {}\".format(ckpt_path))\n            if os.path.isdir(ckpt_path):\n                ckpt = cls.get_state_dict(ckpt_path)\n            else:\n                ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n            if 'model' in ckpt:\n                ckpt = ckpt['model']\n            if 'llm_proj.weight' in ckpt:\n                ckpt['llama_proj.weight'] = ckpt.pop('llm_proj.weight')\n                ckpt['llama_proj.bias'] = ckpt.pop('llm_proj.bias')\n            msg = model.load_state_dict(ckpt, strict=False)\n                \n        return model\n\nclass STLLMModel(Blip2Base):\n    \"\"\"\n    BLIP2 GPT-LLAMA model.\n    \"\"\"\n    def __init__(\n        self,\n        vit_model=\"eva_clip_g\",\n        q_former_model=\"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\",\n        img_size=224,\n        pre_encoding=False,\n        use_mask=False,\n        mvm_decode=False,\n        video_input=None,\n        residual_size=4,\n        qformer_text_input=False,\n        drop_path_rate=0,\n        use_grad_checkpoint=False,\n        vit_precision=\"fp16\",\n        freeze_vit=True,\n        has_qformer=True,\n        freeze_qformer=True,\n        num_query_token=32,\n        llama_model=\"\",\n        max_txt_len=32,\n        end_sym='\\n',\n    ):\n        super().__init__()\n\n        self.tokenizer = self.init_tokenizer(truncation_side=\"left\")\n        self.pre_encoding = pre_encoding\n        self.video_input = video_input\n        self.use_mask = use_mask\n        self.mvm_decode = mvm_decode\n        self.qformer_text_input = qformer_text_input\n        self.residual_size = residual_size\n        if self.video_input == 'residual':\n            self.down_proj = nn.Linear(4096, 1024)\n            self.non_linear_func = nn.ReLU()\n            self.up_proj = nn.Linear(1024, 4096)\n            nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))\n            nn.init.zeros_(self.up_proj.weight)\n            nn.init.zeros_(self.down_proj.bias)\n            nn.init.zeros_(self.up_proj.bias)\n\n        if self.mvm_decode:\n            self.mvm_decoder = Linear_Decoder()\n\n        print('Loading VIT')\n        self.vit_model = vit_model\n        self.visual_encoder, self.ln_vision = self.init_vision_encoder(\n            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision\n        )\n        if freeze_vit:\n            for name, param in self.visual_encoder.named_parameters():\n                if 'BTAdapter' in name:\n                    continue\n                param.requires_grad = False\n            for name, param in self.ln_vision.named_parameters():\n                param.requires_grad = False\n            if vit_model=='eva_clip_g':\n                self.ln_vision = self.ln_vision.eval()\n                self.ln_vision.train = disabled_train\n            logging.info(\"freeze vision encoder\")\n        print('Loading VIT Done')\n\n        self.has_qformer = has_qformer\n        if self.has_qformer:\n            print('Loading Q-Former')\n            self.Qformer, self.query_tokens = self.init_Qformer(\n                num_query_token, self.visual_encoder.num_features\n            )\n            \n            if not qformer_text_input:\n                self.Qformer.bert.embeddings.word_embeddings = None\n                self.Qformer.bert.embeddings.position_embeddings = None\n                for layer in self.Qformer.bert.encoder.layer:\n                    layer.output = None\n                    layer.intermediate = None\n                self.load_from_pretrained(url_or_filename=q_former_model)\n            else:\n                self.Qformer.resize_token_embeddings(len(self.tokenizer))\n                self.load_from_pretrained(url_or_filename=q_former_model)\n\n            self.Qformer.cls = None\n            if freeze_qformer:\n                for name, param in self.Qformer.named_parameters():\n                    param.requires_grad = False\n                if vit_model=='eva_clip_g':\n                    self.Qformer = self.Qformer.eval()\n                    self.Qformer.train = disabled_train\n                self.query_tokens.requires_grad = False\n                logging.info(\"freeze Qformer\")\n\n            img_f_dim = self.Qformer.config.hidden_size\n            print('Loading Q-Former Done')\n        else:\n            img_f_dim = self.visual_encoder.num_features * 4\n            print('Do not use Q-Former here.')\n\n        print('Loading LLAMA')\n        self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)\n        if qformer_text_input:\n            self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n            self.llama_tokenizer.add_special_tokens({'bos_token': '</s>'})\n            self.llama_tokenizer.add_special_tokens({'eos_token': '</s>'})\n            self.llama_tokenizer.add_special_tokens({'unk_token': '</s>'})\n        else:\n            self.llama_tokenizer.pad_token = \"$$\"\n\n        self.llama_proj = nn.Linear(\n            img_f_dim, 4096\n        )\n\n        self.max_txt_len = max_txt_len\n        self.end_sym = end_sym\n\n    def encode_img(self, image, text=None):\n        device = image.device\n        with self.maybe_autocast():\n            T = image.shape[1]\n            infer = True if len(image.shape)==4 else False\n            use_image = True if T == 1 or (len(image.shape)==4) else False\n            if (not use_image or len(image.shape)==5) and (self.vit_model=='eva_clip_g'):\n                image = einops.rearrange(image,'B T C H W -> (B T) C H W')\n\n            image_embeds = self.visual_encoder(image)\n            image_embeds = self.ln_vision(image_embeds)\n            if self.has_qformer:\n                image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)\n                query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)\n                if self.qformer_text_input:\n                    assert text\n                    if isinstance(text, str):\n                        text = [text] * query_tokens.size(0)\n                    elif len(text) != query_tokens.size(0):\n                        text_ = []\n                        for t in text:\n                            text_ += [t] * T\n                        text = text_\n                    text_Qformer = self.tokenizer(\n                        text,\n                        padding='longest',\n                        truncation=True,\n                        max_length=self.max_txt_len,\n                        return_tensors=\"pt\",\n                    ).to(query_tokens.device)\n                    query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)\n                    Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)\n                    query_output = self.Qformer.bert(\n                        text_Qformer.input_ids,\n                        attention_mask=Qformer_atts,\n                        query_embeds=query_tokens,\n                        encoder_hidden_states=image_embeds,\n                        encoder_attention_mask=image_atts,\n                        return_dict=True,\n                    )\n                else:\n                    query_output = self.Qformer.bert(\n                        query_embeds=query_tokens,\n                        encoder_hidden_states=image_embeds,\n                        encoder_attention_mask=image_atts,\n                        return_dict=True,\n                    )\n                inputs_llama = self.llama_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])\n            else:\n                image_embeds = image_embeds[:, 1:, :]\n                bs, pn, hs = image_embeds.shape\n                image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))\n                inputs_llama = self.llama_proj(image_embeds)\n            if not infer:\n                inputs_llama = einops.rearrange(inputs_llama,'(B T) L D -> B T L D',T=T)\n            atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)\n        return inputs_llama, atts_llama, use_image\n  \n    def prompt_wrap(self, img_embeds, atts_img, prompts):\n        if prompts:\n            emb_lists = []\n            if isinstance(prompts, str):\n                prompts = [prompts] * len(img_embeds)\n\n            for each_img_embed, each_prompt in zip(img_embeds, prompts):\n                p_before, p_after = each_prompt.split('<ImageHere>')\n                p_before_tokens = self.llama_tokenizer(\n                    p_before, return_tensors=\"pt\", add_special_tokens=False).to(img_embeds.device)\n                p_after_tokens = self.llama_tokenizer(\n                    p_after, return_tensors=\"pt\", add_special_tokens=self.qformer_text_input).to(img_embeds.device)\n                p_before_embed = self.embed_tokens(p_before_tokens.input_ids) if min(p_before_tokens.input_ids.shape) != 0 else None\n                p_after_embed = self.embed_tokens(p_after_tokens.input_ids)\n                if len(each_img_embed.size())==2:\n                    each_img_embed = each_img_embed[None]\n                wrapped_emb = torch.cat([p_before_embed, each_img_embed, p_after_embed], dim=1) if p_before_embed is not None \\\n                    else torch.cat([each_img_embed, p_after_embed], dim=1)\n                emb_lists.append(wrapped_emb)\n            emb_lens = [emb.shape[1] for emb in emb_lists]\n            pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))\n            wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()\n            wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)\n            for i, emb in enumerate(emb_lists):\n                wrapped_embs[i, :emb_lens[i]] = emb\n                wrapped_atts[i, :emb_lens[i]] = 1\n            return wrapped_embs, wrapped_atts\n        else:\n            return img_embeds, atts_img\n    \n    def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):\n        input_lens = []\n        cat_embs = []\n        cat_atts = []\n        for i in range(input_embs.size(0)):\n            input_len = input_atts[i].sum()\n            input_lens.append(input_len)\n            cat_embs.append(\n                torch.cat([\n                    input_embs[i][:input_len],\n                    output_embs[i],\n                    input_embs[i][input_len:]\n                ])\n            )\n            cat_atts.append(\n                torch.cat([\n                    input_atts[i][:input_len],\n                    output_atts[i],\n                    input_atts[i][input_len:]\n                ])\n            )\n        cat_embs = torch.stack(cat_embs)\n        cat_atts = torch.stack(cat_atts)\n        return cat_embs, cat_atts, input_lens\n\n    def get_residual_index(self, sample_segments, total_segments, devices):\n        if hasattr(self,'residual_index'):\n            return self.residual_index\n        else:\n            seg_size = float(total_segments) / sample_segments\n            frame_indices = np.array([\n            int((seg_size / 2) + np.round(seg_size * idx))\n            for idx in range(sample_segments)\n            ])\n            frame_indices = torch.from_numpy(frame_indices).to(devices)\n            self.register_buffer('residual_index', frame_indices)\n            return frame_indices\n\n    def forward(self, samples):\n        image = samples[\"image\"]\n        instruction = samples[\"instruction_input\"] if \"instruction_input\" in samples else None\n\n        use_image = False\n        if self.pre_encoding:\n            image = image.type_as(self.llama_proj.weight)\n            img_embeds = self.llama_proj(image)\n            atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(image.device)\n        else:\n            if self.qformer_text_input:\n                qformer_text_input = [it.split('Human: ')[1].split(' ###')[0] for it in instruction]\n            else:\n                qformer_text_input = None\n            img_embeds, atts_img, use_image = self.encode_img(image, qformer_text_input)\n\n        if not use_image:\n            T = img_embeds.size(1)\n        if not use_image and self.video_input == 'all':\n            img_embeds = img_embeds.view(img_embeds.size(0),1,-1,img_embeds.size(-1)).contiguous()\n        elif not use_image and self.video_input == 'mean':\n            img_embeds = img_embeds.mean(dim=1, keepdim=True)\n        elif not use_image and self.video_input == 'residual':\n            residual_index = self.get_residual_index(self.residual_size,T,img_embeds.device)\n            global_embeds = img_embeds.mean(dim=1, keepdim=True)\n            \n            local_embeds = img_embeds[:,residual_index]\n            global_embeds = global_embeds.expand((-1,self.residual_size,-1,-1)).to(self.up_proj.weight.dtype)\n            global_embeds = self.up_proj(self.non_linear_func(self.down_proj(global_embeds)))\n            img_embeds = (local_embeds + global_embeds).view(img_embeds.size(0),1,-1,img_embeds.size(-1)).contiguous()\n        else:\n            pass\n\n        B, _, L, D = img_embeds.size()\n        unmask_img_embeds = None\n        if not use_image and self.use_mask:\n            self.img_len = L\n            rate = np.random.normal(0.5, 0.1)\n            mask_rate = float(np.clip(rate,0.1,0.7))\n            mask = RandomMaskingGenerator(L, mask_rate, B, img_embeds.device).unsqueeze(1)\n            self.mask = mask\n\n            unmask_img_embeds = img_embeds\n            unmask_atts_img = torch.ones(unmask_img_embeds.size()[:-1], dtype=torch.long).to(image.device)\n            img_embeds = img_embeds[~mask].reshape(B, 1, -1, D)\n            atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(image.device)\n            self.mask_img_len = img_embeds.size(2)\n\n\n        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)\n        self.llama_tokenizer.padding_side = \"right\"\n        text = [t + self.llama_tokenizer.eos_token for t in samples[\"answer\"]] if self.qformer_text_input \\\n            else [t + self.end_sym for t in samples[\"answer\"]]\n\n        to_regress_tokens = self.llama_tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            truncation=True,\n            max_length=self.max_txt_len,\n            add_special_tokens=False\n        ).to(image.device)\n        to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)\n\n        inputs_embeds, attention_mask, input_lens = \\\n            self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)\n\n        if unmask_img_embeds is not None:\n            unmask_img_embeds, unmask_atts_img = self.prompt_wrap(unmask_img_embeds, unmask_atts_img, instruction)\n            unmask_inputs_embeds, unmask_attention_mask, unmask_input_lens = \\\n            self.concat_emb_input_output(unmask_img_embeds, unmask_atts_img, to_regress_embeds, to_regress_tokens.attention_mask)\n\n        if not self.qformer_text_input:\n            batch_size = img_embeds.shape[0]\n            bos = torch.ones([batch_size, 1],\n                             dtype=to_regress_tokens.input_ids.dtype,\n                             device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id\n            bos_embeds = self.embed_tokens(bos)\n            atts_bos = atts_img[:, :1]\n            inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)\n            attention_mask = torch.cat([atts_bos, attention_mask], dim=1)\n            if unmask_img_embeds is not None:\n                unmask_inputs_embeds = torch.cat([bos_embeds, unmask_inputs_embeds], dim=1)\n                unmask_attention_mask = torch.cat([atts_bos, unmask_attention_mask], dim=1)\n\n        part_targets = to_regress_tokens.input_ids.masked_fill(\n            to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100\n        )\n        targets = (\n            torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],\n                       dtype=torch.long).to(image.device).fill_(-100)\n        )\n\n        offset = 0 if self.qformer_text_input else 1\n        for i, target in enumerate(part_targets):\n            targets[i, input_lens[i] + offset:input_lens[i] + len(target) + offset] = target  # plus 1 for bos\n\n        if unmask_img_embeds is None:\n            unmask_inputs_embeds, unmask_attention_mask = None, None\n        return inputs_embeds, attention_mask, unmask_inputs_embeds, unmask_attention_mask, targets\n\n    @classmethod\n    def from_config(cls, cfg):\n        vit_model = cfg.get(\"vit_model\", \"eva_clip_g\")\n        q_former_model = cfg.get(\"q_former_model\", \"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\")\n        img_size = cfg.get(\"image_size\")\n        num_query_token = cfg.get(\"num_query_token\")\n        llama_model = cfg.get(\"llama_model\")\n\n        drop_path_rate = cfg.get(\"drop_path_rate\", 0)\n        use_grad_checkpoint = cfg.get(\"use_grad_checkpoint\", False)\n        vit_precision = cfg.get(\"vit_precision\", \"fp16\")\n        freeze_vit = cfg.get(\"freeze_vit\", True)\n        has_qformer = cfg.get(\"has_qformer\", True)\n        freeze_qformer = cfg.get(\"freeze_qformer\", True)\n\n        max_txt_len = cfg.get(\"max_txt_len\", 32)\n        end_sym = cfg.get(\"end_sym\", '\\n')\n\n        pre_encoding = cfg.get(\"pre_encoding\", False)\n        video_input = cfg.get(\"video_input\", None)\n        use_mask = cfg.get(\"use_mask\", False)\n        qformer_text_input = cfg.get(\"qformer_text_input\", False)\n        residual_size = cfg.get(\"residual_size\", 4)\n        mvm_decode = cfg.get(\"mvm_decode\", False)\n        \n        model = cls(\n            vit_model=vit_model,\n            q_former_model=q_former_model,\n            img_size=img_size,\n            pre_encoding=pre_encoding,\n            video_input=video_input,\n            use_mask=use_mask,\n            mvm_decode=mvm_decode,\n            residual_size=residual_size,\n            qformer_text_input=qformer_text_input,\n            drop_path_rate=drop_path_rate,\n            use_grad_checkpoint=use_grad_checkpoint,\n            vit_precision=vit_precision,\n            freeze_vit=freeze_vit,\n            has_qformer=has_qformer,\n            freeze_qformer=freeze_qformer,\n            num_query_token=num_query_token,\n            llama_model=llama_model,\n            max_txt_len=max_txt_len,\n            end_sym=end_sym,\n        )\n\n        ckpt_path = cfg.get(\"ckpt\", \"\")  # load weights of MiniGPT-4\n        if ckpt_path and not os.path.isdir(ckpt_path):\n            print(\"Load BLIP2-LLM Checkpoint: {}\".format(ckpt_path))\n            ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n            if 'model' in ckpt:\n                ckpt = ckpt['model']\n            if 'llm_proj.weight' in ckpt:\n                ckpt['llama_proj.weight'] = ckpt.pop('llm_proj.weight')\n                ckpt['llama_proj.bias'] = ckpt.pop('llm_proj.bias')\n            msg = model.load_state_dict(ckpt, strict=False)\n\n        return model\n"
  },
  {
    "path": "stllm/models/utils.py",
    "content": "import numpy as np\nimport torch\n\ndef RandomMaskingGenerator(num_patches, mask_ratio, batch, device='cuda'):\n    num_mask = int(mask_ratio * num_patches)\n\n    mask_list = []\n    for _ in range(batch):\n        mask = np.hstack([\n            np.zeros(num_patches - num_mask),\n            np.ones(num_mask),\n        ])\n        np.random.shuffle(mask)\n        mask_list.append(mask)\n    mask = torch.Tensor(mask_list).to(device, non_blocking=True).to(torch.bool)\n    return mask\n\ndef get_sinusoid_encoding_table(n_position, d_hid): \n    ''' Sinusoid position encoding table ''' \n    # TODO: make it with torch instead of numpy \n    def get_position_angle_vec(position): \n        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] \n\n    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) \n    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i \n    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 \n\n    return  torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0) "
  },
  {
    "path": "stllm/processors/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom stllm.processors.base_processor import BaseProcessor\nfrom stllm.processors.blip_processors import (\n    Blip2ImageTrainProcessor,\n    Blip2ImageEvalProcessor,\n    BlipCaptionProcessor,\n)\n\nfrom stllm.common.registry import registry\n\n__all__ = [\n    \"BaseProcessor\",\n    \"Blip2ImageTrainProcessor\",\n    \"Blip2ImageEvalProcessor\",\n    \"BlipCaptionProcessor\",\n]\n\n\ndef load_processor(name, cfg=None):\n    \"\"\"\n    Example\n\n    >>> processor = load_processor(\"alpro_video_train\", cfg=None)\n    \"\"\"\n    processor = registry.get_processor_class(name).from_config(cfg)\n\n    return processor\n"
  },
  {
    "path": "stllm/processors/base_processor.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom omegaconf import OmegaConf\n\n\nclass BaseProcessor:\n    def __init__(self):\n        self.transform = lambda x: x\n        return\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        return cls()\n\n    def build(self, **kwargs):\n        cfg = OmegaConf.create(kwargs)\n\n        return self.from_config(cfg)\n"
  },
  {
    "path": "stllm/processors/blip_processors.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport re\n\nfrom stllm.common.registry import registry\nfrom stllm.processors.base_processor import BaseProcessor\nfrom stllm.processors.randaugment import RandomAugment\nfrom stllm.processors.video_transform import SampleFrames\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import InterpolationMode\n\n\nclass BlipImageBaseProcessor(BaseProcessor):\n    def __init__(self, mean=None, std=None):\n        if mean is None:\n            mean = (0.48145466, 0.4578275, 0.40821073)\n        if std is None:\n            std = (0.26862954, 0.26130258, 0.27577711)\n\n        self.normalize = transforms.Normalize(mean, std)\n\n\n@registry.register_processor(\"blip_caption\")\nclass BlipCaptionProcessor(BaseProcessor):\n    def __init__(self, prompt=\"\", max_words=50):\n        self.prompt = prompt\n        self.max_words = max_words\n\n    def __call__(self, caption):\n        caption = self.prompt + self.pre_caption(caption)\n\n        return caption\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        prompt = cfg.get(\"prompt\", \"\")\n        max_words = cfg.get(\"max_words\", 50)\n\n        return cls(prompt=prompt, max_words=max_words)\n\n    def pre_caption(self, caption):\n        caption = re.sub(\n            r\"([.!\\\"()*#:;~])\",\n            \" \",\n            caption.lower(),\n        )\n        caption = re.sub(\n            r\"\\s{2,}\",\n            \" \",\n            caption,\n        )\n        caption = caption.rstrip(\"\\n\")\n        caption = caption.strip(\" \")\n\n        # truncate caption\n        caption_words = caption.split(\" \")\n        if len(caption_words) > self.max_words:\n            caption = \" \".join(caption_words[: self.max_words])\n\n        return caption\n\n\n@registry.register_processor(\"blip2_image_train\")\nclass Blip2ImageTrainProcessor(BlipImageBaseProcessor):\n    def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 224)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        min_scale = cfg.get(\"min_scale\", 0.5)\n        max_scale = cfg.get(\"max_scale\", 1.0)\n\n        return cls(\n            image_size=image_size,\n            mean=mean,\n            std=std,\n            min_scale=min_scale,\n            max_scale=max_scale,\n        )\n\n@registry.register_processor(\"blip2_video_train\")\nclass Blip2VideoTrainProcessor(BaseProcessor):\n    def __init__(self, num_frames=16, test_mode=True):\n        self.num_frames = num_frames\n\n        self.transform = transforms.Compose(\n            [\n                SampleFrames(clip_len=1,frame_interval=1,num_clips=num_frames,test_mode=test_mode),\n                transforms.ToTensor(),\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        num_frames = cfg.get(\"num_frames\", 16)\n        test_mode = cfg.get(\"test_mode\", True)\n        return cls(num_frames=num_frames, test_mode=test_mode)\n\n\n@registry.register_processor(\"blip2_image_eval\")\nclass Blip2ImageEvalProcessor(BlipImageBaseProcessor):\n    def __init__(self, image_size=224, mean=None, std=None):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(\n                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    @classmethod\n    def from_config(cls, cfg=None):\n        if cfg is None:\n            cfg = OmegaConf.create()\n\n        image_size = cfg.get(\"image_size\", 224)\n\n        mean = cfg.get(\"mean\", None)\n        std = cfg.get(\"std\", None)\n\n        return cls(image_size=image_size, mean=mean, std=std)"
  },
  {
    "path": "stllm/processors/randaugment.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport cv2\nimport numpy as np\n\nimport torch\n\n\n## aug functions\ndef identity_func(img):\n    return img\n\n\ndef autocontrast_func(img, cutoff=0):\n    \"\"\"\n    same output as PIL.ImageOps.autocontrast\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        n = ch.size\n        cut = cutoff * n // 100\n        if cut == 0:\n            high, low = ch.max(), ch.min()\n        else:\n            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n            low = np.argwhere(np.cumsum(hist) > cut)\n            low = 0 if low.shape[0] == 0 else low[0]\n            high = np.argwhere(np.cumsum(hist[::-1]) > cut)\n            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]\n        if high <= low:\n            table = np.arange(n_bins)\n        else:\n            scale = (n_bins - 1) / (high - low)\n            offset = -low * scale\n            table = np.arange(n_bins) * scale + offset\n            table[table < 0] = 0\n            table[table > n_bins - 1] = n_bins - 1\n        table = table.clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef equalize_func(img):\n    \"\"\"\n    same output as PIL.ImageOps.equalize\n    PIL's implementation is different from cv2.equalize\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n        non_zero_hist = hist[hist != 0].reshape(-1)\n        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)\n        if step == 0:\n            return ch\n        n = np.empty_like(hist)\n        n[0] = step // 2\n        n[1:] = hist[:-1]\n        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef rotate_func(img, degree, fill=(0, 0, 0)):\n    \"\"\"\n    like PIL, rotate by degree, not radians\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    center = W / 2, H / 2\n    M = cv2.getRotationMatrix2D(center, degree, 1)\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)\n    return out\n\n\ndef solarize_func(img, thresh=128):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    table = np.array([el if el < thresh else 255 - el for el in range(256)])\n    table = table.clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef color_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Color\n    \"\"\"\n    ## implementation according to PIL definition, quite slow\n    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]\n    #  out = blend(degenerate, img, factor)\n    #  M = (\n    #      np.eye(3) * factor\n    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)\n    #  )[np.newaxis, np.newaxis, :]\n    M = np.float32(\n        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]\n    ) * factor + np.float32([[0.114], [0.587], [0.299]])\n    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)\n    return out\n\n\ndef contrast_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))\n    table = (\n        np.array([(el - mean) * factor + mean for el in range(256)])\n        .clip(0, 255)\n        .astype(np.uint8)\n    )\n    out = table[img]\n    return out\n\n\ndef brightness_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef sharpness_func(img, factor):\n    \"\"\"\n    The differences the this result and PIL are all on the 4 boundaries, the center\n    areas are same\n    \"\"\"\n    kernel = np.ones((3, 3), dtype=np.float32)\n    kernel[1][1] = 5\n    kernel /= 13\n    degenerate = cv2.filter2D(img, -1, kernel)\n    if factor == 0.0:\n        out = degenerate\n    elif factor == 1.0:\n        out = img\n    else:\n        out = img.astype(np.float32)\n        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]\n        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)\n        out = out.astype(np.uint8)\n    return out\n\n\ndef shear_x_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, factor, 0], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_x_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, -offset], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_y_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [0, 1, -offset]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef posterize_func(img, bits):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))\n    return out\n\n\ndef shear_y_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [factor, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef cutout_func(img, pad_size, replace=(0, 0, 0)):\n    replace = np.array(replace, dtype=np.uint8)\n    H, W = img.shape[0], img.shape[1]\n    rh, rw = np.random.random(2)\n    pad_size = pad_size // 2\n    ch, cw = int(rh * H), int(rw * W)\n    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)\n    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)\n    out = img.copy()\n    out[x1:x2, y1:y2, :] = replace\n    return out\n\n\n### level to args\ndef enhance_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        return ((level / MAX_LEVEL) * 1.8 + 0.1,)\n\n    return level_to_args\n\n\ndef shear_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 0.3\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef translate_level_to_args(translate_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * float(translate_const)\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * cutout_const)\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef solarize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 256)\n        return (level,)\n\n    return level_to_args\n\n\ndef none_level_to_args(level):\n    return ()\n\n\ndef posterize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 4)\n        return (level,)\n\n    return level_to_args\n\n\ndef rotate_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 30\n        if np.random.random() < 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\nfunc_dict = {\n    \"Identity\": identity_func,\n    \"AutoContrast\": autocontrast_func,\n    \"Equalize\": equalize_func,\n    \"Rotate\": rotate_func,\n    \"Solarize\": solarize_func,\n    \"Color\": color_func,\n    \"Contrast\": contrast_func,\n    \"Brightness\": brightness_func,\n    \"Sharpness\": sharpness_func,\n    \"ShearX\": shear_x_func,\n    \"TranslateX\": translate_x_func,\n    \"TranslateY\": translate_y_func,\n    \"Posterize\": posterize_func,\n    \"ShearY\": shear_y_func,\n}\n\ntranslate_const = 10\nMAX_LEVEL = 10\nreplace_value = (128, 128, 128)\narg_dict = {\n    \"Identity\": none_level_to_args,\n    \"AutoContrast\": none_level_to_args,\n    \"Equalize\": none_level_to_args,\n    \"Rotate\": rotate_level_to_args(MAX_LEVEL, replace_value),\n    \"Solarize\": solarize_level_to_args(MAX_LEVEL),\n    \"Color\": enhance_level_to_args(MAX_LEVEL),\n    \"Contrast\": enhance_level_to_args(MAX_LEVEL),\n    \"Brightness\": enhance_level_to_args(MAX_LEVEL),\n    \"Sharpness\": enhance_level_to_args(MAX_LEVEL),\n    \"ShearX\": shear_level_to_args(MAX_LEVEL, replace_value),\n    \"TranslateX\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"TranslateY\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"Posterize\": posterize_level_to_args(MAX_LEVEL),\n    \"ShearY\": shear_level_to_args(MAX_LEVEL, replace_value),\n}\n\n\nclass RandomAugment(object):\n    def __init__(self, N=2, M=10, isPIL=False, augs=[]):\n        self.N = N\n        self.M = M\n        self.isPIL = isPIL\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N)\n        return [(op, 0.5, self.M) for op in sampled_ops]\n\n    def __call__(self, img):\n        if self.isPIL:\n            img = np.array(img)\n        ops = self.get_random_ops()\n        for name, prob, level in ops:\n            if np.random.random() > prob:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return img\n\n\nclass VideoRandomAugment(object):\n    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):\n        self.N = N\n        self.M = M\n        self.p = p\n        self.tensor_in_tensor_out = tensor_in_tensor_out\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N, replace=False)\n        return [(op, self.M) for op in sampled_ops]\n\n    def __call__(self, frames):\n        assert (\n            frames.shape[-1] == 3\n        ), \"Expecting last dimension for 3-channels RGB (b, h, w, c).\"\n\n        if self.tensor_in_tensor_out:\n            frames = frames.numpy().astype(np.uint8)\n\n        num_frames = frames.shape[0]\n\n        ops = num_frames * [self.get_random_ops()]\n        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]\n\n        frames = torch.stack(\n            list(map(self._aug, frames, ops, apply_or_not)), dim=0\n        ).float()\n\n        return frames\n\n    def _aug(self, img, ops, apply_or_not):\n        for i, (name, level) in enumerate(ops):\n            if not apply_or_not[i]:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return torch.from_numpy(img)\n\n\nif __name__ == \"__main__\":\n    a = RandomAugment()\n    img = np.random.randn(32, 32, 3)\n    a(img)\n"
  },
  {
    "path": "stllm/processors/video_transform.py",
    "content": "import numpy as np\n\nclass SampleFrames:\n    \"\"\"Sample frames from the video.\n\n    Required Keys:\n\n        - total_frames\n        - start_index\n\n    Added Keys:\n\n        - frame_inds\n        - frame_interval\n        - num_clips\n\n    Args:\n        clip_len (int): Frames of each sampled output clip.\n        frame_interval (int): Temporal interval of adjacent sampled frames.\n            Defaults to 1.\n        num_clips (int): Number of clips to be sampled. Default: 1.\n        temporal_jitter (bool): Whether to apply temporal jittering.\n            Defaults to False.\n        twice_sample (bool): Whether to use twice sample when testing.\n            If set to True, it will sample frames with and without fixed shift,\n            which is commonly used for testing in TSM model. Defaults to False.\n        out_of_bound_opt (str): The way to deal with out of bounds frame\n            indexes. Available options are 'loop', 'repeat_last'.\n            Defaults to 'loop'.\n        test_mode (bool): Store True when building test or validation dataset.\n            Defaults to False.\n        keep_tail_frames (bool): Whether to keep tail frames when sampling.\n            Defaults to False.\n        target_fps (optional, int): Convert input videos with arbitrary frame\n            rates to the unified target FPS before sampling frames. If\n            ``None``, the frame rate will not be adjusted. Defaults to\n            ``None``.\n    \"\"\"\n\n    def __init__(self,\n                 clip_len: int,\n                 frame_interval: int = 1,\n                 num_clips: int = 1,\n                 twice_sample: bool = False,\n                 out_of_bound_opt: str = 'loop',\n                 test_mode: bool = False,\n                 keep_tail_frames: bool = False,\n                 target_fps = None,\n                 **kwargs) -> None:\n\n        self.clip_len = clip_len\n        self.frame_interval = frame_interval\n        self.num_clips = num_clips\n        self.twice_sample = twice_sample\n        self.out_of_bound_opt = out_of_bound_opt\n        self.test_mode = test_mode\n        self.keep_tail_frames = keep_tail_frames\n        self.target_fps = target_fps\n        assert self.out_of_bound_opt in ['loop', 'repeat_last']\n\n    def _get_train_clips(self, num_frames: int,\n                         ori_clip_len: float) -> np.array:\n        \"\"\"Get clip offsets in train mode.\n\n        It will calculate the average interval for selected frames,\n        and randomly shift them within offsets between [0, avg_interval].\n        If the total number of frames is smaller than clips num or origin\n        frames length, it will return all zero indices.\n\n        Args:\n            num_frames (int): Total number of frame in the video.\n            ori_clip_len (float): length of original sample clip.\n\n        Returns:\n            np.ndarray: Sampled frame indices in train mode.\n        \"\"\"\n\n        if self.keep_tail_frames:\n            avg_interval = (num_frames - ori_clip_len + 1) / float(\n                self.num_clips)\n            if num_frames > ori_clip_len - 1:\n                base_offsets = np.arange(self.num_clips) * avg_interval\n                clip_offsets = (base_offsets + np.random.uniform(\n                    0, avg_interval, self.num_clips)).astype(np.int32)\n            else:\n                clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)\n        else:\n            avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips\n\n            if avg_interval > 0:\n                base_offsets = np.arange(self.num_clips) * avg_interval\n                clip_offsets = base_offsets + np.random.randint(\n                    avg_interval, size=self.num_clips)\n            elif num_frames > max(self.num_clips, ori_clip_len):\n                clip_offsets = np.sort(\n                    np.random.randint(\n                        num_frames - ori_clip_len + 1, size=self.num_clips))\n            elif avg_interval == 0:\n                ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips\n                clip_offsets = np.around(np.arange(self.num_clips) * ratio)\n            else:\n                clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)\n\n        return clip_offsets\n\n    def _get_test_clips(self, num_frames: int,\n                        ori_clip_len: float) -> np.array:\n        \"\"\"Get clip offsets in test mode.\n\n        If the total number of frames is\n        not enough, it will return all zero indices.\n\n        Args:\n            num_frames (int): Total number of frame in the video.\n            ori_clip_len (float): length of original sample clip.\n\n        Returns:\n            np.ndarray: Sampled frame indices in test mode.\n        \"\"\"\n        if self.clip_len == 1:  # 2D recognizer\n            # assert self.frame_interval == 1\n            avg_interval = num_frames / float(self.num_clips)\n            base_offsets = np.arange(self.num_clips) * avg_interval\n            clip_offsets = base_offsets + avg_interval / 2.0\n            if self.twice_sample:\n                clip_offsets = np.concatenate([clip_offsets, base_offsets])\n        else:  # 3D recognizer\n            max_offset = max(num_frames - ori_clip_len, 0)\n            if self.twice_sample:\n                num_clips = self.num_clips * 2\n            else:\n                num_clips = self.num_clips\n            if num_clips > 1:\n                num_segments = self.num_clips - 1\n                # align test sample strategy with `PySlowFast` repo\n                if self.target_fps is not None:\n                    offset_between = np.floor(max_offset / float(num_segments))\n                    clip_offsets = np.arange(num_clips) * offset_between\n                else:\n                    offset_between = max_offset / float(num_segments)\n                    clip_offsets = np.arange(num_clips) * offset_between\n                    clip_offsets = np.round(clip_offsets)\n            else:\n                clip_offsets = np.array([max_offset // 2])\n        return clip_offsets\n\n    def _sample_clips(self, num_frames: int, ori_clip_len: float) -> np.array:\n        \"\"\"Choose clip offsets for the video in a given mode.\n\n        Args:\n            num_frames (int): Total number of frame in the video.\n\n        Returns:\n            np.ndarray: Sampled frame indices.\n        \"\"\"\n        if self.test_mode:\n            clip_offsets = self._get_test_clips(num_frames, ori_clip_len)\n        else:\n            clip_offsets = self._get_train_clips(num_frames, ori_clip_len)\n\n        return clip_offsets\n\n    def _get_ori_clip_len(self, fps_scale_ratio: float) -> float:\n        \"\"\"calculate length of clip segment for different strategy.\n\n        Args:\n            fps_scale_ratio (float): Scale ratio to adjust fps.\n        \"\"\"\n        if self.target_fps is not None:\n            # align test sample strategy with `PySlowFast` repo\n            ori_clip_len = self.clip_len * self.frame_interval\n            ori_clip_len = np.maximum(1, ori_clip_len * fps_scale_ratio)\n        elif self.test_mode:\n            ori_clip_len = (self.clip_len - 1) * self.frame_interval + 1\n        else:\n            ori_clip_len = self.clip_len * self.frame_interval\n\n        return ori_clip_len\n\n    def __call__(self, x):\n        \"\"\"Perform the SampleFrames loading.\n\n        Args:\n            results (dict): The resulting dict to be modified and passed\n                to the next transform in pipeline.\n        \"\"\"\n        total_frames = x.shape[0]\n        # if can't get fps, same value of `fps` and `target_fps`\n        # will perform nothing\n        fps_scale_ratio = 1.0\n        \n        ori_clip_len = self._get_ori_clip_len(fps_scale_ratio)\n        clip_offsets = self._sample_clips(total_frames, ori_clip_len)\n\n        if self.target_fps:\n            frame_inds = clip_offsets[:, None] + np.linspace(\n                0, ori_clip_len - 1, self.clip_len).astype(np.int32)\n        else:\n            frame_inds = clip_offsets[:, None] + np.arange(\n                self.clip_len)[None, :] * self.frame_interval\n            frame_inds = np.concatenate(frame_inds)\n\n\n        frame_inds = frame_inds.reshape((-1, self.clip_len))\n        if self.out_of_bound_opt == 'loop':\n            frame_inds = np.mod(frame_inds, total_frames)\n        elif self.out_of_bound_opt == 'repeat_last':\n            safe_inds = frame_inds < total_frames\n            unsafe_inds = 1 - safe_inds\n            last_ind = np.max(safe_inds * frame_inds, axis=1)\n            new_inds = (safe_inds * frame_inds + (unsafe_inds.T * last_ind).T)\n            frame_inds = new_inds\n        else:\n            raise ValueError('Illegal out_of_bound option.')\n        \n        frame_inds = np.concatenate(frame_inds).astype(np.int32)\n        \n        results = x[frame_inds]\n        results = results.transpose((1, 2, 0))\n        return results"
  },
  {
    "path": "stllm/runners/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom stllm.runners.runner_base import RunnerBase\n\n__all__ = [\"RunnerBase\"]\n"
  },
  {
    "path": "stllm/runners/runner_base.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport datetime\nimport json\nimport logging\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\nimport webdataset as wds\nfrom stllm.common.dist_utils import (\n    download_cached_file,\n    get_rank,\n    get_world_size,\n    is_main_process,\n    main_process,\n)\nfrom stllm.common.registry import registry\nfrom stllm.common.utils import is_url\nfrom stllm.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset\nfrom stllm.datasets.datasets.dataloader_utils import (\n    IterLoader,\n    MultiIterLoader,\n    PrefetchLoader,\n    MetaLoader,\n)\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader, DistributedSampler\n\n\n@registry.register_runner(\"runner_base\")\nclass RunnerBase:\n    \"\"\"\n    A runner class to train and evaluate a model given a task and datasets.\n\n    The runner uses pytorch distributed data parallel by default. Future release\n    will support other distributed frameworks.\n    \"\"\"\n\n    def __init__(self, cfg, task, model, datasets, job_id):\n        self.config = cfg\n        self.job_id = job_id\n\n        self.task = task\n        self.datasets = datasets\n\n        self._model = model\n\n        self._wrapped_model = None\n        self._device = None\n        self._optimizer = None\n        self._scaler = None\n        self._dataloaders = None\n        self._lr_sched = None\n\n        self.start_epoch = 0\n\n        # self.setup_seeds()\n        self.setup_output_dir()\n\n    @property\n    def device(self):\n        if self._device is None:\n            self._device = torch.device(self.config.run_cfg.device)\n\n        return self._device\n\n    @property\n    def use_distributed(self):\n        return self.config.run_cfg.distributed\n\n    @property\n    def model(self):\n        \"\"\"\n        A property to get the DDP-wrapped model on the device.\n        \"\"\"\n        # move model to device\n        if self._model.device != self.device:\n            self._model = self._model.to(self.device)\n\n            # distributed training wrapper\n            if self.use_distributed:\n                if self._wrapped_model is None:\n                    self._wrapped_model = DDP(\n                        self._model, device_ids=[self.config.run_cfg.gpu]\n                    )\n            else:\n                self._wrapped_model = self._model\n\n        return self._wrapped_model\n\n    @property\n    def optimizer(self):\n        # TODO make optimizer class and configurations\n        if self._optimizer is None:\n            num_parameters = 0\n            p_wd, p_non_wd = [], []\n            for n, p in self.model.named_parameters():\n                if not p.requires_grad:\n                    continue  # frozen weights\n                print(n)\n                if p.ndim < 2 or \"bias\" in n or \"ln\" in n or \"bn\" in n:\n                    p_non_wd.append(p)\n                else:\n                    p_wd.append(p)\n                num_parameters += p.data.nelement()\n            logging.info(\"number of trainable parameters: %d\" % num_parameters)\n            optim_params = [\n                {\n                    \"params\": p_wd,\n                    \"weight_decay\": float(self.config.run_cfg.weight_decay),\n                },\n                {\"params\": p_non_wd, \"weight_decay\": 0},\n            ]\n            beta2 = self.config.run_cfg.get(\"beta2\", 0.999)\n            self._optimizer = torch.optim.AdamW(\n                optim_params,\n                lr=float(self.config.run_cfg.init_lr),\n                weight_decay=float(self.config.run_cfg.weight_decay),\n                betas=(0.9, beta2),\n            )\n\n        return self._optimizer\n\n    @property\n    def scaler(self):\n        amp = self.config.run_cfg.get(\"amp\", False)\n\n        if amp:\n            if self._scaler is None:\n                self._scaler = torch.cuda.amp.GradScaler()\n\n        return self._scaler\n\n    @property\n    def lr_scheduler(self):\n        \"\"\"\n        A property to get and create learning rate scheduler by split just in need.\n        \"\"\"\n        if self._lr_sched is None:\n            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)\n\n            # max_epoch = self.config.run_cfg.max_epoch\n            max_epoch = self.max_epoch\n            # min_lr = self.config.run_cfg.min_lr\n            min_lr = self.min_lr\n            # init_lr = self.config.run_cfg.init_lr\n            init_lr = self.init_lr\n\n            # optional parameters\n            decay_rate = self.config.run_cfg.get(\"lr_decay_rate\", None)\n            warmup_start_lr = self.config.run_cfg.get(\"warmup_lr\", -1)\n            warmup_steps = self.config.run_cfg.get(\"warmup_steps\", 0)\n            iters_per_epoch = self.config.run_cfg.get(\"iters_per_epoch\", None)\n\n            if iters_per_epoch is None:\n                try:\n                    #iters_per_epoch = len(self.dataloaders['train'])\n                    #iters_per_epoch = len(self.dataloaders['train'].loaders[0])\n                    iters_per_epoch = sum([len(i) for i in self.dataloaders['train'].loaders])\n                except (AttributeError, TypeError):\n                    iters_per_epoch = 10000\n\n            self._lr_sched = lr_sched_cls(\n                optimizer=self.optimizer,\n                max_epoch=max_epoch,\n                iters_per_epoch=iters_per_epoch,\n                min_lr=min_lr,\n                init_lr=init_lr,\n                decay_rate=decay_rate,\n                warmup_start_lr=warmup_start_lr,\n                warmup_steps=warmup_steps,\n            )\n\n        return self._lr_sched\n\n    @property\n    def dataloaders(self) -> dict:\n        \"\"\"\n        A property to get and create dataloaders by split just in need.\n\n        If no train_dataset_ratio is provided, concatenate map-style datasets and\n        chain wds.DataPipe datasets separately. Training set becomes a tuple\n        (ConcatDataset, ChainDataset), both are optional but at least one of them is\n        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.\n\n        If train_dataset_ratio is provided, create a MultiIterLoader to sample\n        each dataset by ratios during training.\n\n        Currently do not support multiple datasets for validation and test.\n\n        Returns:\n            dict: {split_name: (tuples of) dataloader}\n        \"\"\"\n        if self._dataloaders is None:\n\n            # concatenate map-style datasets and chain wds.DataPipe datasets separately\n            # training set becomes a tuple (ConcatDataset, ChainDataset), both are\n            # optional but at least one of them is required. The resultant ConcatDataset\n            # and ChainDataset will be sampled evenly.\n            logging.info(\n                \"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline).\"\n            )\n\n            datasets = reorg_datasets_by_split(self.datasets)\n            self.datasets = datasets\n            # self.datasets = concat_datasets(datasets)\n\n            # print dataset statistics after concatenation/chaining\n            for split_name in self.datasets:\n                if isinstance(self.datasets[split_name], tuple) or isinstance(\n                    self.datasets[split_name], list\n                ):\n                    # mixed wds.DataPipeline and torch.utils.data.Dataset\n                    num_records = sum(\n                        [\n                            len(d)\n                            if not type(d) in [wds.DataPipeline, ChainDataset]\n                            else 0\n                            for d in self.datasets[split_name]\n                        ]\n                    )\n\n                else:\n                    if hasattr(self.datasets[split_name], \"__len__\"):\n                        # a single map-style dataset\n                        num_records = len(self.datasets[split_name])\n                    else:\n                        # a single wds.DataPipeline\n                        num_records = -1\n                        logging.info(\n                            \"Only a single wds.DataPipeline dataset, no __len__ attribute.\"\n                        )\n\n                if num_records >= 0:\n                    logging.info(\n                        \"Loaded {} records for {} split from the dataset.\".format(\n                            num_records, split_name\n                        )\n                    )\n\n            # create dataloaders\n            split_names = sorted(self.datasets.keys())\n\n            datasets = [self.datasets[split] for split in split_names]\n            is_trains = [split in self.train_splits for split in split_names]\n\n            batch_sizes = [\n                self.config.run_cfg.batch_size_train\n                if split == \"train\"\n                else self.config.run_cfg.batch_size_eval\n                for split in split_names\n            ]\n\n            collate_fns = []\n            for dataset in datasets:\n                if isinstance(dataset, tuple) or isinstance(dataset, list):\n                    collate_fns.append([getattr(d, \"collater\", None) for d in dataset])\n                else:\n                    collate_fns.append(getattr(dataset, \"collater\", None))\n\n            dataloaders = self.create_loaders(\n                datasets=datasets,\n                num_workers=self.config.run_cfg.num_workers,\n                batch_sizes=batch_sizes,\n                is_trains=is_trains,\n                collate_fns=collate_fns,\n            )\n\n            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}\n\n        return self._dataloaders\n\n    @property\n    def cuda_enabled(self):\n        return self.device.type == \"cuda\"\n\n    @property\n    def max_epoch(self):\n        return int(self.config.run_cfg.max_epoch)\n\n    @property\n    def log_freq(self):\n        log_freq = self.config.run_cfg.get(\"log_freq\", 50)\n        return int(log_freq)\n\n    @property\n    def init_lr(self):\n        return float(self.config.run_cfg.init_lr)\n\n    @property\n    def min_lr(self):\n        return float(self.config.run_cfg.min_lr)\n\n    @property\n    def accum_grad_iters(self):\n        return int(self.config.run_cfg.get(\"accum_grad_iters\", 1))\n\n    @property\n    def valid_splits(self):\n        valid_splits = self.config.run_cfg.get(\"valid_splits\", [])\n\n        if len(valid_splits) == 0:\n            logging.info(\"No validation splits found.\")\n\n        return valid_splits\n\n    @property\n    def test_splits(self):\n        test_splits = self.config.run_cfg.get(\"test_splits\", [])\n\n        return test_splits\n\n    @property\n    def train_splits(self):\n        train_splits = self.config.run_cfg.get(\"train_splits\", [])\n\n        if len(train_splits) == 0:\n            logging.info(\"Empty train splits.\")\n\n        return train_splits\n\n    @property\n    def evaluate_only(self):\n        \"\"\"\n        Set to True to skip training.\n        \"\"\"\n        return self.config.run_cfg.evaluate\n\n    @property\n    def use_dist_eval_sampler(self):\n        return self.config.run_cfg.get(\"use_dist_eval_sampler\", True)\n\n    @property\n    def resume_ckpt_path(self):\n        return self.config.run_cfg.get(\"resume_ckpt_path\", None)\n\n    @property\n    def train_loader(self):\n        train_dataloader = self.dataloaders[\"train\"]\n\n        return train_dataloader\n\n    def setup_output_dir(self):\n        lib_root = Path(registry.get_path(\"library_root\"))\n\n        output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id\n        result_dir = output_dir / \"result\"\n\n        output_dir.mkdir(parents=True, exist_ok=True)\n        result_dir.mkdir(parents=True, exist_ok=True)\n\n        registry.register_path(\"result_dir\", str(result_dir))\n        registry.register_path(\"output_dir\", str(output_dir))\n\n        self.result_dir = result_dir\n        self.output_dir = output_dir\n\n    def train(self):\n        start_time = time.time()\n        best_agg_metric = 0\n        best_epoch = 0\n\n        self.log_config()\n\n        # resume from checkpoint if specified\n        if not self.evaluate_only and self.resume_ckpt_path is not None:\n            self._load_checkpoint(self.resume_ckpt_path)\n\n        for cur_epoch in range(self.start_epoch, self.max_epoch):\n            # training phase\n            if not self.evaluate_only:\n                logging.info(\"Start training\")\n                train_stats = self.train_epoch(cur_epoch)\n                self.log_stats(split_name=\"train\", stats=train_stats)\n\n            # evaluation phase\n            if len(self.valid_splits) > 0:\n                for split_name in self.valid_splits:\n                    logging.info(\"Evaluating on {}.\".format(split_name))\n\n                    val_log = self.eval_epoch(\n                        split_name=split_name, cur_epoch=cur_epoch\n                    )\n                    if val_log is not None:\n                        if is_main_process():\n                            assert (\n                                \"agg_metrics\" in val_log\n                            ), \"No agg_metrics found in validation log.\"\n\n                            agg_metrics = val_log[\"agg_metrics\"]\n                            if agg_metrics > best_agg_metric and split_name == \"val\":\n                                best_epoch, best_agg_metric = cur_epoch, agg_metrics\n\n                                self._save_checkpoint(cur_epoch, is_best=True)\n\n                            val_log.update({\"best_epoch\": best_epoch})\n                            self.log_stats(val_log, split_name)\n\n            else:\n                # if no validation split is provided, we just save the checkpoint at the end of each epoch.\n                if not self.evaluate_only:\n                    self._save_checkpoint(cur_epoch, is_best=False)\n\n            if self.evaluate_only:\n                break\n\n            if self.config.run_cfg.distributed:\n                dist.barrier()\n\n        # testing phase\n        test_epoch = \"best\" if len(self.valid_splits) > 0 else cur_epoch\n        self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)\n\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        logging.info(\"Training time {}\".format(total_time_str))\n\n    def evaluate(self, cur_epoch=\"best\", skip_reload=False):\n        test_logs = dict()\n\n        if len(self.test_splits) > 0:\n            for split_name in self.test_splits:\n                test_logs[split_name] = self.eval_epoch(\n                    split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload\n                )\n\n            return test_logs\n\n    def train_epoch(self, epoch):\n        # train\n        self.model.train()\n\n        return self.task.train_epoch(\n            epoch=epoch,\n            model=self.model,\n            data_loader=self.train_loader,\n            optimizer=self.optimizer,\n            scaler=self.scaler,\n            lr_scheduler=self.lr_scheduler,\n            cuda_enabled=self.cuda_enabled,\n            log_freq=self.log_freq,\n            accum_grad_iters=self.accum_grad_iters,\n        )\n\n    @torch.no_grad()\n    def eval_epoch(self, split_name, cur_epoch, skip_reload=False):\n        \"\"\"\n        Evaluate the model on a given split.\n\n        Args:\n            split_name (str): name of the split to evaluate on.\n            cur_epoch (int): current epoch.\n            skip_reload_best (bool): whether to skip reloading the best checkpoint.\n                During training, we will reload the best checkpoint for validation.\n                During testing, we will use provided weights and skip reloading the best checkpoint .\n        \"\"\"\n        data_loader = self.dataloaders.get(split_name, None)\n        assert data_loader, \"data_loader for split {} is None.\".format(split_name)\n\n        # TODO In validation, you need to compute loss as well as metrics\n        # TODO consider moving to model.before_evaluation()\n        model = self.unwrap_dist_model(self.model)\n        if not skip_reload and cur_epoch == \"best\":\n            model = self._reload_best_model(model)\n        model.eval()\n\n        self.task.before_evaluation(\n            model=model,\n            dataset=self.datasets[split_name],\n        )\n        results = self.task.evaluation(model, data_loader)\n\n        if results is not None:\n            return self.task.after_evaluation(\n                val_result=results,\n                split_name=split_name,\n                epoch=cur_epoch,\n            )\n\n    def unwrap_dist_model(self, model):\n        if self.use_distributed:\n            return model.module\n        else:\n            return model\n\n    def create_loaders(\n        self,\n        datasets,\n        num_workers,\n        batch_sizes,\n        is_trains,\n        collate_fns,\n        dataset_ratios=None,\n    ):\n        \"\"\"\n        Create dataloaders for training and validation.\n        \"\"\"\n\n        def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):\n            # create a single dataloader for each split\n            if isinstance(dataset, ChainDataset) or isinstance(\n                dataset, wds.DataPipeline\n            ):\n                # wds.WebdDataset instance are chained together\n                # webdataset.DataPipeline has its own sampler and collate_fn\n                loader = iter(\n                    DataLoader(\n                        dataset,\n                        batch_size=bsz,\n                        num_workers=num_workers,\n                        pin_memory=True,\n                    )\n                )\n            else:\n                # map-style dataset are concatenated together\n                # setup distributed sampler\n                if self.use_distributed:\n                    sampler = DistributedSampler(\n                        dataset,\n                        shuffle=is_train,\n                        num_replicas=get_world_size(),\n                        rank=get_rank(),\n                    )\n                    if not self.use_dist_eval_sampler:\n                        # e.g. retrieval evaluation\n                        sampler = sampler if is_train else None\n                else:\n                    sampler = None\n\n                loader = DataLoader(\n                    dataset,\n                    batch_size=bsz,\n                    num_workers=num_workers,\n                    pin_memory=True,\n                    sampler=sampler,\n                    shuffle=sampler is None and is_train,\n                    collate_fn=collate_fn,\n                    drop_last=True if is_train else False,\n                )\n                loader = PrefetchLoader(loader)\n\n                if is_train:\n                    loader = IterLoader(loader, use_distributed=self.use_distributed)\n\n            return loader\n\n        loaders = []\n\n        for dataset, bsz, is_train, collate_fn in zip(\n            datasets, batch_sizes, is_trains, collate_fns\n        ):\n            if isinstance(dataset, list) or isinstance(dataset, tuple):\n                if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:\n                    dataset_ratios = [d.sample_ratio for d in dataset]\n                #loader = MultiIterLoader(\n                #    loaders=[\n                #        _create_loader(d, num_workers, bsz, is_train, collate_fn[i])\n                #        for i, d in enumerate(dataset)\n                #    ],\n                #    ratios=dataset_ratios,\n                #)\n                loader = MetaLoader(\n                    loaders=[\n                        _create_loader(d, num_workers, bsz, is_train, collate_fn[i])\n                        for i, d in enumerate(dataset)\n                    ]\n                )\n            else:\n                loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)\n\n            loaders.append(loader)\n\n        return loaders\n\n    @main_process\n    def _save_checkpoint(self, cur_epoch, is_best=False):\n        \"\"\"\n        Save the checkpoint at the current epoch.\n        \"\"\"\n        model_no_ddp = self.unwrap_dist_model(self.model)\n        param_grad_dic = {\n            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()\n        }\n        state_dict = model_no_ddp.state_dict()\n        for k in list(state_dict.keys()):\n            if k in param_grad_dic.keys() and not param_grad_dic[k]:\n                # delete parameters that do not require gradient\n                del state_dict[k]\n        save_obj = {\n            \"model\": state_dict,\n            \"optimizer\": self.optimizer.state_dict(),\n            \"config\": self.config.to_dict(),\n            \"scaler\": self.scaler.state_dict() if self.scaler else None,\n            \"epoch\": cur_epoch,\n        }\n        save_to = os.path.join(\n            self.output_dir,\n            \"checkpoint_{}.pth\".format(\"best\" if is_best else cur_epoch),\n        )\n        logging.info(\"Saving checkpoint at epoch {} to {}.\".format(cur_epoch, save_to))\n        torch.save(save_obj, save_to)\n\n    def _reload_best_model(self, model):\n        \"\"\"\n        Load the best checkpoint for evaluation.\n        \"\"\"\n        checkpoint_path = os.path.join(self.output_dir, \"checkpoint_best.pth\")\n\n        logging.info(\"Loading checkpoint from {}.\".format(checkpoint_path))\n        checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n        try:\n            model.load_state_dict(checkpoint[\"model\"])\n        except RuntimeError as e:\n            logging.warning(\n                \"\"\"\n                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.\n                Trying to load the model with strict=False.\n                \"\"\"\n            )\n            model.load_state_dict(checkpoint[\"model\"], strict=False)\n        return model\n\n    def _load_checkpoint(self, url_or_filename):\n        \"\"\"\n        Resume from a checkpoint.\n        \"\"\"\n        if is_url(url_or_filename):\n            cached_file = download_cached_file(\n                url_or_filename, check_hash=False, progress=True\n            )\n            checkpoint = torch.load(cached_file, map_location=self.device)\n        elif os.path.isfile(url_or_filename):\n            checkpoint = torch.load(url_or_filename, map_location=self.device)\n        else:\n            raise RuntimeError(\"checkpoint url or path is invalid\")\n\n        state_dict = checkpoint[\"model\"]\n        self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)\n\n        self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n        if self.scaler and \"scaler\" in checkpoint:\n            self.scaler.load_state_dict(checkpoint[\"scaler\"])\n\n        self.start_epoch = checkpoint[\"epoch\"] + 1\n        logging.info(\"Resume checkpoint from {}\".format(url_or_filename))\n\n    @main_process\n    def log_stats(self, stats, split_name):\n        if isinstance(stats, dict):\n            log_stats = {**{f\"{split_name}_{k}\": v for k, v in stats.items()}}\n            with open(os.path.join(self.output_dir, \"log.txt\"), \"a\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n        elif isinstance(stats, list):\n            pass\n\n    @main_process\n    def log_config(self):\n        with open(os.path.join(self.output_dir, \"log.txt\"), \"a\") as f:\n            f.write(json.dumps(self.config.to_dict(), indent=4) + \"\\n\")\n"
  },
  {
    "path": "stllm/tasks/__init__.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom stllm.common.registry import registry\nfrom stllm.tasks.base_task import BaseTask\nfrom stllm.tasks.image_text_pretrain import ImageTextPretrainTask, VideoTextItTask\n\n\ndef setup_task(cfg):\n    assert \"task\" in cfg.run_cfg, \"Task name must be provided.\"\n\n    task_name = cfg.run_cfg.task\n    task = registry.get_task_class(task_name).setup_task(cfg=cfg)\n    assert task is not None, \"Task {} not properly registered.\".format(task_name)\n\n    return task\n\n\n__all__ = [\n    \"BaseTask\",\n    \"ImageTextPretrainTask\",\n    \"VideoTextItTask\",\n]\n"
  },
  {
    "path": "stllm/tasks/base_task.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom stllm.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized\nfrom stllm.common.logger import MetricLogger, SmoothedValue\nfrom stllm.common.registry import registry\nfrom stllm.datasets.data_utils import prepare_sample\n\n\nclass BaseTask:\n    def __init__(self, **kwargs):\n        super().__init__()\n\n        self.inst_id_key = \"instance_id\"\n\n    @classmethod\n    def setup_task(cls, **kwargs):\n        return cls()\n\n    def build_model(self, cfg):\n        model_config = cfg.model_cfg\n\n        model_cls = registry.get_model_class(model_config.arch)\n        return model_cls.from_config(model_config)\n\n    def build_datasets(self, cfg):\n        \"\"\"\n        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.\n        Download dataset and annotations automatically if not exist.\n\n        Args:\n            cfg (common.config.Config): _description_\n\n        Returns:\n            dict: Dictionary of torch.utils.data.Dataset objects by split.\n        \"\"\"\n\n        datasets = dict()\n\n        datasets_config = cfg.datasets_cfg\n\n        assert len(datasets_config) > 0, \"At least one dataset has to be specified.\"\n\n        for name in datasets_config:\n            dataset_config = datasets_config[name]\n\n            builder = registry.get_builder_class(name)(dataset_config)\n            dataset = builder.build_datasets()\n\n            dataset['train'].name = name\n            if 'sample_ratio' in dataset_config:\n                dataset['train'].sample_ratio = dataset_config.sample_ratio\n\n            datasets[name] = dataset\n\n        return datasets\n\n    def train_step(self, model, samples):\n        loss = model(samples)[\"loss\"]\n        return loss\n\n    def valid_step(self, model, samples):\n        raise NotImplementedError\n\n    def before_evaluation(self, model, dataset, **kwargs):\n        model.before_evaluation(dataset=dataset, task_type=type(self))\n\n    def after_evaluation(self, **kwargs):\n        pass\n\n    def inference_step(self):\n        raise NotImplementedError\n\n    def evaluation(self, model, data_loader, cuda_enabled=True):\n        metric_logger = MetricLogger(delimiter=\"  \")\n        header = \"Evaluation\"\n        # TODO make it configurable\n        print_freq = 10\n\n        results = []\n\n        for samples in metric_logger.log_every(data_loader, print_freq, header):\n            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)\n\n            eval_output = self.valid_step(model=model, samples=samples)\n            results.extend(eval_output)\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        return results\n\n    def train_epoch(\n        self,\n        epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        cuda_enabled=False,\n        log_freq=50,\n        accum_grad_iters=1,\n    ):\n        return self._train_inner_loop(\n            epoch=epoch,\n            iters_per_epoch=lr_scheduler.iters_per_epoch,\n            model=model,\n            data_loader=data_loader,\n            optimizer=optimizer,\n            scaler=scaler,\n            lr_scheduler=lr_scheduler,\n            log_freq=log_freq,\n            cuda_enabled=cuda_enabled,\n            accum_grad_iters=accum_grad_iters,\n        )\n\n    def train_iters(\n        self,\n        epoch,\n        start_iters,\n        iters_per_inner_epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        cuda_enabled=False,\n        log_freq=50,\n        accum_grad_iters=1,\n    ):\n        return self._train_inner_loop(\n            epoch=epoch,\n            start_iters=start_iters,\n            iters_per_epoch=iters_per_inner_epoch,\n            model=model,\n            data_loader=data_loader,\n            optimizer=optimizer,\n            scaler=scaler,\n            lr_scheduler=lr_scheduler,\n            log_freq=log_freq,\n            cuda_enabled=cuda_enabled,\n            accum_grad_iters=accum_grad_iters,\n        )\n\n    def _train_inner_loop(\n        self,\n        epoch,\n        iters_per_epoch,\n        model,\n        data_loader,\n        optimizer,\n        lr_scheduler,\n        scaler=None,\n        start_iters=None,\n        log_freq=50,\n        cuda_enabled=False,\n        accum_grad_iters=1,\n    ):\n        \"\"\"\n        An inner training loop compatible with both epoch-based and iter-based training.\n\n        When using epoch-based, training stops after one epoch; when using iter-based,\n        training stops after #iters_per_epoch iterations.\n        \"\"\"\n        use_amp = scaler is not None\n\n        if not hasattr(data_loader, \"__next__\"):\n            # convert to iterator if not already\n            data_loader = iter(data_loader)\n\n        metric_logger = MetricLogger(delimiter=\"  \")\n        metric_logger.add_meter(\"lr\", SmoothedValue(window_size=1, fmt=\"{value:.6f}\"))\n        metric_logger.add_meter(\"loss\", SmoothedValue(window_size=1, fmt=\"{value:.4f}\"))\n\n        # if iter-based runner, schedule lr based on inner epoch.\n        logging.info(\n            \"Start training epoch {}, {} iters per inner epoch.\".format(\n                epoch, iters_per_epoch\n            )\n        )\n        header = \"Train: data epoch: [{}]\".format(epoch)\n        if start_iters is None:\n            # epoch-based runner\n            inner_epoch = epoch\n        else:\n            # In iter-based runner, we schedule the learning rate based on iterations.\n            inner_epoch = start_iters // iters_per_epoch\n            header = header + \"; inner epoch [{}]\".format(inner_epoch)\n\n        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):\n            # if using iter-based runner, we stop after iters_per_epoch iterations.\n            if i >= iters_per_epoch:\n                break\n\n            samples = next(data_loader)\n\n            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)\n            samples.update(\n                {\n                    \"epoch\": inner_epoch,\n                    \"num_iters_per_epoch\": iters_per_epoch,\n                    \"iters\": i,\n                }\n            )\n\n            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)\n\n            with torch.cuda.amp.autocast(enabled=use_amp):\n                loss = self.train_step(model=model, samples=samples)\n\n            # after_train_step()\n            if use_amp:\n                scaler.scale(loss).backward()\n            else:\n                loss.backward()\n\n            # update gradients every accum_grad_iters iterations\n            if (i + 1) % accum_grad_iters == 0:\n                if use_amp:\n                    scaler.step(optimizer)\n                    scaler.update()                     \n                else:    \n                    optimizer.step()\n                optimizer.zero_grad()\n\n            metric_logger.update(loss=loss.item())\n            metric_logger.update(lr=optimizer.param_groups[0][\"lr\"])\n\n        # after train_epoch()\n        # gather the stats from all processes\n        metric_logger.synchronize_between_processes()\n        logging.info(\"Averaged stats: \" + str(metric_logger.global_avg()))\n        return {\n            k: \"{:.3f}\".format(meter.global_avg)\n            for k, meter in metric_logger.meters.items()\n        }\n\n    @staticmethod\n    def save_result(result, result_dir, filename, remove_duplicate=\"\"):\n        import json\n\n        result_file = os.path.join(\n            result_dir, \"%s_rank%d.json\" % (filename, get_rank())\n        )\n        final_result_file = os.path.join(result_dir, \"%s.json\" % filename)\n\n        json.dump(result, open(result_file, \"w\"))\n\n        if is_dist_avail_and_initialized():\n            dist.barrier()\n\n        if is_main_process():\n            logging.warning(\"rank %d starts merging results.\" % get_rank())\n            # combine results from all processes\n            result = []\n\n            for rank in range(get_world_size()):\n                result_file = os.path.join(\n                    result_dir, \"%s_rank%d.json\" % (filename, rank)\n                )\n                res = json.load(open(result_file, \"r\"))\n                result += res\n\n            if remove_duplicate:\n                result_new = []\n                id_list = []\n                for res in result:\n                    if res[remove_duplicate] not in id_list:\n                        id_list.append(res[remove_duplicate])\n                        result_new.append(res)\n                result = result_new\n\n            json.dump(result, open(final_result_file, \"w\"))\n            print(\"result file saved to %s\" % final_result_file)\n\n        return final_result_file\n"
  },
  {
    "path": "stllm/tasks/image_text_pretrain.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nfrom stllm.common.registry import registry\nfrom stllm.tasks.base_task import BaseTask\nfrom stllm.datasets.datasets.instruction_data import available_corpus, train_transform\nfrom stllm.datasets.datasets.image_video_itdatasets import ITImgTrainDataset, ITVidTrainDataset\n\n@registry.register_task(\"image_text_pretrain\")\nclass ImageTextPretrainTask(BaseTask):\n    def __init__(self):\n        super().__init__()\n\n    def evaluation(self, model, data_loader, cuda_enabled=True):\n        pass\n\n@registry.register_task(\"video_text_it\")\nclass VideoTextItTask(ImageTextPretrainTask):\n    def __init__(self):\n        super().__init__()\n\n    def build_datasets(self, cfg):\n        \"\"\"\n        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.\n        Download dataset and annotations automatically if not exist.\n\n        Args:\n            cfg (common.config.Config): _description_\n\n        Returns:\n            dict: Dictionary of torch.utils.data.Dataset objects by split.\n        \"\"\"\n        datasets = dict()\n        datasets_config = cfg.datasets_cfg\n        assert len(datasets_config) > 0, \"At least one dataset has to be specified.\"\n        simple = cfg.model_cfg.get('qformer_text_input',False)\n        for name in datasets_config:\n            dataset_config = datasets_config[name]\n            dataset_info = available_corpus[name]\n            dataset_cls = ITImgTrainDataset if get_media_type(dataset_info)==\"image\" else ITVidTrainDataset\n\n            datasets[name] = {'train': dataset_cls(ann_file=dataset_info, simple=simple,\n                        transform=train_transform, **dataset_config)}\n\n        return datasets\n\ndef get_media_type(dataset_info):\n    if len(dataset_info) == 3 and dataset_info[2] == \"video\":\n        return \"video\"\n    else:\n        return \"image\"\n"
  },
  {
    "path": "stllm/test/__init__.py",
    "content": ""
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_activitynet_qa.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3\n    Returns a score for correctness.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question = qa_set['q']\n        answer = qa_set['a']\n        pred = qa_set['pred']\n        try:\n            # Compute the correctness score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\": \n                            \"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. \"\n                            \"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Focus on the meaningful match between the predicted answer and the correct answer.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches.\\n\"\n                            \"- Evaluate the correctness of the prediction compared to the answer.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question: {question}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer: {pred}\\n\\n\"\n                            \"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. \"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is  a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['id']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['id'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['id'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['id']\n        question = sample['question']\n        answer = sample['answer']\n        pred = sample['pred']\n        qa_set = {\"q\": question, \"a\": answer, \"pred\": pred}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e:\n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                content = json.load(json_file)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score and accuracy\n    score_sum = 0\n    count = 0\n    yes_count = 0\n    no_count = 0\n    for key, result in combined_contents.items():\n        # Computing score\n        count += 1\n        if isinstance(result[0],list):\n            result = result[0]\n        score_match = result[0]['score']\n\n        score = int(score_match)\n        score_sum += score\n\n        # Computing accuracy\n        pred = result[0]['pred']\n        if \"yes\" in pred.lower():\n            yes_count += 1\n        elif \"no\" in pred.lower():\n            no_count += 1\n\n    average_score = score_sum / count\n    accuracy = yes_count / (yes_count + no_count)\n    print(\"Yes count:\", yes_count)\n    print(\"No count:\", no_count)\n    print(\"Accuracy:\", accuracy)\n    print(\"Average score:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_1_correctness.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3\n    Returns a score for correctness.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question = qa_set['q']\n        answer = qa_set['a']\n        pred = qa_set['pred']\n        try:\n            # Compute the correctness score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\": \n                            \"You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. \"\n                            \"Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\\n\"\n                            \"- The predicted answer must be factually accurate and align with the video content.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches.\\n\"\n                            \"- Evaluate the factual accuracy of the prediction compared to the answer.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question: {question}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer: {pred}\\n\\n\"\n                            \"Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. \"\n                            \"Please assign a score of 0 when the meaning of Predicted Answer is similar to 'I don't know'.\"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {''score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['video_name']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['video_name'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['video_name'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['video_name']\n        question = sample['Q']\n        answer = sample['A']\n        pred = sample['pred']\n        qa_set = {\"q\": question, \"a\": answer, \"pred\": pred}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e:\n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                content = json.load(json_file)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score\n    score_sum = 0\n    count = 0\n    for key, result in combined_contents.items():\n        count += 1\n        score_match = result[0]['score']\n        score = int(score_match)\n        score_sum += score\n    average_score = score_sum / count\n\n    print(\"Average score for correctness:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3 and\n    returns a score for detailed orientation.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question = qa_set['q']\n        answer = qa_set['a']\n        pred = qa_set['pred']\n        try:\n            # Compute the detailed-orientation score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\":\n                            \"You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. \"\n                            \"Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\\n\"\n                            \"- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches.\\n\"\n                            \"- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question: {question}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer: {pred}\\n\\n\"\n                            \"Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. \"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {''score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['video_name']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['video_name'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['video_name'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['video_name']\n        question = sample['Q']\n        answer = sample['A']\n        pred = sample['pred']\n        qa_set = {\"q\": question, \"a\": answer, \"pred\": pred}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e:\n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                content = json.load(json_file)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score\n    score_sum = 0\n    count = 0\n    for key, result in combined_contents.items():\n        count += 1\n        score_match = result[0]['score']\n        score = int(score_match)\n        score_sum += score\n    average_score = score_sum / count\n\n    print(\"Average score for detailed orientation:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_3_context.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3 and\n    returns a score for contextual understanding.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question = qa_set['q']\n        answer = qa_set['a']\n        pred = qa_set['pred']\n        try:\n            # Compute the contextual understanding score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\":\n                            \"You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. \"\n                            \"Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\\n\"\n                            \"- The predicted answer must capture the main themes and sentiments of the video.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches.\\n\"\n                            \"- Provide your evaluation of the contextual understanding of the prediction compared to the answer.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question: {question}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer: {pred}\\n\\n\"\n                            \"Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. \"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {''score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['video_name']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['video_name'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['video_name'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['video_name']\n        question = sample['Q']\n        answer = sample['A']\n        pred = sample['pred']\n        qa_set = {\"q\": question, \"a\": answer, \"pred\": pred}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e:\n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                try:\n                    content = json.load(json_file)\n                except:\n                    print (file_path)\n                    os.remove(file_path)\n                    exit(-1)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score\n    score_sum = 0\n    count = 0\n    for key, result in combined_contents.items():\n        count += 1\n        score_match = result[0]['score']\n        score = int(score_match)\n        score_sum += score\n    average_score = score_sum / count\n\n    print(\"Average score for contextual understanding:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_4_temporal.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3 and\n    returns a score for temporal understanding.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question = qa_set['q']\n        answer = qa_set['a']\n        pred = qa_set['pred']\n        try:\n            # Compute the temporal understanding score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\":\n                            \"You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. \"\n                            \"Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\\n\"\n                            \"- Evaluate the temporal accuracy of the prediction compared to the answer.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question: {question}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer: {pred}\\n\\n\"\n                            \"Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. \"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {''score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['video_name']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['video_name'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['video_name'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['video_name']\n        question = sample['Q']\n        answer = sample['A']\n        pred = sample['pred']\n        qa_set = {\"q\": question, \"a\": answer, \"pred\": pred}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e: \n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                content = json.load(json_file)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score\n    score_sum = 0\n    count = 0\n    for key, result in combined_contents.items():\n        count += 1\n        score_match = result[0]['score']\n        score = int(score_match)\n        score_sum += score\n    average_score = score_sum / count\n\n    print(\"Average score temporal understanding:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "stllm/test/gpt_evaluation/evaluate_benchmark_5_consistency.py",
    "content": "import openai\nimport os\nimport argparse\nimport json\nimport ast\nfrom multiprocessing.pool import Pool\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"question-answer-generation-using-gpt-3\")\n    parser.add_argument(\"--pred_path\", required=True, help=\"The path to file containing prediction.\")\n    parser.add_argument(\"--output_dir\", required=True, help=\"The path to save annotation json files.\")\n    parser.add_argument(\"--output_json\", required=True, help=\"The path to save annotation final combined json file.\")\n    parser.add_argument(\"--api_key\", required=True, help=\"OpenAI API key.\")\n    parser.add_argument(\"--num_tasks\", required=True, type=int, help=\"Number of splits.\")\n    args = parser.parse_args()\n    return args\n\n\ndef annotate(prediction_set, caption_files, output_dir):\n    \"\"\"\n    Evaluates question and answer pairs using GPT-3 and\n    returns a score for consistency.\n    \"\"\"\n    for file in caption_files:\n        key = file[:-5] # Strip file extension\n        qa_set = prediction_set[key]\n        question1 = qa_set['q1']\n        question2 = qa_set['q2']\n        answer = qa_set['a']\n        pred1 = qa_set['pred1']\n        pred2 = qa_set['pred2']\n        try:\n            # Compute the consistency score\n            completion = openai.ChatCompletion.create(\n                model=\"gpt-3.5-turbo\",\n                messages=[\n                    {\n                        \"role\": \"system\",\n                        \"content\":\n                            \"You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. \"\n                            \"You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions .\"\n                            \"Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:\"\n                            \"------\"\n                            \"##INSTRUCTIONS: \"\n                            \"- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\\n\"\n                            \"- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\\n\"\n                            \"- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\\n\"\n                            \"- Evaluate the consistency of the two predicted answers compared to the correct answer.\"\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\":\n                            \"Please evaluate the following video-based question-answer pair:\\n\\n\"\n                            f\"Question 1: {question1}\\n\"\n                            f\"Question 2: {question2}\\n\"\n                            f\"Correct Answer: {answer}\\n\"\n                            f\"Predicted Answer to Question 1: {pred1}\\n\"\n                            f\"Predicted Answer to Question 2: {pred2}\\n\\n\"\n                            \"Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. \"\n                            \"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING.\"\n                            \"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. \"\n                            \"For example, your response should look like this: {''score': 4.8}.\"\n                    }\n                ]\n            )\n            # Convert response to a Python dictionary.\n            response_message = completion[\"choices\"][0][\"message\"][\"content\"]\n            response_dict = ast.literal_eval(response_message)\n            result_qa_pair = [response_dict, qa_set]\n\n            # Save the question-answer pairs to a json file.\n            with open(f\"{output_dir}/{key}.json\", \"w\") as f:\n                json.dump(result_qa_pair, f)\n\n        except Exception as e:\n            print(f\"Error processing file '{key}': {e}\")\n\n\ndef main():\n    \"\"\"\n    Main function to control the flow of the program.\n    \"\"\"\n    # Parse arguments.\n    args = parse_args()\n\n    file = open(args.pred_path)\n    pred_contents = json.load(file)\n\n    # Dictionary to store the count of occurrences for each video_id\n    video_id_counts = {}\n    new_pred_contents = []\n\n    # Iterate through each sample in pred_contents\n    for sample in pred_contents:\n        video_id = sample['video_name']\n        if video_id in video_id_counts:\n            video_id_counts[video_id] += 1\n        else:\n            video_id_counts[video_id] = 0\n\n        # Create a new sample with the modified key\n        new_sample = sample\n        new_sample['video_name'] = f\"{video_id}_{video_id_counts[video_id]}\"\n        new_pred_contents.append(new_sample)\n\n    # Generating list of id's and corresponding files\n    id_list = [x['video_name'] for x in new_pred_contents]\n    caption_files = [f\"{id}.json\" for id in id_list]\n\n    output_dir = args.output_dir\n    # Generate output directory if not exists.\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # Preparing dictionary of question-answer sets\n    prediction_set = {}\n    for sample in new_pred_contents:\n        id = sample['video_name']\n        question1 = sample['Q1']\n        question2 = sample['Q1']\n        answer = sample['A']\n        pred1 = sample['pred1']\n        pred2 = sample['pred2']\n        qa_set = {\"q1\": question1, \"q2\": question2, \"a\": answer, \"pred1\": pred1, \"pred2\": pred2}\n        prediction_set[id] = qa_set\n\n    # Set the OpenAI API key.\n    openai.api_key = args.api_key\n    num_tasks = args.num_tasks\n\n    # While loop to ensure that all captions are processed.\n    while True:\n        try:\n            # Files that have not been processed yet.\n            completed_files = os.listdir(output_dir)\n            print(f\"completed_files: {len(completed_files)}\")\n\n            # Files that have not been processed yet.\n            incomplete_files = [f for f in caption_files if f not in completed_files]\n            print(f\"incomplete_files: {len(incomplete_files)}\")\n\n            # Break the loop when there are no incomplete files\n            if len(incomplete_files) == 0:\n                break\n            if len(incomplete_files) <= num_tasks:\n                num_tasks = 1\n\n            # Split tasks into parts.\n            part_len = len(incomplete_files) // num_tasks\n            all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]\n            task_args = [(prediction_set, part, args.output_dir) for part in all_parts]\n\n            # Use a pool of workers to process the files in parallel.\n            with Pool() as pool:\n                pool.starmap(annotate, task_args)\n\n        except Exception as e: \n            print(f\"Error: {e}\")\n\n    # Combine all the processed files into one\n    combined_contents = {}\n    json_path = args.output_json\n\n    # Iterate through json files\n    for file_name in os.listdir(output_dir):\n        if file_name.endswith(\".json\"):\n            file_path = os.path.join(output_dir, file_name)\n            with open(file_path, \"r\") as json_file:\n                content = json.load(json_file)\n                combined_contents[file_name[:-5]] = content\n\n    # Write combined content to a json file\n    with open(json_path, \"w\") as json_file:\n        json.dump(combined_contents, json_file)\n    print(\"All evaluation completed!\")\n\n    # Calculate average score\n    score_sum = 0\n    count = 0\n    for key, result in combined_contents.items():\n        count += 1\n        score_match = result[0]['score']\n        score = int(score_match)\n        score_sum += score\n    average_score = score_sum / count\n\n    print(\"Average score for consistency:\", average_score)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stllm/test/mvbench/mv_bench.py",
    "content": "import os \nimport json\nimport math\nimport numpy as np\nimport cv2\nimport io\nimport imageio\nfrom mmengine.fileio import FileClient\nclient = FileClient('disk')\nfrom decord import VideoReader, cpu\nfrom PIL import Image\nimport torchvision.transforms as T\nfrom stllm.test.video_transforms import (\n    GroupNormalize, GroupScale, GroupCenterCrop, \n    Stack, ToTorchFormatTensor\n)\nfrom torchvision.transforms.functional import InterpolationMode\nfrom torch.utils.data import Dataset\nimport torch\n\nfrom stllm.conversation.mvbench_conversation import ask, answer, EasyDict\n\n\ndata_list = {\n    \"Action Sequence\": (\"action_sequence.json\", \"your_data_path/star/Charades_v1_480/\", \"video\", True), # has start & end\n    \"Action Prediction\": (\"action_prediction.json\", \"your_data_path/star/Charades_v1_480/\", \"video\", True), # has start & end\n    \"Action Antonym\": (\"action_antonym.json\", \"your_data_path/ssv2_video/\", \"video\", False),\n    \"Fine-grained Action\": (\"fine_grained_action.json\", \"your_data_path/Moments_in_Time_Raw/videos/\", \"video\", False),\n    \"Unexpected Action\": (\"unexpected_action.json\", \"your_data_path/FunQA_test/test/\", \"video\", False),\n    \"Object Existence\": (\"object_existence.json\", \"your_data_path/clevrer/video_validation/\", \"video\", False),\n    \"Object Interaction\": (\"object_interaction.json\", \"your_data_path/star/Charades_v1_480/\", \"video\", True), # has start & end\n    \"Object Shuffle\": (\"object_shuffle.json\", \"your_data_path/perception/videos/\", \"video\", False),\n    \"Moving Direction\": (\"moving_direction.json\", \"your_data_path/clevrer/video_validation/\", \"video\", False),\n    \"Action Localization\": (\"action_localization.json\", \"your_data_path/sta/sta_video/\", \"video\", True),  # has start & end\n    \"Scene Transition\": (\"scene_transition.json\", \"your_data_path/scene_qa/video/\", \"video\", False),\n    \"Action Count\": (\"action_count.json\", \"your_data_path/perception/videos/\", \"video\", False),\n    \"Moving Count\": (\"moving_count.json\", \"your_data_path/clevrer/video_validation/\", \"video\", False),\n    \"Moving Attribute\": (\"moving_attribute.json\", \"your_data_path/clevrer/video_validation/\", \"video\", False),\n    \"State Change\": (\"state_change.json\", \"your_data_path/perception/videos/\", \"video\", False),\n    \"Fine-grained Pose\": (\"fine_grained_pose.json\", \"your_data_path/nturgbd/\", \"video\", False),\n    \"Character Order\": (\"character_order.json\", \"your_data_path/perception/videos/\", \"video\", False),\n    \"Egocentric Navigation\": (\"egocentric_navigation.json\", \"your_data_path/vlnqa/\", \"video\", False),\n    \"Episodic Reasoning\": (\"episodic_reasoning.json\", \"your_data_path/tvqa/frames_fps3_hq/\", \"frame\", True),  # has start & end, read frame\n    \"Counterfactual Inference\": (\"counterfactual_inference.json\", \"your_data_path/clevrer/video_validation/\", \"video\", False),\n}\n\ndata_dir = \"your_mvpbench_path/json\"\n\nclass MVBench_dataset(Dataset):\n    def __init__(self, data_dir, data_list=data_list, num_segments=8, resolution=224, specified_item=None):\n        self.data_list = []\n        if specified_item:\n            data_list = {specified_item: data_list[specified_item]}\n        for k, v in data_list.items():\n            with open(os.path.join(data_dir, v[0]), 'r') as f:\n                json_data = json.load(f)\n            for data in json_data:\n                self.data_list.append({\n                    'task_type': k,\n                    'prefix': v[1],\n                    'data_type': v[2],\n                    'bound': v[3],\n                    'data': data\n                })\n        \n        self.decord_method = {\n            'video': self.read_video,\n            'gif': self.read_gif,\n            'frame': self.read_frame,\n        }\n        \n        self.num_segments = num_segments\n        \n        # transform\n        crop_size = resolution\n        scale_size = resolution\n        input_mean = [0.48145466, 0.4578275, 0.40821073]\n        input_std = [0.26862954, 0.26130258, 0.27577711]\n        self.transform = T.Compose([\n            GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),\n            GroupCenterCrop(crop_size),\n            Stack(),\n            ToTorchFormatTensor(),\n            GroupNormalize(input_mean, input_std) \n        ])\n    \n    def __str__(self):\n        len_list = {}\n        option_list = {}\n        for data in self.data_list:\n            if data['task_type'] not in len_list:\n                len_list[data['task_type']] = 0\n            len_list[data['task_type']] += 1\n            if data['task_type'] not in option_list:\n                option_list[data['task_type']] = 0\n            option_list[data['task_type']] += len(data['data']['candidates'])\n        \n        correct = 0\n        total = 0\n        res = f\"There are {len(self.data_list)} videos as follow:\\n\"\n        for k, v in len_list.items():\n            correct += len_list[k]\n            total += option_list[k]\n            res += f\"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\\n\"\n            correct = correct + 1 / option_list[k]\n        res += f\"Total random accuracy: {correct/total*100:.2f}%\"\n        return res.rstrip()\n        \n    def __len__(self):\n        return len(self.data_list)\n    \n    def get_index(self, bound, fps, max_frame, first_idx=0):\n        if bound:\n            start, end = bound[0], bound[1]\n        else:\n            start, end = -100000, 100000\n        start_idx = max(first_idx, round(start * fps))\n        end_idx = min(round(end * fps), max_frame)\n\n        if bound:\n            video_len = bound[1] - bound[0]\n        else:\n            video_len = max_frame / fps\n\n        if self.num_segments > 0:\n            num_segments = self.num_segments  \n        else:  #fps 1\n            if video_len < 4:\n                num_segments = 4\n            elif video_len > 16:\n                num_segments = 16\n            else:\n                num_segments = math.floor(video_len)\n        seg_size = float(end_idx - start_idx) / num_segments\n        frame_indices = np.array([\n            int(start_idx + (seg_size / 2) + np.round(seg_size * idx))\n            for idx in range(num_segments)\n        ])\n        return frame_indices\n    \n    def read_video(self, video_path, bound=None):\n        video_bytes = client.get(video_path)\n        vr = VideoReader(io.BytesIO(video_bytes), ctx=cpu(0), num_threads=1)\n        max_frame = len(vr) - 1\n        fps = float(vr.get_avg_fps())\n\n        images_group = list()\n        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) \n        for frame_index in frame_indices:\n            img = Image.fromarray(vr[frame_index].numpy())\n            images_group.append(img)\n        torch_imgs = self.transform(images_group)\n\n        return torch_imgs\n    \n    def read_gif(self, video_path, bound=None, fps=25):\n        video_bytes = client.get(video_path)\n        gif = imageio.get_reader(io.BytesIO(video_bytes))\n        max_frame = len(gif) - 1\n        \n        images_group = list()\n        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) \n        for index, frame in enumerate(gif):\n            if index in frame_indices:\n                img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)\n                img = Image.fromarray(img)\n                images_group.append(img)\n        torch_imgs = self.transform(images_group)\n        return torch_imgs\n    \n    def read_frame(self, video_path, bound=None, fps=3):\n        if os.path.exists(video_path):\n            max_frame = len(os.listdir(video_path))\n        else:\n            max_frame = len([k for k in client.list(video_path)])\n            \n        images_group = list()\n        frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1\n        for frame_index in frame_indices:\n            img_bytes = client.get(os.path.join(video_path, f\"{frame_index:05d}.jpg\"))\n            img = Image.open(io.BytesIO(img_bytes))\n            images_group.append(img)\n        torch_imgs = self.transform(images_group)\n\n        return torch_imgs\n\n    def qa_template(self, data):\n        question = f\"Question: {data['question']}\\n\"\n        question += \"Options:\\n\"\n        answer = data['answer']\n        answer_idx = -1\n        for idx, c in enumerate(data['candidates']):\n            question += f\"({chr(ord('A') + idx)}) {c}\\n\"\n            if c == answer:\n                answer_idx = idx\n        question = question.rstrip()\n        answer = f\"({chr(ord('A') + answer_idx)}) {answer}\"\n        return question, answer\n\n    def __getitem__(self, idx):\n        decord_method = self.decord_method[self.data_list[idx]['data_type']]\n        bound = None\n        if self.data_list[idx]['bound']:\n            bound = (\n                self.data_list[idx]['data']['start'],\n                self.data_list[idx]['data']['end'],\n            )\n        video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])\n        torch_imgs = decord_method(video_path, bound)\n        question, answer = self.qa_template(self.data_list[idx]['data'])\n            \n        return {\n            'video': torch_imgs, \n            'video_path': video_path,\n            'question': question, \n            'answer': answer,\n            'task_type': self.data_list[idx]['task_type'],\n        }\n\ndef get_residual_index(sample_segments, total_segments, devices):\n    seg_size = float(total_segments) / sample_segments\n    frame_indices = np.array([\n    int((seg_size / 2) + np.round(seg_size * idx))\n    for idx in range(sample_segments)\n    ])\n    frame_indices = torch.from_numpy(frame_indices).to(devices)\n    return frame_indices\n\ndef infer_mvbench(\n        model,\n        data_sample, system=\"\", \n        question_prompt='', # add in the end of question\n        answer_prompt=None, # add in the begining of answer\n        return_prompt='',  # add in the begining of return message\n        system_llm=False,\n        all_token=False,\n        ask_simple=False,\n    ):\n    video = data_sample[\"video\"]\n    TC, H, W = video.shape\n    video = video.reshape(TC//3, 3, H, W).to(\"cuda:0\")\n    \n    video_list = []\n    with torch.no_grad():\n        if hasattr(model.model,'stllm_model'):\n            encode_model = model.model.stllm_model\n        else:\n            encode_model = model.model.model.stllm_model\n            \n        video_emb, _, _ = encode_model.encode_img(video, data_sample['question'])\n        \n    if not all_token:\n        video_emb = video_emb.mean(dim=0, keepdim=True)\n    else:\n        video_emb = video_emb.view(1, -1, video_emb.size(-1))\n    video_list.append(video_emb)\n\n    chat = EasyDict({\n        \"system\": system,\n        \"roles\": (\"Human\", \"Assistant\"),\n        \"messages\": [],\n        \"sep\": \"###\"\n    })\n\n    chat.messages.append([chat.roles[0], f\"<Video><VideoHere></Video>\\n\"])\n    \n    if system_llm:\n        prompt = system + data_sample['question'] + question_prompt\n    else:\n        prompt = data_sample['question'] + question_prompt\n    \n    ask(prompt, chat)\n\n    llm_message = answer(\n        conv=chat, model=model, ask_simple=ask_simple, do_sample=False, \n        img_list=video_list, max_new_tokens=100, \n        answer_prompt=answer_prompt\n    )[0]\n    # remove potential explanation\n    llm_message = return_prompt + llm_message.strip().split('\\n')[0]\n    print(llm_message)\n    print(f\"GT: {data_sample['answer']}\")\n    return llm_message\n\ndef check_ans(pred, gt):\n    flag = False\n    \n    pred_list = pred.lower().split(' ')\n    pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])\n    gt_list = gt.lower().split(' ')\n    gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])\n    if gt_content[-1] == '.':\n        gt_content = gt_content[:-1]\n    \n    if pred_option.replace('.', '') in gt_option:\n        flag = True\n    elif gt_option in pred_option:\n        flag = True\n        \n    return flag\n\nif __name__ == \"__main__\":\n    dataset = MVBench_dataset(data_dir, data_list, num_segments=16, resolution=224)"
  },
  {
    "path": "stllm/test/mvbench/mv_bench_infer.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\nfrom mv_bench import MVBench_dataset, infer_mvbench, check_ans\nimport argparse\nimport os\n\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--anno-path\", required=True, help=\"path to mvbench annotation.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\"--specified_item\", type=str, required=False, default=None)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    parser.add_argument(\"--system_llm\", action='store_false')\n    parser.add_argument(\"--ask_simple\", action='store_true')\n    return parser.parse_args()\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on a set of video files using the provided model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n\n    print('Initializing Chat')\n    args = parse_args()\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    \n    all_token = ~(model_config.video_input=='mean')\n    correct = 0\n    total = 0\n    res_list = []\n    acc_dict = {}\n    videos_len = []\n    dataset = MVBench_dataset(args.anno_path, num_segments=args.num_frames, resolution=224, specified_item = args.specified_item)\n    for example in tqdm(dataset):\n        task_type = example['task_type']\n        if task_type not in acc_dict:\n            acc_dict[task_type] = [0, 0] # correct, total\n        acc_dict[task_type][1] += 1\n        total += 1\n\n        pred = infer_mvbench(\n            model,example, \n            system=\"Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\\n\",\n            question_prompt=\"\\nOnly give the best option.\",\n            answer_prompt=\"Best option:(\",\n            return_prompt='(',\n            system_llm=args.system_llm,\n            all_token=all_token,\n            ask_simple=args.ask_simple,\n        )\n\n        gt = example['answer']\n        if args.specified_item:\n            res_list.append({\n                'video_path': example['video_path'],\n                'question': example['question'],\n                'pred': pred,\n                'gt': gt,\n            })\n        else:\n            res_list.append({\n                'pred': pred,\n                'gt': gt\n            })\n        if check_ans(pred=pred, gt=gt):\n            acc_dict[task_type][0] += 1\n            correct += 1\n        print(f\"Part  Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%\")\n        print(f\"Total Acc: {correct / total * 100 :.2f}%\")\n        print('-' * 30, task_type, '-' * 30)\n    acc_dict['Total Acc'] = f\"{correct / total * 100 :.2f}%\"\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as f:\n        json.dump({\n            \"acc_dict\": acc_dict,\n            \"res_list\": res_list\n        }, f)\n\n              \n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/qabench/activitynet_qa.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\n\nimport argparse\nimport os\nimport torch\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_VIDEO_LLama2, CONV_VIDEO_Vicuna0, \\\n                    CONV_VISION_LLama2, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\nfrom stllm.test.video_utils import load_video_rawframes\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Define the command-line arguments\n    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)\n    parser.add_argument('--gt_file_question', help='Path to the ground truth file containing question.', required=True)\n    parser.add_argument('--gt_file_answers', help='Path to the ground truth file containing answers.', required=True)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\"--frames\", type=str, required=False, default=None)\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    return parser.parse_args()\n\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n    conv_dict = {'minigpt4_vicuna0': CONV_VIDEO_Vicuna0,\n             \"instructblip_vicuna0\": CONV_instructblip_Vicuna0,\n             \"instructblip_vicuna0_btadapter\": CONV_instructblip_Vicuna0,\n             'minigpt4_vicuna0_btadapter': CONV_VIDEO_Vicuna0,}\n\n    print('Initializing Chat')\n    args = parse_args()\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    #model_config.eval = True\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    \n    CONV_VISION = conv_dict[model_config.model_type]\n    model = model.to(torch.float16)\n\n    chat = Chat(model, device='cuda:{}'.format(args.gpu_id))\n\n    # Load both ground truth file containing questions and answers\n    with open(args.gt_file_question) as file:\n        gt_questions = json.load(file)\n    with open(args.gt_file_answers) as file:\n        gt_answers = json.load(file)\n    \n    if args.frames is not None:\n        with open(args.frames,'r') as f:\n            frames = json.load(f)\n\n    # Create the output directory if it doesn't exist\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    output_list = []  # List to store the output results\n\n    video_formats = ['.mp4', '.avi', '.mov', '.mkv']\n\n    # Iterate over each sample in the ground truth file\n    index = 0\n    \n    for sample in tqdm(gt_questions):\n        video_name = 'v_' + sample['video_name']\n        question = sample['question']\n        id = sample['question_id']\n        answer = gt_answers[index]['answer']\n        index += 1\n\n        sample_set = {'id': id, 'question': question, 'answer': answer}\n        video_path = os.path.join(args.video_dir, video_name)    \n        # Check if the video exists\n        chat_state = CONV_VISION.copy()\n        img_list = []\n        chat.upload_video(video_path, chat_state, img_list, args.num_frames, question)\n        chat.ask(question, chat_state)\n        llm_message = chat.answer(conv=chat_state,\n                              img_list=img_list,\n                              num_beams=5,\n                              do_sample=False,\n                              temperature=1,\n                              max_new_tokens=300,\n                              max_length=2000)[0]\n\n        sample_set['pred'] = llm_message\n        output_list.append(sample_set)\n\n    # Save the output list to a JSON file\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as file:\n        json.dump(output_list, file)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/qabench/msrvtt_qa.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\n\nimport argparse\nimport os\nimport torch\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_VIDEO_LLama2, CONV_VIDEO_Vicuna0, \\\n                    CONV_VISION_LLama2, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Define the command-line arguments\n    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)\n    parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    return parser.parse_args()\n\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n    conv_dict = {'minigpt4_vicuna0': CONV_VIDEO_Vicuna0,\n             \"instructblip_vicuna0\": CONV_instructblip_Vicuna0,\n             \"instructblip_vicuna0_btadapter\": CONV_instructblip_Vicuna0,\n             'minigpt4_vicuna0_btadapter': CONV_VIDEO_Vicuna0,}\n\n    print('Initializing Chat')\n    args = parse_args()\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    #model_config.eval = True\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    CONV_VISION = conv_dict[model_config.model_type]\n    model = model.to(torch.float16)\n    \n    chat = Chat(model, device='cuda:{}'.format(args.gpu_id))\n\n    # Load both ground truth file containing questions and answers\n    with open(args.gt_file) as file:\n        gt_file = json.load(file)\n\n    # Create the output directory if it doesn't exist\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    output_list = []  # List to store the output results\n\n    video_formats = ['.mp4', '.avi', '.mov', '.mkv']\n\n    # Iterate over each sample in the ground truth file\n    index = 0\n    \n    for sample in tqdm(gt_file):\n        video_name = str(sample['video_id']) + '.mp4'\n        question = sample['question'] \n        answer = sample['answer']\n        index += 1\n\n        sample_set = {'id': index, 'question': question, 'answer': answer}\n        video_path = os.path.join(args.video_dir, video_name)\n        # Check if the video exists\n        chat_state = CONV_VISION.copy()\n        img_list = []\n        chat.upload_video(video_path, chat_state, img_list, args.num_frames, question)\n        chat.ask(question, chat_state)\n        llm_message = chat.answer(conv=chat_state,\n                              img_list=img_list,\n                              num_beams=5,\n                              do_sample=False,\n                              temperature=1,\n                              system=False,\n                              max_new_tokens=300,\n                              max_length=2000)[0]\n\n        sample_set['pred'] = llm_message\n        output_list.append(sample_set)\n\n    # Save the output list to a JSON file\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as file:\n        json.dump(output_list, file)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/qabench/msvd_qa.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\n\nimport argparse\nimport os\nimport torch\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_VIDEO_LLama2, CONV_VIDEO_Vicuna0, \\\n                    CONV_VISION_LLama2, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Define the command-line arguments\n    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)\n    parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    return parser.parse_args()\n\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n    conv_dict = {'minigpt4_vicuna0': CONV_VIDEO_Vicuna0,\n             \"instructblip_vicuna0\": CONV_instructblip_Vicuna0,\n             \"instructblip_vicuna0_btadapter\": CONV_instructblip_Vicuna0,\n             'minigpt4_vicuna0_btadapter': CONV_VIDEO_Vicuna0,}\n\n    print('Initializing Chat')\n    args = parse_args()\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    #model_config.eval = True\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    \n    CONV_VISION = conv_dict[model_config.model_type]\n    model = model.to(torch.float16)\n    chat = Chat(model, device='cuda:{}'.format(args.gpu_id))\n\n    # Load both ground truth file containing questions and answers\n    with open(args.gt_file) as file:\n        gt_file = json.load(file)\n\n    # Create the output directory if it doesn't exist\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    output_list = []  # List to store the output results\n\n    video_formats = ['.mp4', '.avi', '.mov', '.mkv']\n\n    # Iterate over each sample in the ground truth file\n    index = 0\n    \n    for sample in tqdm(gt_file):\n        video_name = sample['video_name'] if 'video_name' in sample else sample['video']\n        id = sample['question_id'] if 'question_id' in sample else sample['id']\n        question = sample['question']\n        answer = sample['answer']\n        index += 1\n\n        sample_set = {'id': id, 'question': question, 'answer': answer}\n        video_path = os.path.join(args.video_dir, video_name)\n        # Check if the video exists\n        chat_state = CONV_VISION.copy()\n        img_list = []\n        chat.upload_video(video_path, chat_state, img_list, args.num_frames, question)\n        chat.ask(question, chat_state)\n        llm_message = chat.answer(conv=chat_state,\n                              img_list=img_list,\n                              num_beams=5,\n                              do_sample=False,\n                              temperature=1,\n                              system=False,\n                              max_new_tokens=300,\n                              max_length=2000)[0]\n\n        sample_set['pred'] = llm_message\n        output_list.append(sample_set)\n        \n\n    # Save the output list to a JSON file\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as file:\n        json.dump(output_list, file)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/vcgbench/videochatgpt_benchmark_consist.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\n\nimport argparse\nimport os\nimport torch\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_VIDEO_LLama2, CONV_VIDEO_Vicuna0, \\\n                    CONV_VISION_LLama2, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Define the command-line arguments\n    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)\n    parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    return parser.parse_args()\n\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on a set of video files using the provided model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n    conv_dict = {'minigpt4_vicuna0': CONV_VIDEO_Vicuna0,\n             \"instructblip_vicuna0\": CONV_instructblip_Vicuna0,\n             \"instructblip_vicuna0_btadapter\": CONV_instructblip_Vicuna0,\n             'minigpt4_vicuna0_btadapter': CONV_VIDEO_Vicuna0,}\n\n    print('Initializing Chat')\n    args = parse_args()\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    #model_config.eval = True\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    CONV_VISION = conv_dict[model_config.model_type]\n    model = model.to(torch.float16)\n    \n    chat = Chat(model, device='cuda:{}'.format(args.gpu_id))\n\n    # Load the ground truth file\n    with open(args.gt_file) as file:\n        gt_contents = json.load(file)\n\n    # Create the output directory if it doesn't exist\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    output_list = []  # List to store the output results\n\n    video_formats = ['.mp4', '.avi', '.mov', '.mkv']\n\n    # Iterate over each sample in the ground truth file\n    for sample in tqdm(gt_contents):\n        video_name = sample['video_name']\n        sample_set = sample\n        question_1 = sample['Q1']\n        question_2 = sample['Q2']\n\n        # Load the video file\n        for fmt in video_formats:  # Added this line\n            temp_path = os.path.join(args.video_dir, f\"{video_name}{fmt}\")\n            if os.path.exists(temp_path):\n                video_path = temp_path\n                break\n\n        for i in range(1,3):\n            chat_state = CONV_VISION.copy()\n            img_list = []\n            question = question_1 if i==1 else question_2\n            chat.upload_video(video_path, chat_state, img_list, args.num_frames, question)\n            chat.ask(question, chat_state)\n            llm_message = chat.answer(conv=chat_state,\n                                  img_list=img_list,\n                                  num_beams=5,\n                                  do_sample=False,\n                                  temperature=1,\n                                  max_new_tokens=300,\n                                  max_length=2000)[0]\n\n            sample_set['pred{}'.format(i)] = llm_message\n        output_list.append(sample_set)\n\n    # Save the output list to a JSON file\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as file:\n        json.dump(output_list, file)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/vcgbench/videochatgpt_benchmark_general.py",
    "content": "import os\nimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport json\nfrom tqdm import tqdm\n\nimport argparse\nimport os\nimport torch\nfrom stllm.common.config import Config\nfrom stllm.common.registry import registry\nfrom stllm.conversation.conversation import Chat, CONV_VIDEO_LLama2, CONV_VIDEO_Vicuna0, \\\n                    CONV_VISION_LLama2, CONV_instructblip_Vicuna0\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Define the command-line arguments\n    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)\n    parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True)\n    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)\n    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--ckpt-path\", required=True, help=\"path to checkpoint file.\")\n    parser.add_argument(\"--num-frames\", type=int, required=False, default=100)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n    parser.add_argument(\"--gpu-id\", type=int, default=0, help=\"specify the gpu to load the model.\")\n    return parser.parse_args()\n\n\ndef run_inference(args):\n    \"\"\"\n    Run inference on a set of video files using the provided model.\n\n    Args:\n        args: Command-line arguments.\n    \"\"\"\n    # Initialize the model\n    conv_dict = {'minigpt4_vicuna0': CONV_VIDEO_Vicuna0,\n             \"instructblip_vicuna0\": CONV_instructblip_Vicuna0,\n             \"instructblip_vicuna0_btadapter\": CONV_instructblip_Vicuna0,\n             'minigpt4_vicuna0_btadapter': CONV_VIDEO_Vicuna0,}\n\n    print('Initializing Chat')\n    args = parse_args()\n    cfg = Config(args)\n\n    model_config = cfg.model_cfg\n    model_config.device_8bit = args.gpu_id\n    model_config.ckpt = args.ckpt_path\n    model_cls = registry.get_model_class(model_config.arch)\n    #model_config.eval = True\n    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))\n    for name, para in model.named_parameters():\n        para.requires_grad = False\n    model.eval()\n    \n    CONV_VISION = conv_dict[model_config.model_type]\n    model = model.to(torch.float16)\n  \n    chat = Chat(model, device='cuda:{}'.format(args.gpu_id))\n\n    # Load the ground truth file\n    with open(args.gt_file) as file:\n        gt_contents = json.load(file)\n\n    # Create the output directory if it doesn't exist\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    output_list = []  # List to store the output results\n\n    video_formats = ['.mp4', '.avi', '.mov', '.mkv']\n\n    # Iterate over each sample in the ground truth file\n    for sample in tqdm(gt_contents):\n        video_name = sample['video_name']\n        sample_set = sample\n        question = sample['Q']\n\n        # Load the video file\n        for fmt in video_formats:  # Added this line\n            temp_path = os.path.join(args.video_dir, f\"{video_name}{fmt}\")\n            if os.path.exists(temp_path):\n                video_path = temp_path\n                break\n        \n        chat_state = CONV_VISION.copy()\n        img_list = []\n        chat.upload_video(video_path, chat_state, img_list, args.num_frames, question)\n        chat.ask(question, chat_state)\n        llm_message = chat.answer(conv=chat_state,\n                              img_list=img_list,\n                              num_beams=5,\n                              do_sample=False,\n                              temperature=1,\n                              max_new_tokens=300,\n                              max_length=2000)[0]\n        \n        sample_set['pred'] = llm_message\n        output_list.append(sample_set)\n\n    # Save the output list to a JSON file\n    with open(os.path.join(args.output_dir, f\"{args.output_name}.json\"), 'w') as file:\n        json.dump(output_list, file)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    run_inference(args)\n"
  },
  {
    "path": "stllm/test/video_transforms.py",
    "content": "import torchvision\nimport random\nfrom PIL import Image, ImageOps\nimport numpy as np\nimport numbers\nimport math\nimport torch\n\n\nclass GroupRandomCrop(object):\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, img_group):\n\n        w, h = img_group[0].size\n        th, tw = self.size\n\n        out_images = list()\n\n        x1 = random.randint(0, w - tw)\n        y1 = random.randint(0, h - th)\n\n        for img in img_group:\n            assert(img.size[0] == w and img.size[1] == h)\n            if w == tw and h == th:\n                out_images.append(img)\n            else:\n                out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))\n\n        return out_images\n\n\nclass MultiGroupRandomCrop(object):\n    def __init__(self, size, groups=1):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n        self.groups = groups\n\n    def __call__(self, img_group):\n\n        w, h = img_group[0].size\n        th, tw = self.size\n\n        out_images = list()\n\n        for i in range(self.groups):\n            x1 = random.randint(0, w - tw)\n            y1 = random.randint(0, h - th)\n\n            for img in img_group:\n                assert(img.size[0] == w and img.size[1] == h)\n                if w == tw and h == th:\n                    out_images.append(img)\n                else:\n                    out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))\n\n        return out_images\n\n\nclass GroupCenterCrop(object):\n    def __init__(self, size):\n        self.worker = torchvision.transforms.CenterCrop(size)\n\n    def __call__(self, img_group):\n        return [self.worker(img) for img in img_group]\n\n\nclass GroupRandomHorizontalFlip(object):\n    \"\"\"Randomly horizontally flips the given PIL.Image with a probability of 0.5\n    \"\"\"\n\n    def __init__(self, is_flow=False):\n        self.is_flow = is_flow\n\n    def __call__(self, img_group, is_flow=False):\n        v = random.random()\n        if v < 0.5:\n            ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]\n            if self.is_flow:\n                for i in range(0, len(ret), 2):\n                    # invert flow pixel values when flipping\n                    ret[i] = ImageOps.invert(ret[i])\n            return ret\n        else:\n            return img_group\n\n\nclass GroupNormalize(object):\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, tensor):\n        rep_mean = self.mean * (tensor.size()[0] // len(self.mean))\n        rep_std = self.std * (tensor.size()[0] // len(self.std))\n\n        # TODO: make efficient\n        for t, m, s in zip(tensor, rep_mean, rep_std):\n            t.sub_(m).div_(s)\n\n        return tensor\n\n\nclass GroupScale(object):\n    \"\"\" Rescales the input PIL.Image to the given 'size'.\n    'size' will be the size of the smaller edge.\n    For example, if height > width, then image will be\n    rescaled to (size * height / width, size)\n    size: size of the smaller edge\n    interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(self, size, interpolation=Image.BILINEAR):\n        self.worker = torchvision.transforms.Resize(size, interpolation)\n\n    def __call__(self, img_group):\n        return [self.worker(img) for img in img_group]\n\n\nclass GroupOverSample(object):\n    def __init__(self, crop_size, scale_size=None, flip=True):\n        self.crop_size = crop_size if not isinstance(\n            crop_size, int) else (crop_size, crop_size)\n\n        if scale_size is not None:\n            self.scale_worker = GroupScale(scale_size)\n        else:\n            self.scale_worker = None\n        self.flip = flip\n\n    def __call__(self, img_group):\n\n        if self.scale_worker is not None:\n            img_group = self.scale_worker(img_group)\n\n        image_w, image_h = img_group[0].size\n        crop_w, crop_h = self.crop_size\n\n        offsets = GroupMultiScaleCrop.fill_fix_offset(\n            False, image_w, image_h, crop_w, crop_h)\n        oversample_group = list()\n        for o_w, o_h in offsets:\n            normal_group = list()\n            flip_group = list()\n            for i, img in enumerate(img_group):\n                crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))\n                normal_group.append(crop)\n                flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)\n\n                if img.mode == 'L' and i % 2 == 0:\n                    flip_group.append(ImageOps.invert(flip_crop))\n                else:\n                    flip_group.append(flip_crop)\n\n            oversample_group.extend(normal_group)\n            if self.flip:\n                oversample_group.extend(flip_group)\n        return oversample_group\n\n\nclass GroupFullResSample(object):\n    def __init__(self, crop_size, scale_size=None, flip=True):\n        self.crop_size = crop_size if not isinstance(\n            crop_size, int) else (crop_size, crop_size)\n\n        if scale_size is not None:\n            self.scale_worker = GroupScale(scale_size)\n        else:\n            self.scale_worker = None\n        self.flip = flip\n\n    def __call__(self, img_group):\n\n        if self.scale_worker is not None:\n            img_group = self.scale_worker(img_group)\n\n        image_w, image_h = img_group[0].size\n        crop_w, crop_h = self.crop_size\n\n        w_step = (image_w - crop_w) // 4\n        h_step = (image_h - crop_h) // 4\n\n        offsets = list()\n        offsets.append((0 * w_step, 2 * h_step))  # left\n        offsets.append((4 * w_step, 2 * h_step))  # right\n        offsets.append((2 * w_step, 2 * h_step))  # center\n\n        oversample_group = list()\n        for o_w, o_h in offsets:\n            normal_group = list()\n            flip_group = list()\n            for i, img in enumerate(img_group):\n                crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))\n                normal_group.append(crop)\n                if self.flip:\n                    flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)\n\n                    if img.mode == 'L' and i % 2 == 0:\n                        flip_group.append(ImageOps.invert(flip_crop))\n                    else:\n                        flip_group.append(flip_crop)\n\n            oversample_group.extend(normal_group)\n            oversample_group.extend(flip_group)\n        return oversample_group\n\n\nclass GroupMultiScaleCrop(object):\n\n    def __init__(self, input_size, scales=None, max_distort=1,\n                 fix_crop=True, more_fix_crop=True):\n        self.scales = scales if scales is not None else [1, .875, .75, .66]\n        self.max_distort = max_distort\n        self.fix_crop = fix_crop\n        self.more_fix_crop = more_fix_crop\n        self.input_size = input_size if not isinstance(input_size, int) else [\n            input_size, input_size]\n        self.interpolation = Image.BILINEAR\n\n    def __call__(self, img_group):\n\n        im_size = img_group[0].size\n\n        crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)\n        crop_img_group = [\n            img.crop(\n                (offset_w,\n                 offset_h,\n                 offset_w +\n                 crop_w,\n                 offset_h +\n                 crop_h)) for img in img_group]\n        ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)\n                         for img in crop_img_group]\n        return ret_img_group\n\n    def _sample_crop_size(self, im_size):\n        image_w, image_h = im_size[0], im_size[1]\n\n        # find a crop size\n        base_size = min(image_w, image_h)\n        crop_sizes = [int(base_size * x) for x in self.scales]\n        crop_h = [\n            self.input_size[1] if abs(\n                x - self.input_size[1]) < 3 else x for x in crop_sizes]\n        crop_w = [\n            self.input_size[0] if abs(\n                x - self.input_size[0]) < 3 else x for x in crop_sizes]\n\n        pairs = []\n        for i, h in enumerate(crop_h):\n            for j, w in enumerate(crop_w):\n                if abs(i - j) <= self.max_distort:\n                    pairs.append((w, h))\n\n        crop_pair = random.choice(pairs)\n        if not self.fix_crop:\n            w_offset = random.randint(0, image_w - crop_pair[0])\n            h_offset = random.randint(0, image_h - crop_pair[1])\n        else:\n            w_offset, h_offset = self._sample_fix_offset(\n                image_w, image_h, crop_pair[0], crop_pair[1])\n\n        return crop_pair[0], crop_pair[1], w_offset, h_offset\n\n    def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):\n        offsets = self.fill_fix_offset(\n            self.more_fix_crop, image_w, image_h, crop_w, crop_h)\n        return random.choice(offsets)\n\n    @staticmethod\n    def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):\n        w_step = (image_w - crop_w) // 4\n        h_step = (image_h - crop_h) // 4\n\n        ret = list()\n        ret.append((0, 0))  # upper left\n        ret.append((4 * w_step, 0))  # upper right\n        ret.append((0, 4 * h_step))  # lower left\n        ret.append((4 * w_step, 4 * h_step))  # lower right\n        ret.append((2 * w_step, 2 * h_step))  # center\n\n        if more_fix_crop:\n            ret.append((0, 2 * h_step))  # center left\n            ret.append((4 * w_step, 2 * h_step))  # center right\n            ret.append((2 * w_step, 4 * h_step))  # lower center\n            ret.append((2 * w_step, 0 * h_step))  # upper center\n\n            ret.append((1 * w_step, 1 * h_step))  # upper left quarter\n            ret.append((3 * w_step, 1 * h_step))  # upper right quarter\n            ret.append((1 * w_step, 3 * h_step))  # lower left quarter\n            ret.append((3 * w_step, 3 * h_step))  # lower righ quarter\n\n        return ret\n\n\nclass GroupRandomSizedCrop(object):\n    \"\"\"Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size\n    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio\n    This is popularly used to train the Inception networks\n    size: size of the smaller edge\n    interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(self, size, interpolation=Image.BILINEAR):\n        self.size = size\n        self.interpolation = interpolation\n\n    def __call__(self, img_group):\n        for attempt in range(10):\n            area = img_group[0].size[0] * img_group[0].size[1]\n            target_area = random.uniform(0.08, 1.0) * area\n            aspect_ratio = random.uniform(3. / 4, 4. / 3)\n\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if random.random() < 0.5:\n                w, h = h, w\n\n            if w <= img_group[0].size[0] and h <= img_group[0].size[1]:\n                x1 = random.randint(0, img_group[0].size[0] - w)\n                y1 = random.randint(0, img_group[0].size[1] - h)\n                found = True\n                break\n        else:\n            found = False\n            x1 = 0\n            y1 = 0\n\n        if found:\n            out_group = list()\n            for img in img_group:\n                img = img.crop((x1, y1, x1 + w, y1 + h))\n                assert(img.size == (w, h))\n                out_group.append(\n                    img.resize(\n                        (self.size, self.size), self.interpolation))\n            return out_group\n        else:\n            # Fallback\n            scale = GroupScale(self.size, interpolation=self.interpolation)\n            crop = GroupRandomCrop(self.size)\n            return crop(scale(img_group))\n\n\nclass ConvertDataFormat(object):\n    def __init__(self, model_type):\n        self.model_type = model_type\n\n    def __call__(self, images):\n        if self.model_type == '2D':\n            return images\n        tc, h, w = images.size()\n        t = tc // 3\n        images = images.view(t, 3, h, w)\n        images = images.permute(1, 0, 2, 3)\n        return images\n\n\nclass Stack(object):\n\n    def __init__(self, roll=False):\n        self.roll = roll\n\n    def __call__(self, img_group):\n        if img_group[0].mode == 'L':\n            return np.concatenate([np.expand_dims(x, 2)\n                                   for x in img_group], axis=2)\n        elif img_group[0].mode == 'RGB':\n            if self.roll:\n                return np.concatenate([np.array(x)[:, :, ::-1]\n                                       for x in img_group], axis=2)\n            else:\n                #print(np.concatenate(img_group, axis=2).shape)\n                # print(img_group[0].shape)\n                return np.concatenate(img_group, axis=2)\n\n\nclass ToTorchFormatTensor(object):\n    \"\"\" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]\n    to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] \"\"\"\n\n    def __init__(self, div=True):\n        self.div = div\n\n    def __call__(self, pic):\n        if isinstance(pic, np.ndarray):\n            # handle numpy array\n            img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()\n        else:\n            # handle PIL Image\n            img = torch.ByteTensor(\n                torch.ByteStorage.from_buffer(\n                    pic.tobytes()))\n            img = img.view(pic.size[1], pic.size[0], len(pic.mode))\n            # put it from HWC to CHW format\n            # yikes, this transpose takes 80% of the loading time/CPU\n            img = img.transpose(0, 1).transpose(0, 2).contiguous()\n        return img.float().div(255) if self.div else img.float()\n\n\nclass IdentityTransform(object):\n\n    def __call__(self, data):\n        return data\n\n\nif __name__ == \"__main__\":\n    trans = torchvision.transforms.Compose([\n        GroupScale(256),\n        GroupRandomCrop(224),\n        Stack(),\n        ToTorchFormatTensor(),\n        GroupNormalize(\n            mean=[.485, .456, .406],\n            std=[.229, .224, .225]\n        )]\n    )\n\n    im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')\n\n    color_group = [im] * 3\n    rst = trans(color_group)\n\n    gray_group = [im.convert('L')] * 9\n    gray_rst = trans(gray_group)\n\n    trans2 = torchvision.transforms.Compose([\n        GroupRandomSizedCrop(256),\n        Stack(),\n        ToTorchFormatTensor(),\n        GroupNormalize(\n            mean=[.485, .456, .406],\n            std=[.229, .224, .225])\n    ])\n    print(trans2(color_group))\n"
  },
  {
    "path": "stllm/test/video_utils.py",
    "content": "import os\nimport copy\nimport numpy as np\nfrom PIL import Image\nimport decord\nfrom decord import VideoReader, cpu\nfrom transformers import AutoTokenizer, CLIPVisionModel, CLIPImageProcessor\nimport torch\nfrom mmengine.fileio import FileClient\n\ndef load_video(vis_path, n_clips=1, num_frm=100):\n    \"\"\"\n    Load video frames from a video file.\n\n    Parameters:\n    vis_path (str): Path to the video file.\n    n_clips (int): Number of clips to extract from the video. Defaults to 1.\n    num_frm (int): Number of frames to extract from each clip. Defaults to 100.\n\n    Returns:\n    list: List of PIL.Image.Image objects representing video frames.\n    \"\"\"\n\n    # Load video with VideoReader\n    vr = VideoReader(vis_path, ctx=cpu(0))\n    total_frame_num = len(vr)\n\n    # Currently, this function supports only 1 clip\n    assert n_clips == 1\n\n    # Calculate total number of frames to extract\n    total_num_frm = min(total_frame_num, num_frm)\n    # Get indices of frames to extract\n    frame_idx = get_seq_frames(total_frame_num, total_num_frm)\n    # Extract frames as numpy array\n    img_array = vr.get_batch(frame_idx)\n\n    if isinstance(img_array, decord.ndarray.NDArray):\n        img_array = img_array.asnumpy()\n    else:\n        img_array = img_array.numpy()\n    \n    img_array = img_array.reshape(\n        (n_clips, total_num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))\n    # Convert numpy arrays to PIL Image objects\n    clip_imgs = [Image.fromarray(img_array[0, j]) for j in range(total_num_frm)]\n\n    return clip_imgs\n\ndef load_video_rawframes(vis_path, total_frame_num, n_clips=1, num_frm=100):\n    # Currently, this function supports only 1 clip\n    assert n_clips == 1\n    # Calculate total number of frames to extract\n    total_num_frm = min(total_frame_num, num_frm)\n    # Get indices of frames to extract\n    frame_idx = get_seq_frames(total_frame_num, total_num_frm)\n    # Extract frames as numpy array\n    img_array = get_frames_from_raw(vis_path, frame_idx)\n    # Set target image height and width\n    target_h, target_w = 224, 224\n    # If image shape is not as target, resize it\n    if img_array.shape[-3] != target_h or img_array.shape[-2] != target_w:\n        img_array = torch.from_numpy(img_array).permute(0, 3, 1, 2).float()\n        img_array = torch.nn.functional.interpolate(img_array, size=(target_h, target_w))\n        img_array = img_array.permute(0, 2, 3, 1).to(torch.uint8).numpy()\n\n    # Reshape array to match number of clips and frames\n    img_array = img_array.reshape(\n        (n_clips, total_num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))\n    # Convert numpy arrays to PIL Image objects\n    clip_imgs = [Image.fromarray(img_array[0, j]) for j in range(total_num_frm)]\n\n    return clip_imgs\n\ndef get_seq_frames(total_num_frames, desired_num_frames):\n    \"\"\"\n    Calculate the indices of frames to extract from a video.\n\n    Parameters:\n    total_num_frames (int): Total number of frames in the video.\n    desired_num_frames (int): Desired number of frames to extract.\n\n    Returns:\n    list: List of indices of frames to extract.\n    \"\"\"\n\n    # Calculate the size of each segment from which a frame will be extracted\n    seg_size = float(total_num_frames - 1) / desired_num_frames\n\n    seq = []\n    for i in range(desired_num_frames):\n        # Calculate the start and end indices of each segment\n        start = int(np.round(seg_size * i))\n        end = int(np.round(seg_size * (i + 1)))\n\n        # Append the middle index of the segment to the list\n        seq.append((start + end) // 2)\n\n    return seq\n\ndef get_frames_from_raw(directory, frame_idx, filename_tmpl=\"{:0>6}.jpg\", offset=1):\n    import mmcv\n    mmcv.use_backend('cv2')\n    file_client = FileClient('disk')\n    imgs = list()\n    cache = {}\n    for i, frame_idx in enumerate(frame_idx):\n        if frame_idx in cache:\n            imgs.append(copy.deepcopy(imgs[cache[frame_idx]]))\n            continue\n        else:\n            cache[frame_idx] = i\n        frame_idx += offset\n        filepath = os.path.join(directory, filename_tmpl.format(frame_idx))\n        try:\n            img_bytes = file_client.get(filepath)\n        except:\n            filepath = os.path.join(directory, filename_tmpl.format(frame_idx+1))\n            img_bytes = file_client.get(filepath)\n        cur_frame = mmcv.imfrombytes(img_bytes, channel_order='rgb')\n        imgs.append(cur_frame)    \n    return np.stack(imgs, axis=0)\n"
  },
  {
    "path": "stllm/train/stllm_trainer.py",
    "content": "import os\nimport torch\nfrom stllm.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset\nfrom stllm.datasets.datasets.dataloader_utils import (\n    IterLoader,\n    PrefetchLoader,\n    MetaLoader,\n)\nfrom stllm.common.dist_utils import (\n    get_rank,\n    get_world_size,\n)\nimport webdataset as wds\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torch.utils.data import Sampler\n\nfrom transformers import Trainer\nfrom transformers.trainer import (\n    is_sagemaker_mp_enabled,\n    get_parameter_names,\n    has_length,\n    ALL_LAYERNORM_LAYERS,\n    ShardedDDPOption,\n    logger,\n)\nfrom typing import List, Optional\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                print(name, 'no ignore status')\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef split_to_even_chunks(indices, lengths, num_chunks):\n    \"\"\"\n    Split a list of indices into `chunks` chunks of roughly equal lengths.\n    \"\"\"\n\n    if len(indices) % num_chunks != 0:\n        return [indices[i::num_chunks] for i in range(num_chunks)]\n\n    num_indices_per_chunk = len(indices) // num_chunks\n\n    chunks = [[] for _ in range(num_chunks)]\n    chunks_lengths = [0 for _ in range(num_chunks)]\n    for index in indices:\n        shortest_chunk = chunks_lengths.index(min(chunks_lengths))\n        chunks[shortest_chunk].append(index)\n        chunks_lengths[shortest_chunk] += lengths[index]\n        if len(chunks[shortest_chunk]) == num_indices_per_chunk:\n            chunks_lengths[shortest_chunk] = float(\"inf\")\n\n    return chunks\n\n\ndef get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):\n    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.\n    assert all(l != 0 for l in lengths), \"Should not have zero length.\"\n    if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):\n        # all samples are in the same modality\n        return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)\n    mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])\n    lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])\n\n    mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]\n    lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]\n    megabatch_size = world_size * batch_size\n    mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]\n    lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]\n\n    last_mm = mm_megabatches[-1]\n    last_lang = lang_megabatches[-1]\n    additional_batch = last_mm + last_lang\n    megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]\n    megabatch_indices = torch.randperm(len(megabatches), generator=generator)\n    megabatches = [megabatches[i] for i in megabatch_indices]\n\n    if len(additional_batch) > 0:\n        megabatches.append(sorted(additional_batch))\n\n    return [i for megabatch in megabatches for i in megabatch]\n\n\ndef get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):\n    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.\n    indices = torch.randperm(len(lengths), generator=generator)\n    megabatch_size = world_size * batch_size\n    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]\n    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]\n    megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]\n\n    return [i for megabatch in megabatches for batch in megabatch for i in batch]\n\n\nclass LengthGroupedSampler(Sampler):\n    r\"\"\"\n    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while\n    keeping a bit of randomness.\n    \"\"\"\n\n    def __init__(\n        self,\n        batch_size: int,\n        world_size: int,\n        lengths: Optional[List[int]] = None,\n        generator=None,\n        group_by_modality: bool = False,\n    ):\n        if lengths is None:\n            raise ValueError(\"Lengths must be provided.\")\n\n        self.batch_size = batch_size\n        self.world_size = world_size\n        self.lengths = lengths\n        self.generator = generator\n        self.group_by_modality = group_by_modality\n\n    def __len__(self):\n        return len(self.lengths)\n\n    def __iter__(self):\n        if self.group_by_modality:\n            indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)\n        else:\n            indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)\n        return iter(indices)\n\n\nclass STLLMTrainer(Trainer):\n\n    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:\n        if self.train_dataset is None or not has_length(self.train_dataset):\n            return None\n\n        if self.args.group_by_modality_length:\n            lengths = self.train_dataset.modality_lengths\n            return LengthGroupedSampler(\n                self.args.train_batch_size,\n                world_size=self.args.world_size * self.args.gradient_accumulation_steps,\n                lengths=lengths,\n                group_by_modality=True,\n            )\n        else:\n            return super()._get_train_sampler()\n\n    def get_train_dataloader(self):\n        def _create_loader(dataset, num_workers, bsz, collate_fn=None, is_train=True):\n            # create a single dataloader for each split\n            if isinstance(dataset, ChainDataset) or isinstance(\n                dataset, wds.DataPipeline\n            ):\n                # wds.WebdDataset instance are chained together\n                # webdataset.DataPipeline has its own sampler and collate_fn\n                loader = iter(\n                    DataLoader(\n                        dataset,\n                        batch_size=bsz,\n                        num_workers=num_workers,\n                        pin_memory=True,\n                    )\n                )\n            else:\n                sampler = DistributedSampler(\n                    dataset,\n                    shuffle=is_train,\n                    num_replicas=get_world_size(),\n                    rank=get_rank(),\n                    seed=42,\n                )       \n                loader = DataLoader(\n                    dataset,\n                    batch_size=bsz,\n                    num_workers=num_workers,\n                    pin_memory=True,\n                    sampler=sampler,\n                    shuffle=sampler is None and is_train,\n                    collate_fn=collate_fn,\n                    drop_last=True if is_train else False,\n                )\n                loader = PrefetchLoader(loader)\n\n                if is_train:\n                    loader = IterLoader(loader, use_distributed=True)\n\n            return loader\n\n        dataset = self.train_dataset\n        if isinstance(dataset, torch.utils.data.IterableDataset) or \\\n            isinstance(dataset, torch.utils.data.Dataset):\n            return super().get_train_dataloader()\n        else:\n            dataset = [dt.pop('train') for dt in dataset.values()]\n            batch_size = self.args.per_device_train_batch_size\n            num_workers = self.args.dataloader_num_workers\n            loader = MetaLoader(\n                loaders=[\n                    _create_loader(d, num_workers, batch_size)\n                    for i, d in enumerate(dataset)\n                ]\n            )\n            return loader\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        if is_sagemaker_mp_enabled():\n            return super().create_optimizer()\n        if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n            return super().create_optimizer()\n\n        opt_model = self.model\n\n        if self.optimizer is None:\n            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            if self.args.mm_projector_lr is not None:\n                projector_parameters = [name for name, _ in opt_model.named_parameters() if \"llama_proj\" in name]\n                optimizer_grouped_parameters = [\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.weight_decay,\n                    },\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                    },\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.weight_decay,\n                        \"lr\": self.args.mm_projector_lr,\n                    },\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                        \"lr\": self.args.mm_projector_lr,\n                    },\n                ]\n            else:\n                optimizer_grouped_parameters = [\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.weight_decay,\n                    },\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                    },\n                ]\n\n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                self.optimizer = OSS(\n                    params=optimizer_grouped_parameters,\n                    optim=optimizer_cls,\n                    **optimizer_kwargs,\n                )\n            else:\n                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n                if optimizer_cls.__name__ == \"Adam8bit\":\n                    import bitsandbytes\n\n                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()\n\n                    skipped = 0\n                    for module in opt_model.modules():\n                        if isinstance(module, nn.Embedding):\n                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())\n                            logger.info(f\"skipped {module}: {skipped/2**20}M params\")\n                            manager.register_module_override(module, \"weight\", {\"optim_bits\": 32})\n                            logger.debug(f\"bitsandbytes: will optimize {module} in fp32\")\n                    logger.info(f\"skipped: {skipped/2**20}M params\")\n\n        return self.optimizer\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n        \"\"\"\n        How the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n        Subclass and override for custom behavior.\n        \"\"\"\n        if self.label_smoother is not None and \"labels\" in inputs:\n            labels = inputs.pop(\"labels\")\n        else:\n            labels = None\n        outputs = model(inputs)\n        # Save past state if it exists\n        # TODO: this needs to be fixed and made cleaner later.\n        if self.args.past_index >= 0:\n            self._past = outputs[self.args.past_index]\n\n        if labels is not None:\n            if is_peft_available() and isinstance(model, PeftModel):\n                model_name = unwrap_model(model.base_model)._get_name()\n            else:\n                model_name = unwrap_model(model)._get_name()\n            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():\n                loss = self.label_smoother(outputs, labels, shift_labels=True)\n            else:\n                loss = self.label_smoother(outputs, labels)\n        else:\n            if isinstance(outputs, dict) and \"loss\" not in outputs:\n                raise ValueError(\n                    \"The model did not return a loss from the inputs, only the following keys: \"\n                    f\"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.\"\n                )\n            # We don't use .loss here since the model may return tuples instead of ModelOutput.\n            loss = outputs[\"loss\"] if isinstance(outputs, dict) else outputs[0]\n\n        return (loss, outputs) if return_outputs else loss\n\n"
  },
  {
    "path": "stllm/train/train.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\n\nimport stllm.tasks as tasks\nfrom stllm.common.config import Config\nfrom stllm.common.dist_utils import get_rank, init_distributed_mode\nfrom stllm.common.logger import setup_logger\nfrom stllm.common.optims import (\n    LinearWarmupCosineLRScheduler,\n    LinearWarmupStepLRScheduler,\n)\nfrom stllm.common.registry import registry\nfrom stllm.common.utils import now\n\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\nlocal_rank = None\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Training\")\n\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n\n    args = parser.parse_args()\n    # if 'LOCAL_RANK' not in os.environ:\n    #     os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef setup_seeds(config):\n    seed = config.run_cfg.seed + get_rank()\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n    cudnn.benchmark = False\n    cudnn.deterministic = True\n\n\ndef get_runner_class(cfg):\n    \"\"\"\n    Get runner class from config. Default to epoch-based runner.\n    \"\"\"\n    runner_cls = registry.get_runner_class(cfg.run_cfg.get(\"runner\", \"runner_base\"))\n\n    return runner_cls\n\n\ndef main():\n    # allow auto-dl completes on main process without timeout when using NCCL backend.\n    # os.environ[\"NCCL_BLOCKING_WAIT\"] = \"1\"\n\n    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.\n    job_id = now()\n\n    cfg = Config(parse_args())\n\n    init_distributed_mode(cfg.run_cfg)\n\n    setup_seeds(cfg)\n\n    # set after init_distributed_mode() to only log on master.\n    setup_logger()\n\n    cfg.pretty_print()\n\n    task = tasks.setup_task(cfg)\n    datasets = task.build_datasets(cfg)\n    model = task.build_model(cfg)\n\n    runner = get_runner_class(cfg)(\n        cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets\n    )\n    runner.train()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stllm/train/train_hf.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nfrom dataclasses import dataclass, field\nimport logging\nimport pathlib\nfrom typing import Dict, Optional, Sequence, List\n\nimport sys\nimport torch\nimport deepspeed\n\nimport transformers\nfrom torch.utils.data.dataloader import default_collate\nfrom stllm.train.stllm_trainer import STLLMTrainer\n\nimport argparse\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\n\nimport stllm.tasks as tasks\nfrom stllm.common.config import Config\nfrom stllm.common.dist_utils import get_rank, init_distributed_mode\nfrom stllm.common.logger import setup_logger\n# imports modules for registration\nfrom stllm.datasets.builders import *\nfrom stllm.models import *\nfrom stllm.processors import *\nfrom stllm.runners import *\nfrom stllm.tasks import *\n\nlocal_rank = None\n\n\ndef rank0_print(*args):\n    if local_rank == 0:\n        print(*args)\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Training\")\n\n    parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n    parser.add_argument(\"--local_rank\", required=False, default=0)\n    parser.add_argument(\n        \"--options\",\n        nargs=\"+\",\n        help=\"override some settings in the used config, the key-value pair \"\n        \"in xxx=yyy format will be merged into config file (deprecate), \"\n        \"change to --cfg-options instead.\",\n    )\n\n    args = parser.parse_args()\n    # if 'LOCAL_RANK' not in os.environ:\n    #     os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n    \n@dataclass\nclass ModelArguments:\n    freeze_backbone: bool = field(default=False)\n\n@dataclass\nclass DataArguments:\n    data_path: str = field(default=None,\n                           metadata={\"help\": \"Path to the training data.\"})\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    model_max_length: int = field(\n        default=1024,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    double_quant: bool = field(\n        default=True,\n        metadata={\"help\": \"Compress the quantization statistics through double quantization.\"}\n    )\n    quant_type: str = field(\n        default=\"nf4\",\n        metadata={\"help\": \"Quantization data type to use. Should be one of `fp4` or `nf4`.\"}\n    )\n    bits: int = field(\n        default=16,\n        metadata={\"help\": \"How many bits to use.\"}\n    )\n    mm_projector_lr: Optional[float] = None\n    group_by_modality_length: bool = field(default=False)\n\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):\n    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']\n    for name, module in model.named_modules():\n        if any(mm_keyword in name for mm_keyword in multimodal_keywords):\n            continue\n        if isinstance(module, cls):\n            names = name.split('.')\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n    if 'lm_head' in lora_module_names: # needed for 16-bit\n        lora_module_names.remove('lm_head')\n    return list(lora_module_names)\n\n\ndef safe_save_model_for_hf_trainer(trainer: transformers.Trainer,\n                                   output_dir: str):\n    \"\"\"Collects the state dict and dump to disk.\"\"\"\n    if trainer.deepspeed:\n        torch.cuda.synchronize()\n        trainer.save_model(output_dir)\n        return\n\n    if trainer.args.should_save:\n        model_no_ddp = trainer.model\n        param_grad_dic = {\n            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()\n        }\n        state_dict = model_no_ddp.state_dict()\n        for k in list(state_dict.keys()):\n            if k in param_grad_dic.keys() and not param_grad_dic[k]:\n                # delete parameters that do not require gradient\n                del state_dict[k]\n        cpu_state_dict = {\n            key: value.cpu()\n            for key, value in state_dict.items()\n        }\n        del state_dict     \n        trainer._save(output_dir, state_dict=cpu_state_dict)\n\ndef merge_dict_to_argv(input_dict):\n    input_dict.pop('task')\n    i = 0\n    while i < len(sys.argv):\n        if sys.argv[i].startswith('--cfg-path'):\n            sys.argv.pop(i)\n            sys.argv.pop(i)\n            break\n        else:\n            i += 1\n    sys.argv.extend([f'--{key}={value}' for key, value in input_dict.items()])\n\n@dataclass\nclass DefaultDataCollator(object):\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        return default_collate(instances)\n    \ndef train():\n    global local_rank\n\n    cfg = Config(parse_args())\n\n    task = tasks.setup_task(cfg)\n    datasets = task.build_datasets(cfg)\n    model = task.build_model(cfg)\n\n    parser = transformers.HfArgumentParser(\n        (ModelArguments, DataArguments, TrainingArguments))\n    merge_dict_to_argv(cfg.run_cfg)\n    #sys.argv.extend([f'--{key}={value}' for key, value in cfg.run_cfg.items()])\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n    local_rank = training_args.local_rank\n\n    if training_args.bits in [4, 8]:\n        from peft import prepare_model_for_kbit_training\n        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)\n\n    # set after init_distributed_mode() to only log on master.\n    setup_logger()\n\n    cfg.pretty_print()\n\n    data_module = {}\n    dataset_name = list(datasets.keys())[0]\n    data_module['train_dataset'] = datasets\n    data_module['eval_dataset'] = None\n    data_module['data_collator'] = DefaultDataCollator()\n\n    trainer = STLLMTrainer(model=model,\n                    tokenizer=None,\n                    args=training_args,\n                    **data_module)\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    safe_save_model_for_hf_trainer(trainer=trainer,\n                                       output_dir=training_args.output_dir)\n\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "stllm/train/zero2.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": \"auto\"\n    }\n}"
  },
  {
    "path": "stllm/train/zero3.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    }\n}"
  },
  {
    "path": "stllm/train/zero3_offload.json",
    "content": "{\n    \"fp16\": {\n      \"enabled\": \"auto\",\n      \"loss_scale\": 0,\n      \"loss_scale_window\": 1000,\n      \"initial_scale_power\": 16,\n      \"hysteresis\": 2,\n      \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n      \"enabled\": \"auto\"\n    },\n    \"optimizer\": {\n      \"type\": \"AdamW\",\n      \"params\": {\n        \"lr\": \"auto\",\n        \"betas\": \"auto\",\n        \"eps\": \"auto\",\n        \"weight_decay\": \"auto\"\n      }\n    },\n    \"scheduler\": {\n      \"type\": \"WarmupLR\",\n      \"params\": {\n        \"warmup_min_lr\": \"auto\",\n        \"warmup_max_lr\": \"auto\",\n        \"warmup_num_steps\": \"auto\"\n      }\n    },\n    \"zero_optimization\": {\n      \"stage\": 3,\n      \"offload_optimizer\": {\n        \"device\": \"cpu\",\n        \"pin_memory\": true\n      },\n      \"offload_param\": {\n        \"device\": \"cpu\",\n        \"pin_memory\": true\n      },\n      \"overlap_comm\": true,\n      \"contiguous_gradients\": true,\n      \"sub_group_size\": 1e9,\n      \"reduce_bucket_size\": \"auto\",\n      \"stage3_prefetch_bucket_size\": \"auto\",\n      \"stage3_param_persistence_threshold\": \"auto\",\n      \"stage3_max_live_parameters\": 1e9,\n      \"stage3_max_reuse_distance\": 1e9,\n      \"gather_16bit_weights_on_model_save\": true\n    },\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"steps_per_print\": 1e5,\n    \"wall_clock_breakdown\": false\n  }"
  },
  {
    "path": "trainval.md",
    "content": "## 1. Prepare the Pretrained Weights\nAlthough some weights can be downloaded dynamically at runtime, it is recommended to pre-download them for speeding up each run.\n\n#### Pre-trained Image Encoder (EVA ViT-g)\n```\nwget https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth\n```\nthe path of image encoder weight can be modified [here](stllm/models/eva_vit.py#L433).\n\n#### Pre-trained Q-Former and Linear Projection\n```\n# InstructBLIP (recommended)\nwget https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth\n```\n```\n# MiniGPT4\nwget https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth\nwget https://huggingface.co/Vision-CAIR/MiniGPT-4/blob/main/pretrained_minigpt4.pth\n```\nthe path of Q-Former and Linear Weight can be modified in ```q_former_model``` and ```ckpt``` in each config [here](config).\n\n#### Prepare Vicuna Weights\nPlease first follow the [instructions](https://github.com/lm-sys/FastChat) to prepare Vicuna v1.1 (for InstructBLIP) or Vicuna v1.0 (for MiniGPT4). \nThen modify the ```llama_model``` in each config [here](config) to the folder that contains Vicuna weights.\n\n## 2. Training \n#### Data\nWe follow [VideoChat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) to maintain consistency in the format of each instruction dataset. \nPlease follow the source [instructions](https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/DATA.md) to prepare the videos and annotations for each dataset.\nThen modify the path for each dataset [here](stllm/datasets/datasets/instruction_data.py).\n\nPlease note：\n\n(1）We do not need to prepare all datasets; we only need to prepare the datasets corresponding to the configurations needed for execution.\n\n(2) The annotations for videochat11k and videochatgpt100k are slightly different from the source, which can be found [here](https://drive.google.com/file/d/1HIcT0WOmnHNU_xLtezKaHeUG8qa0_wQh/view).\n\n#### Running\nPlease first modify the path in [train script](script/train/train.sh) for the desired config from [config folder](config), then run\n```\nbash script/train/train.sh\n```\n\n## 3. Inference\n#### MVBench \nPlease first modify the checkpoint path and annotation path in [test script], then run \n```\nbash script/inference/mvbench/test_mvbench.sh\n```\n\n#### VcgBench \nAll evaluation scripts can be found [here](script/inference/vcgbench).\n\nFor instance, to evaluate the temporal score on VideoChatGPT benchmark, we first run the inference to get prediction results: \n```\nbash script/inference/vcgbench/test_temporal.sh\n```\nand then execute the corresponding evaluation script to perform benchmarking:\n```\nbash script/inference/vcgbench/score_temporal.sh\n```\n\n#### VideoQABench\nAll testing procedures are identical to VCGbench， where all evaluation scripts are [here](script/inference/qabench).\n\nFor instance, to evaluate the result on MSVD, we first run\n```\nbash script/inference/qabench/msvd_qa.sh\n```\nand then run\n```\nbash script/inference/qabench/score_msvd.sh\n```\n\n\n"
  }
]